diff --git a/cart-packet-queue.go b/cart-packet-queue.go new file mode 100644 index 0000000..384cae4 --- /dev/null +++ b/cart-packet-queue.go @@ -0,0 +1,94 @@ +package main + +import ( + "fmt" + "io" + "log" + "net" + "sync" + "time" +) + +type CartPacketWithData struct { + MessageType uint16 + Id CartId + Added time.Time + Consumed bool + Data []byte +} + +type CartPacketQueue struct { + mu sync.RWMutex + Packets []CartPacketWithData + connection net.Conn +} + +func NewCartPacketQueue(connection net.Conn) *CartPacketQueue { + + queue := &CartPacketQueue{ + Packets: make([]CartPacketWithData, 0), + connection: connection, + } + go func() { + defer connection.Close() + var packet CartPacket + for { + err := ReadPacket(queue.connection, &packet) + ts := time.Now() + if err != nil { + + if err == io.EOF { + return + } + log.Printf("Error receiving packet: %v\n", err) + //return + } + + data, err := GetPacketData(queue.connection, int(packet.DataLength)) + if err != nil { + log.Printf("Error receiving packet data: %v\n", err) + return + } + + queue.mu.Lock() + + l := make([]CartPacketWithData, 0, len(queue.Packets)) + + for _, packet := range queue.Packets { + if !packet.Consumed && packet.Added.After(ts.Add(-time.Second)) { + l = append(l, packet) + } + } + + queue.Packets = append(l, CartPacketWithData{ + MessageType: packet.MessageType, + Id: packet.Id, + Added: ts, + Data: data, + }) + queue.mu.Unlock() + + } + }() + return queue +} + +func (p *CartPacketQueue) Expect(messageType uint16, id CartId, timeToWait time.Duration) (*CartPacketWithData, error) { + start := time.Now().Add(-time.Millisecond) + + for { + if time.Since(start) > timeToWait { + return nil, fmt.Errorf("timeout waiting for message type %d", messageType) + } + p.mu.RLock() + for _, packet := range p.Packets { + if packet.MessageType == messageType && packet.Id == id && packet.Added.After(start) { + packet.Consumed = true + p.mu.RUnlock() + return &packet, nil + } + } + p.mu.RUnlock() + time.Sleep(time.Millisecond * 2) + } +} diff --git a/main.go b/main.go index cec630b..d469cb5 100644 --- a/main.go +++ b/main.go @@ -173,11 +173,10 @@ func main() { // if local //syncedPool.AddRemote("localhost") - rpcHandler, err := NewGrainHandler(app.pool, ":1337") + _, err = NewGrainHandler(app.pool, ":1337") if err != nil { log.Fatalf("Error creating handler: %v\n", err) } - go rpcHandler.Serve() go func() { for range time.Tick(time.Minute) { diff --git a/packet-queue.go b/packet-queue.go index faa365b..79712c0 100644 --- a/packet-queue.go +++ b/packet-queue.go @@ -30,18 +30,23 @@ func NewPacketQueue(connection net.Conn) *PacketQueue { } go func() { defer connection.Close() + var packet Packet for { - messageType, data, err := ReceivePacket(queue.connection) + err := ReadPacket(queue.connection, &packet) ts := time.Now() if err != nil { - log.Printf("Error receiving packet: %v\n", err) + if err == io.EOF { return } - + log.Printf("Error receiving packet: %v\n", err) //return } - + data, err := GetPacketData(queue.connection, int(packet.DataLength)) + if err != nil { + log.Printf("Error receiving packet data: %v\n", err) + return + } queue.mu.Lock() l := make([]PacketWithData, 0, len(queue.Packets)) @@ -53,7 +58,7 @@ func NewPacketQueue(connection net.Conn) *PacketQueue { } queue.Packets = append(l, PacketWithData{ - MessageType: messageType, + MessageType: packet.MessageType, Added: ts, Data: data, }) diff --git a/packet.go b/packet.go index a1e12be..75db01b 100644 --- a/packet.go +++ b/packet.go @@ -8,16 +8,18 @@ import ( ) const ( - RemoteGetState = uint16(0x01) - RemoteHandleMessage = uint16(0x02) - ResponseBody = uint16(0x03) + RemoteGetState = uint16(0x01) + RemoteHandleMessage = uint16(0x02) + ResponseBody = uint16(0x03) + RemoteGetStateReply = uint16(0x04) + RemoteHandleMessageReply = uint16(0x05) ) type CartPacket struct { Version uint16 MessageType uint16 - Id CartId DataLength uint16 + Id CartId } type Packet struct { @@ -37,8 +39,8 @@ func SendCartPacket(conn io.Writer, id CartId, messageType uint16, datafn func(w binary.Write(conn, binary.LittleEndian, CartPacket{ Version: 2, MessageType: messageType, - Id: id, DataLength: uint16(len(data)), + Id: id, }) _, err = conn.Write(data) return err @@ -86,19 +88,29 @@ func SendProxyResponse(conn io.Writer, data any) error { }) } +func ReadPacket[V Packet | CartPacket](conn io.Reader, packet *V) error { + return binary.Read(conn, binary.LittleEndian, packet) +} + +func GetPacketData(conn io.Reader, len int) ([]byte, error) { + data := make([]byte, len) + l, err := conn.Read(data) + if l != len { + return nil, fmt.Errorf("expected %d bytes, got %d", len, l) + } + return data, err +} + func ReceivePacket(conn io.Reader) (uint16, []byte, error) { var packet Packet - err := binary.Read(conn, binary.LittleEndian, &packet) + err := ReadPacket(conn, &packet) if err != nil { return packet.MessageType, nil, err } - data := make([]byte, packet.DataLength) - l, err := conn.Read(data) + + data, err := GetPacketData(conn, int(packet.DataLength)) if err != nil { return packet.MessageType, nil, err } - if l != int(packet.DataLength) { - return packet.MessageType, nil, fmt.Errorf("expected %d bytes, got %d", packet.DataLength, l) - } return packet.MessageType, data, nil } diff --git a/rpc-pool.go b/rpc-pool.go index aa78b88..c8a96f8 100644 --- a/rpc-pool.go +++ b/rpc-pool.go @@ -2,11 +2,9 @@ package main import ( "fmt" - "io" "net" "strings" "sync" - "time" ) type RemoteGrainPool struct { @@ -26,8 +24,7 @@ func ToCartId(id string) CartId { } type RemoteGrain struct { - net.Conn - *PacketQueue + *CartClient Id CartId Address string } @@ -59,19 +56,18 @@ func (g *RemoteGrain) Connect() error { } func (g *RemoteGrain) HandleMessage(message *Message, isReplay bool) ([]byte, error) { - err := g.Connect() + + data, err := GetData(message.Write) if err != nil { return nil, err } - err = SendCartPacket(g.connection, g.Id, RemoteHandleMessage, message.Write) + reply, err := g.Call(RemoteHandleMessage, g.Id, RemoteHandleMessageReply, data) + if err != nil { return nil, err } - packet, err := g.Expect(ResponseBody, time.Second) - if err != nil { - return nil, err - } - return packet.Data, err + + return reply, err } func (g *RemoteGrain) GetId() CartId { @@ -79,21 +75,7 @@ func (g *RemoteGrain) GetId() CartId { } func (g *RemoteGrain) GetCurrentState() ([]byte, error) { - err := g.Connect() - if err != nil { - return nil, err - } - err = SendCartPacket(g.connection, g.Id, RemoteGetState, func(w io.Writer) error { - return nil - }) - if err != nil { - return nil, err - } - packet, err := g.Expect(ResponseBody, time.Second) - if err != nil { - return nil, err - } - return packet.Data, nil + return g.Call(RemoteGetState, g.Id, RemoteGetStateReply, nil) } func NewRemoteGrainPool(addr string) *RemoteGrainPool { diff --git a/rpc-server.go b/rpc-server.go index e667f3a..d828452 100644 --- a/rpc-server.go +++ b/rpc-server.go @@ -1,15 +1,13 @@ package main import ( - "encoding/binary" + "bytes" "fmt" - "io" - "net" ) type GrainHandler struct { - listener net.Listener - pool *GrainLocalPool + *CartServer + pool *GrainLocalPool } func (h *GrainHandler) GetState(id CartId, reply *Grain) error { @@ -22,68 +20,83 @@ func (h *GrainHandler) GetState(id CartId, reply *Grain) error { } func NewGrainHandler(pool *GrainLocalPool, listen string) (*GrainHandler, error) { + server, err := CartListen(listen) handler := &GrainHandler{ - pool: pool, + CartServer: server, + pool: pool, } - l, err := net.Listen("tcp", listen) - handler.listener = l + server.HandleCall(RemoteHandleMessage, handler.RemoteHandleMessageHandler) + server.HandleCall(RemoteGetState, handler.RemoteGetStateHandler) return handler, err } -func (h *GrainHandler) Serve() { - for { - conn, err := h.listener.Accept() - if err != nil { - fmt.Println("Error accepting connection:", err) - continue - } - - go h.handleClient(conn) +func (h *GrainHandler) RemoteHandleMessageHandler(id CartId, data []byte) (uint16, []byte, error) { + var msg Message + err := ReadMessage(bytes.NewReader(data), &msg) + if err != nil { + fmt.Println("Error reading message:", err) + return RemoteHandleMessageReply, nil, err } + replyData, err := h.pool.Process(id, msg) + if err != nil { + fmt.Println("Error handling message:", err) + } + if err != nil { + return RemoteHandleMessageReply, nil, err + } + return RemoteHandleMessageReply, replyData, nil } -func (h *GrainHandler) handleClient(conn net.Conn) { - var err error - - defer conn.Close() - - var packet CartPacket - - for { - err = binary.Read(conn, binary.LittleEndian, &packet) - if err != nil { - if err == io.EOF { - break - } - fmt.Println("Error in connection:", err) - } - if packet.Version != 2 { - fmt.Printf("Unknown version %d", packet.Version) - break - } - - switch packet.MessageType { - case RemoteHandleMessage: - var msg Message - err = ReadMessage(conn, &msg) - if err != nil { - fmt.Println("Error reading message:", err) - } - - data, err := h.pool.Process(packet.Id, msg) - if err != nil { - fmt.Println("Error handling message:", err) - } - SendRawResponse(conn, data) - - case RemoteGetState: - data, err := h.pool.Get(packet.Id) - if err != nil { - fmt.Println("Error getting grain:", err) - } - SendRawResponse(conn, data) - } - +func (h *GrainHandler) RemoteGetStateHandler(id CartId, data []byte) (uint16, []byte, error) { + data, err := h.pool.Get(id) + if err != nil { + return RemoteGetStateReply, nil, err } - + return RemoteGetStateReply, data, nil } + +// func (h *GrainHandler) handleClient(conn net.Conn) { +// var err error + +// defer conn.Close() + +// var packet CartPacket + +// for { +// err = binary.Read(conn, binary.LittleEndian, &packet) +// if err != nil { +// if err == io.EOF { +// break +// } +// fmt.Println("Error in connection:", err) +// } +// if packet.Version != 2 { +// fmt.Printf("Unknown version %d", packet.Version) +// break +// } + +// switch packet.MessageType { +// case RemoteHandleMessage: +// var msg Message +// err = ReadMessage(conn, &msg) +// if err != nil { +// fmt.Println("Error reading message:", err) +// } + +// data, err := h.pool.Process(packet.Id, msg) +// if err != nil { +// fmt.Println("Error handling message:", err) +// } +// SendRawResponse(conn, data) + +// case RemoteGetState: +// data, err := h.pool.Get(packet.Id) +// if err != nil { +// fmt.Println("Error getting grain:", err) +// } +// SendRawResponse(conn, data) +// } + +// } + +// } diff --git a/tcp-cart-client.go b/tcp-cart-client.go new file mode 100644 index 0000000..94f1868 --- /dev/null +++ b/tcp-cart-client.go @@ -0,0 +1,104 @@ +package main + +import ( + "encoding/binary" + "io" + "net" + "time" +) + +type CartClient struct { + *CartTCPClient +} + +func CartDial(address string) (*CartClient, error) { + + mux, err := NewCartTCPClient(address) + if err != nil { + return nil, err + } + client := &CartClient{ + CartTCPClient: mux, + } + return client, nil +} + +func (c *Client) Close() { + c.Conn.Close() +} + +type CartTCPClient struct { + net.Conn + Errors chan error + ErrorCount int + address string + *PacketQueue +} + +func NewCartTCPClient(address string) (*CartTCPClient, error) { + connection, err := net.Dial("tcp", address) + if err != nil { + return nil, err + } + return &CartTCPClient{ + Errors: make(chan error), + ErrorCount: 0, + Conn: connection, + address: address, + PacketQueue: NewPacketQueue(connection), + }, nil +} + +func (m *CartTCPClient) Connect() error { + if m.Conn == nil { + connection, err := net.Dial("tcp", m.address) + if err != nil { + + m.Errors <- err + m.ErrorCount++ + + return err + } + m.ErrorCount = 0 + m.Conn = connection + } + return nil +} + +func (m *CartTCPClient) SendPacket(messageType uint16, id CartId, data []byte) error { + err := m.Connect() + if err != nil { + return err + } + err = binary.Write(m.Conn, binary.LittleEndian, CartPacket{ + Version: 1, + MessageType: messageType, + DataLength: uint16(len(data)), + Id: id, + }) + if err != nil { + return err + } + _, err = m.Conn.Write(data) + return err +} + +func (m *CartTCPClient) SendPacketFn(messageType uint16, id CartId, datafn func(w io.Writer) error) error { + data, err := GetData(datafn) + if err != nil { + return err + } + return m.SendPacket(messageType, id, data) +} + +func (m *CartTCPClient) Call(messageType uint16, id CartId, responseType uint16, data []byte) ([]byte, error) { + err := m.SendPacket(messageType, id, data) + if err != nil { + return nil, err + } + packet, err := m.Expect(responseType, time.Second) + if err != nil { + return nil, err + } + return packet.Data, nil +} diff --git a/tcp-cart-mux-server.go b/tcp-cart-mux-server.go new file mode 100644 index 0000000..40ebe9d --- /dev/null +++ b/tcp-cart-mux-server.go @@ -0,0 +1,138 @@ +package main + +import ( + "encoding/binary" + "io" + "log" + "net" + "sync" +) + +type CartServer struct { + *TCPCartServerMux +} + +func CartListen(address string) (*CartServer, error) { + listener, err := net.Listen("tcp", address) + server := &CartServer{ + NewCartTCPServerMux(100), + } + + if err != nil { + return nil, err + } + go func() { + for { + conn, err := listener.Accept() + if err != nil { + log.Printf("Error accepting connection: %v\n", err) + continue + } + go server.HandleConnection(conn) + } + }() + return server, nil +} + +type TCPCartServerMux struct { + mu sync.RWMutex + listeners map[uint16]func(CartId, []byte) error + functions map[uint16]func(CartId, []byte) (uint16, []byte, error) +} + +func NewCartTCPServerMux(maxClients int) *TCPCartServerMux { + m := &TCPCartServerMux{ + mu: sync.RWMutex{}, + listeners: make(map[uint16]func(CartId, []byte) error), + functions: make(map[uint16]func(CartId, []byte) (uint16, []byte, error)), + } + + return m +} + +func (m *TCPCartServerMux) handleListener(messageType uint16, id CartId, data []byte) (bool, error) { + m.mu.RLock() + handler, ok := m.listeners[messageType] + m.mu.RUnlock() + if ok { + err := handler(id, data) + if err != nil { + return true, err + } + } + return false, nil +} + +func (m *TCPCartServerMux) handleFunction(connection net.Conn, messageType uint16, id CartId, data []byte) (bool, error) { + m.mu.RLock() + function, ok := m.functions[messageType] + m.mu.RUnlock() + if ok { + responseType, responseData, err := function(id, data) + if err != nil { + return true, err + } + err = binary.Write(connection, binary.LittleEndian, CartPacket{ + Version: 1, + MessageType: responseType, + DataLength: uint16(len(responseData)), + Id: id, + }) + if err != nil { + return true, err + } + packetsSent.Inc() + _, err = connection.Write(responseData) + return true, err + } + return false, nil +} + +func (m *TCPCartServerMux) HandleConnection(connection net.Conn) error { + var packet *CartPacket + var err error + defer connection.Close() + for { + err = ReadPacket(connection, packet) + if err != nil { + if err == io.EOF { + return nil + } + log.Printf("Error receiving packet: %v\n", err) + return err + } + if packet == nil { + log.Println("Packet is nil") + continue + } + data, err := GetPacketData(connection, int(packet.DataLength)) + if err != nil { + log.Printf("Error getting packet data: %v\n", err) + } + status, err := m.handleListener(packet.MessageType, packet.Id, data) + if err != nil { + log.Printf("Error handling listener: %v\n", err) + } + if !status { + status, err = m.handleFunction(connection, packet.MessageType, packet.Id, data) + if err != nil { + log.Printf("Error handling function: %v\n", err) + } + if !status { + log.Printf("Unknown message type: %d\n", packet.MessageType) + } + } + } +} + +func (m *TCPCartServerMux) ListenFor(messageType uint16, handler func(CartId, []byte) error) { + m.mu.Lock() + m.listeners[messageType] = handler + m.mu.Unlock() +} + +func (m *TCPCartServerMux) HandleCall(messageType uint16, handler func(CartId, []byte) (uint16, []byte, error)) { + m.mu.Lock() + m.functions[messageType] = handler + m.mu.Unlock() +} diff --git a/tcp-mux-client.go b/tcp-client.go similarity index 69% rename from tcp-mux-client.go rename to tcp-client.go index 394e7fe..24c4545 100644 --- a/tcp-mux-client.go +++ b/tcp-client.go @@ -8,26 +8,22 @@ import ( ) type Client struct { - *TCPClientMux + *TCPClient } func Dial(address string) (*Client, error) { - mux, err := NewTCPClientMux(address) + mux, err := NewTCPClient(address) if err != nil { return nil, err } client := &Client{ - TCPClientMux: mux, + TCPClient: mux, } return client, nil } -func (c *Client) Close() { - c.Conn.Close() -} - -type TCPClientMux struct { +type TCPClient struct { net.Conn Errors chan error ErrorCount int @@ -35,12 +31,12 @@ type TCPClientMux struct { *PacketQueue } -func NewTCPClientMux(address string) (*TCPClientMux, error) { +func NewTCPClient(address string) (*TCPClient, error) { connection, err := net.Dial("tcp", address) if err != nil { return nil, err } - return &TCPClientMux{ + return &TCPClient{ Errors: make(chan error), ErrorCount: 0, Conn: connection, @@ -49,7 +45,7 @@ func NewTCPClientMux(address string) (*TCPClientMux, error) { }, nil } -func (m *TCPClientMux) Connect() error { +func (m *TCPClient) Connect() error { if m.Conn == nil { connection, err := net.Dial("tcp", m.address) if err != nil { @@ -65,11 +61,11 @@ func (m *TCPClientMux) Connect() error { return nil } -func (m *TCPClientMux) Close() { +func (m *TCPClient) Close() { m.Conn.Close() } -func (m *TCPClientMux) SendPacket(messageType uint16, data []byte) error { +func (m *TCPClient) SendPacket(messageType uint16, data []byte) error { err := m.Connect() if err != nil { return err @@ -86,7 +82,7 @@ func (m *TCPClientMux) SendPacket(messageType uint16, data []byte) error { return err } -func (m *TCPClientMux) SendPacketFn(messageType uint16, datafn func(w io.Writer) error) error { +func (m *TCPClient) SendPacketFn(messageType uint16, datafn func(w io.Writer) error) error { data, err := GetData(datafn) if err != nil { return err @@ -94,7 +90,7 @@ func (m *TCPClientMux) SendPacketFn(messageType uint16, datafn func(w io.Writer) return m.SendPacket(messageType, data) } -func (m *TCPClientMux) Call(messageType uint16, responseType uint16, data []byte) ([]byte, error) { +func (m *TCPClient) Call(messageType uint16, responseType uint16, data []byte) ([]byte, error) { err := m.SendPacket(messageType, data) if err != nil { return nil, err diff --git a/tcp-mux-server.go b/tcp-mux-server.go index 089e0c3..ac8ccbc 100644 --- a/tcp-mux-server.go +++ b/tcp-mux-server.go @@ -35,18 +35,16 @@ func Listen(address string) (*Server, error) { } type TCPServerMux struct { - mu sync.RWMutex - listeners map[uint16]func(data []byte) error - functions map[uint16]func(data []byte) (uint16, []byte, error) - connections []net.Conn + mu sync.RWMutex + listeners map[uint16]func(data []byte) error + functions map[uint16]func(data []byte) (uint16, []byte, error) } func NewTCPServerMux(maxClients int) *TCPServerMux { m := &TCPServerMux{ - connections: make([]net.Conn, 0, maxClients), - mu: sync.RWMutex{}, - listeners: make(map[uint16]func(data []byte) error), - functions: make(map[uint16]func(data []byte) (uint16, []byte, error)), + mu: sync.RWMutex{}, + listeners: make(map[uint16]func(data []byte) error), + functions: make(map[uint16]func(data []byte) (uint16, []byte, error)), } return m @@ -90,9 +88,7 @@ func (m *TCPServerMux) handleFunction(connection net.Conn, messageType uint16, d } func (m *TCPServerMux) HandleConnection(connection net.Conn) error { - m.mu.Lock() - m.connections = append(m.connections, connection) - m.mu.Unlock() + defer connection.Close() for { messageType, data, err := ReceivePacket(connection)