From 0b290a32bf1617ce74449b7765efb46784788fb8 Mon Sep 17 00:00:00 2001 From: matst80 Date: Mon, 11 Nov 2024 23:24:03 +0100 Subject: [PATCH] implement statuscode in packets --- cart-grain.go | 21 ++++-- cart-grain_test.go | 16 ++--- cart-packet-queue.go | 19 +++-- grain-pool.go | 4 +- packet-queue.go | 36 +++++----- packet.go | 8 ++- pool-server.go | 9 ++- remote-grain-pool.go | 67 ++++++++++++++++++ remote-grain.go | 91 ++++++++++++++++++++++++ rpc-pool.go | 154 ----------------------------------------- synced-pool.go | 30 ++++---- tcp-cart-client.go | 6 +- tcp-cart-mux-server.go | 18 ++++- tcp-cart_test.go | 16 ++++- tcp-client.go | 7 +- tcp-mux-server.go | 15 +++- tcp_test.go | 4 +- 17 files changed, 295 insertions(+), 226 deletions(-) create mode 100644 remote-grain-pool.go create mode 100644 remote-grain.go delete mode 100644 rpc-pool.go diff --git a/cart-grain.go b/cart-grain.go index 7f0d4b1..0b74014 100644 --- a/cart-grain.go +++ b/cart-grain.go @@ -54,7 +54,8 @@ type CartGrain struct { type Grain interface { GetId() CartId - HandleMessage(message *Message, isReplay bool) ([]byte, error) + HandleMessage(message *Message, isReplay bool) (*CallResult, error) + GetCurrentState() (*CallResult, error) } func (c *CartGrain) GetId() CartId { @@ -68,6 +69,14 @@ func (c *CartGrain) GetLastChange() int64 { return *c.storageMessages[len(c.storageMessages)-1].TimeStamp } +func (c *CartGrain) GetCurrentState() (*CallResult, error) { + result, err := json.Marshal(c) + return &CallResult{ + StatusCode: 200, + Data: result, + }, err +} + func getItemData(sku string, qty int) (*messages.AddItem, error) { item, err := FetchItem(sku) if err != nil { @@ -99,7 +108,7 @@ func getItemData(sku string, qty int) (*messages.AddItem, error) { }, nil } -func (c *CartGrain) AddItem(sku string, qty int) ([]byte, error) { +func (c *CartGrain) AddItem(sku string, qty int) (*CallResult, error) { cartItem, err := getItemData(sku, qty) if err != nil { return nil, err @@ -171,7 +180,7 @@ func (c *CartGrain) FindItemWithSku(sku string) (*CartItem, bool) { return nil, false } -func (c *CartGrain) HandleMessage(message *Message, isReplay bool) ([]byte, error) { +func (c *CartGrain) HandleMessage(message *Message, isReplay bool) (*CallResult, error) { if message.TimeStamp == nil { now := time.Now().Unix() message.TimeStamp = &now @@ -294,5 +303,9 @@ func (c *CartGrain) HandleMessage(message *Message, isReplay bool) ([]byte, erro c.storageMessages = append(c.storageMessages, *message) c.mu.Unlock() } - return json.Marshal(c) + result, err := json.Marshal(c) + return &CallResult{ + StatusCode: 200, + Data: result, + }, err } diff --git a/cart-grain_test.go b/cart-grain_test.go index e16df20..162c25d 100644 --- a/cart-grain_test.go +++ b/cart-grain_test.go @@ -88,8 +88,8 @@ func TestAddToCart(t *testing.T) { if err != nil { t.Errorf("Error handling message: %v\n", err) } - if len(result) == 0 { - t.Errorf("Expected result, got nil\n") + if result.StatusCode != 200 { + t.Errorf("Call failed\n") } if grain.TotalPrice != 200 { t.Errorf("Expected total price 200, got %d\n", grain.TotalPrice) @@ -104,8 +104,8 @@ func TestAddToCart(t *testing.T) { if err != nil { t.Errorf("Error handling message: %v\n", err) } - if len(result) == 0 { - t.Errorf("Expected result, got nil\n") + if result.StatusCode != 200 { + t.Errorf("Call failed\n") } if grain.Items[0].Quantity != 4 { t.Errorf("Expected quantity 4, got %d\n", grain.Items[0].Quantity) @@ -146,8 +146,8 @@ func TestSetDelivery(t *testing.T) { if err != nil { t.Errorf("Error handling message: %v\n", err) } - if len(result) == 0 { - t.Errorf("Expected result, got nil\n") + if result.StatusCode != 200 { + t.Errorf("Call failed\n") } setDelivery := GetMessage(SetDeliveryType, &messages.SetDelivery{ @@ -198,8 +198,8 @@ func TestSetDeliveryOnAll(t *testing.T) { if err != nil { t.Errorf("Error handling message: %v\n", err) } - if len(result) == 0 { - t.Errorf("Expected result, got nil\n") + if result.StatusCode != 200 { + t.Errorf("Call failed\n") } setDelivery := GetMessage(SetDeliveryType, &messages.SetDelivery{ diff --git a/cart-packet-queue.go b/cart-packet-queue.go index bca122f..4669719 100644 --- a/cart-packet-queue.go +++ b/cart-packet-queue.go @@ -43,7 +43,10 @@ func (p *CartPacketQueue) HandleConnection(connection net.Conn) error { continue } if packet.DataLength == 0 { - go p.HandleData(packet.MessageType, packet.Id, []byte{}) + go p.HandleData(packet.MessageType, packet.Id, CallResult{ + StatusCode: packet.StatusCode, + Data: []byte{}, + }) continue } data, err := GetPacketData(connection, packet.DataLength) @@ -51,11 +54,14 @@ func (p *CartPacketQueue) HandleConnection(connection net.Conn) error { log.Printf("Error receiving packet data: %v\n", err) return err } - go p.HandleData(packet.MessageType, packet.Id, data) + go p.HandleData(packet.MessageType, packet.Id, CallResult{ + StatusCode: packet.StatusCode, + Data: data, + }) } } -func (p *CartPacketQueue) HandleData(t uint32, id CartId, data []byte) { +func (p *CartPacketQueue) HandleData(t uint32, id CartId, data CallResult) { p.mu.Lock() defer p.mu.Unlock() pl, ok := p.expectedPackages[t] @@ -70,10 +76,9 @@ func (p *CartPacketQueue) HandleData(t uint32, id CartId, data []byte) { } } } - data = nil } -func (p *CartPacketQueue) Expect(messageType uint32, id CartId) <-chan []byte { +func (p *CartPacketQueue) Expect(messageType uint32, id CartId) <-chan CallResult { p.mu.Lock() defer p.mu.Unlock() l, ok := p.expectedPackages[messageType] @@ -82,7 +87,7 @@ func (p *CartPacketQueue) Expect(messageType uint32, id CartId) <-chan []byte { idl.Count++ return idl.Chan } - ch := make(chan []byte) + ch := make(chan CallResult) (*l)[id] = Listener{ Chan: ch, Count: 1, @@ -90,7 +95,7 @@ func (p *CartPacketQueue) Expect(messageType uint32, id CartId) <-chan []byte { return ch } - ch := make(chan []byte) + ch := make(chan CallResult) p.expectedPackages[messageType] = &CartListener{ id: Listener{ Chan: ch, diff --git a/grain-pool.go b/grain-pool.go index 9a26804..081e2e5 100644 --- a/grain-pool.go +++ b/grain-pool.go @@ -27,8 +27,8 @@ var ( ) type GrainPool interface { - Process(id CartId, messages ...Message) ([]byte, error) - Get(id CartId) ([]byte, error) + Process(id CartId, messages ...Message) (*CallResult, error) + Get(id CartId) (*CallResult, error) } type Ttl struct { diff --git a/packet-queue.go b/packet-queue.go index 9e8fd09..654c403 100644 --- a/packet-queue.go +++ b/packet-queue.go @@ -7,33 +7,24 @@ import ( "sync" ) -// type PacketWithData struct { -// MessageType uint32 -// Added time.Time -// Consumed bool -// Data []byte -// } - type PacketQueue struct { mu sync.RWMutex expectedPackages map[uint32]*Listener - //Packets []PacketWithData - //connection net.Conn } -//const cap = 150 +type CallResult struct { + StatusCode uint32 + Data []byte +} type Listener struct { Count int - Chan chan []byte + Chan chan CallResult } func NewPacketQueue(connection net.Conn) *PacketQueue { - queue := &PacketQueue{ expectedPackages: make(map[uint32]*Listener), - //Packets: make([]PacketWithData, 0, cap+1), - //connection: connection, } go queue.HandleConnection(connection) return queue @@ -57,7 +48,10 @@ func (p *PacketQueue) HandleConnection(connection net.Conn) error { continue } if packet.DataLength == 0 { - go p.HandleData(packet.MessageType, []byte{}) + go p.HandleData(packet.MessageType, CallResult{ + StatusCode: packet.StatusCode, + Data: []byte{}, + }) continue } data, err := GetPacketData(connection, packet.DataLength) @@ -65,12 +59,15 @@ func (p *PacketQueue) HandleConnection(connection net.Conn) error { log.Printf("Error receiving packet data: %v\n", err) //return err } else { - go p.HandleData(packet.MessageType, data) + go p.HandleData(packet.MessageType, CallResult{ + StatusCode: packet.StatusCode, + Data: data, + }) } } } -func (p *PacketQueue) HandleData(t uint32, data []byte) { +func (p *PacketQueue) HandleData(t uint32, data CallResult) { p.mu.Lock() defer p.mu.Unlock() l, ok := p.expectedPackages[t] @@ -83,10 +80,9 @@ func (p *PacketQueue) HandleData(t uint32, data []byte) { } return } - data = nil } -func (p *PacketQueue) Expect(messageType uint32) <-chan []byte { +func (p *PacketQueue) Expect(messageType uint32) <-chan CallResult { p.mu.Lock() defer p.mu.Unlock() l, ok := p.expectedPackages[messageType] @@ -95,7 +91,7 @@ func (p *PacketQueue) Expect(messageType uint32) <-chan []byte { return l.Chan } - ch := make(chan []byte) + ch := make(chan CallResult) p.expectedPackages[messageType] = &Listener{ Count: 1, Chan: ch, diff --git a/packet.go b/packet.go index 29a2f55..8438f60 100644 --- a/packet.go +++ b/packet.go @@ -16,14 +16,16 @@ const ( type CartPacket struct { Version uint32 MessageType uint32 - DataLength uint64 + DataLength uint32 + StatusCode uint32 Id CartId } type Packet struct { Version uint32 MessageType uint32 - DataLength uint64 + DataLength uint32 + StatusCode uint32 } func ReadPacket(conn io.Reader, packet *Packet) error { @@ -34,7 +36,7 @@ func ReadCartPacket(conn io.Reader, packet *CartPacket) error { return binary.Read(conn, binary.LittleEndian, packet) } -func GetPacketData(conn io.Reader, len uint64) ([]byte, error) { +func GetPacketData(conn io.Reader, len uint32) ([]byte, error) { if len == 0 { return []byte{}, nil } diff --git a/pool-server.go b/pool-server.go index 29f8180..4b46f40 100644 --- a/pool-server.go +++ b/pool-server.go @@ -52,11 +52,16 @@ func ErrorHandler(fn func(w http.ResponseWriter, r *http.Request) error) func(w } } -func (s *PoolServer) WriteResult(w http.ResponseWriter, data []byte) error { +func (s *PoolServer) WriteResult(w http.ResponseWriter, result *CallResult) error { w.Header().Set("Content-Type", "application/json") w.Header().Set("X-Pod-Name", s.pod_name) + if result.StatusCode != 200 { + w.WriteHeader(int(result.StatusCode)) + w.Write([]byte(result.Data)) + return nil + } w.WriteHeader(http.StatusOK) - _, err := w.Write(data) + _, err := w.Write(result.Data) return err } diff --git a/remote-grain-pool.go b/remote-grain-pool.go new file mode 100644 index 0000000..8a678b2 --- /dev/null +++ b/remote-grain-pool.go @@ -0,0 +1,67 @@ +package main + +import "sync" + +type RemoteGrainPool struct { + mu sync.RWMutex + Host string + grains map[CartId]*RemoteGrain +} + +func NewRemoteGrainPool(addr string) *RemoteGrainPool { + return &RemoteGrainPool{ + Host: addr, + grains: make(map[CartId]*RemoteGrain), + } +} + +func (p *RemoteGrainPool) findRemoteGrain(id CartId) *RemoteGrain { + p.mu.RLock() + grain, ok := p.grains[id] + p.mu.RUnlock() + if !ok { + return nil + } + return grain +} + +func (p *RemoteGrainPool) findOrCreateGrain(id CartId) (*RemoteGrain, error) { + grain := p.findRemoteGrain(id) + + if grain == nil { + grain, err := NewRemoteGrain(id, p.Host) + if err != nil { + return nil, err + } + p.mu.Lock() + p.grains[id] = grain + p.mu.Unlock() + } + return grain, nil +} + +func (p *RemoteGrainPool) Delete(id CartId) { + p.mu.Lock() + delete(p.grains, id) + p.mu.Unlock() +} + +func (p *RemoteGrainPool) Process(id CartId, messages ...Message) (*CallResult, error) { + var result *CallResult + grain, err := p.findOrCreateGrain(id) + if err != nil { + return nil, err + } + for _, message := range messages { + result, err = grain.HandleMessage(&message, false) + } + return result, err +} + +func (p *RemoteGrainPool) Get(id CartId) (*CallResult, error) { + grain, err := p.findOrCreateGrain(id) + if err != nil { + return nil, err + } + return grain.GetCurrentState() +} diff --git a/remote-grain.go b/remote-grain.go new file mode 100644 index 0000000..ed62981 --- /dev/null +++ b/remote-grain.go @@ -0,0 +1,91 @@ +package main + +import ( + "fmt" + "strings" + "time" + + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" +) + +func (id CartId) String() string { + return strings.Trim(string(id[:]), "\x00") +} + +func ToCartId(id string) CartId { + var result [16]byte + copy(result[:], []byte(id)) + return result +} + +type RemoteGrain struct { + *CartClient + Id CartId + Host string +} + +func NewRemoteGrain(id CartId, host string) (*RemoteGrain, error) { + client, err := CartDial(fmt.Sprintf("%s:1337", host)) + if err != nil { + return nil, err + } + + return &RemoteGrain{ + Id: id, + Host: host, + CartClient: client, + }, nil +} + +var ( + remoteCartLatency = promauto.NewCounter(prometheus.CounterOpts{ + Name: "cart_remote_grain_calls_total_latency", + Help: "The total latency of remote grains", + }) + remoteCartCallsTotal = promauto.NewCounter(prometheus.CounterOpts{ + Name: "cart_remote_grain_calls_total", + Help: "The total number of calls to remote grains", + }) +) + +var start time.Time + +func MeasureLatency(fn func() (*CallResult, error)) (*CallResult, error) { + start = time.Now() + data, err := fn() + if err != nil { + return data, err + } + elapsed := time.Since(start).Milliseconds() + go func() { + remoteCartLatency.Add(float64(elapsed)) + remoteCartCallsTotal.Inc() + }() + return data, nil +} + +func (g *RemoteGrain) HandleMessage(message *Message, isReplay bool) (*CallResult, error) { + + data, err := GetData(message.Write) + if err != nil { + return nil, err + } + reply, err := MeasureLatency(func() (*CallResult, error) { + return g.Call(RemoteHandleMutation, g.Id, RemoteHandleMutationReply, data) + }) + + if err != nil { + return nil, err + } + + return reply, err +} + +func (g *RemoteGrain) GetId() CartId { + return g.Id +} + +func (g *RemoteGrain) GetCurrentState() (*CallResult, error) { + return MeasureLatency(func() (*CallResult, error) { return g.Call(RemoteGetState, g.Id, RemoteGetStateReply, []byte{}) }) +} diff --git a/rpc-pool.go b/rpc-pool.go deleted file mode 100644 index 3819546..0000000 --- a/rpc-pool.go +++ /dev/null @@ -1,154 +0,0 @@ -package main - -import ( - "fmt" - "strings" - "sync" - "time" - - "github.com/prometheus/client_golang/prometheus" - "github.com/prometheus/client_golang/prometheus/promauto" -) - -type RemoteGrainPool struct { - mu sync.RWMutex - Host string - grains map[CartId]*RemoteGrain -} - -func (id CartId) String() string { - return strings.Trim(string(id[:]), "\x00") -} - -func ToCartId(id string) CartId { - var result [16]byte - copy(result[:], []byte(id)) - return result -} - -type RemoteGrain struct { - *CartClient - Id CartId - Host string -} - -func NewRemoteGrain(id CartId, host string) (*RemoteGrain, error) { - client, err := CartDial(fmt.Sprintf("%s:1337", host)) - if err != nil { - return nil, err - } - - return &RemoteGrain{ - Id: id, - Host: host, - CartClient: client, - }, nil -} - -var ( - remoteCartLatency = promauto.NewCounter(prometheus.CounterOpts{ - Name: "cart_remote_grain_calls_total_latency", - Help: "The total latency of remote grains", - }) - remoteCartCallsTotal = promauto.NewCounter(prometheus.CounterOpts{ - Name: "cart_remote_grain_calls_total", - Help: "The total number of calls to remote grains", - }) -) - -var start time.Time - -func MeasureLatency(fn func() ([]byte, error)) ([]byte, error) { - start = time.Now() - data, err := fn() - if err != nil { - return data, err - } - elapsed := time.Since(start).Milliseconds() - go func() { - remoteCartLatency.Add(float64(elapsed)) - remoteCartCallsTotal.Inc() - }() - return data, nil -} - -func (g *RemoteGrain) HandleMessage(message *Message, isReplay bool) ([]byte, error) { - - data, err := GetData(message.Write) - if err != nil { - return nil, err - } - reply, err := MeasureLatency(func() ([]byte, error) { return g.Call(RemoteHandleMutation, g.Id, RemoteHandleMutationReply, data) }) - - if err != nil { - return nil, err - } - - return reply, err -} - -func (g *RemoteGrain) GetId() CartId { - return g.Id -} - -func (g *RemoteGrain) GetCurrentState() ([]byte, error) { - return MeasureLatency(func() ([]byte, error) { return g.Call(RemoteGetState, g.Id, RemoteGetStateReply, []byte{}) }) -} - -func NewRemoteGrainPool(addr string) *RemoteGrainPool { - return &RemoteGrainPool{ - Host: addr, - grains: make(map[CartId]*RemoteGrain), - } -} - -func (p *RemoteGrainPool) findRemoteGrain(id CartId) *RemoteGrain { - p.mu.RLock() - grain, ok := p.grains[id] - p.mu.RUnlock() - if !ok { - return nil - } - return grain -} - -func (p *RemoteGrainPool) findOrCreateGrain(id CartId) (*RemoteGrain, error) { - grain := p.findRemoteGrain(id) - - if grain == nil { - grain, err := NewRemoteGrain(id, p.Host) - if err != nil { - return nil, err - } - p.mu.Lock() - p.grains[id] = grain - p.mu.Unlock() - } - return grain, nil -} - -func (p *RemoteGrainPool) Delete(id CartId) { - p.mu.Lock() - delete(p.grains, id) - p.mu.Unlock() -} - -func (p *RemoteGrainPool) Process(id CartId, messages ...Message) ([]byte, error) { - var result []byte - grain, err := p.findOrCreateGrain(id) - if err != nil { - return nil, err - } - for _, message := range messages { - result, err = grain.HandleMessage(&message, false) - } - return result, err -} - -func (p *RemoteGrainPool) Get(id CartId) ([]byte, error) { - grain, err := p.findOrCreateGrain(id) - if err != nil { - return nil, err - } - return grain.GetCurrentState() -} diff --git a/synced-pool.go b/synced-pool.go index 6bc1a7b..2abc216 100644 --- a/synced-pool.go +++ b/synced-pool.go @@ -1,7 +1,6 @@ package main import ( - "encoding/json" "fmt" "log" "strings" @@ -26,7 +25,6 @@ type RemoteHost struct { *Client Host string MissedPings int - //Pool *RemoteGrainPool } type SyncedPool struct { @@ -248,21 +246,27 @@ const ( ) func (h *RemoteHost) Negotiate(knownHosts []string) ([]string, error) { - data, err := h.Call(RemoteNegotiate, RemoteNegotiateResponse, []byte(strings.Join(knownHosts, ";"))) + reply, err := h.Call(RemoteNegotiate, RemoteNegotiateResponse, []byte(strings.Join(knownHosts, ";"))) if err != nil { return nil, err } + if reply.StatusCode != 200 { + return nil, fmt.Errorf("remote returned error on negotiate: %s", string(reply.Data)) + } - return strings.Split(string(data), ";"), nil + return strings.Split(string(reply.Data), ";"), nil } func (g *RemoteHost) GetCartMappings() ([]CartId, error) { - data, err := g.Call(GetCartIds, CartIdsResponse, []byte{}) + reply, err := g.Call(GetCartIds, CartIdsResponse, []byte{}) if err != nil { return nil, err } - parts := strings.Split(string(data), ";") + if reply.StatusCode != 200 { + return nil, fmt.Errorf("remote returned error: %s", string(reply.Data)) + } + parts := strings.Split(string(reply.Data), ";") ids := make([]CartId, 0, len(parts)) for _, p := range parts { ids = append(ids, ToCartId(p)) @@ -289,13 +293,13 @@ func (p *SyncedPool) Negotiate(knownHosts []string) ([]string, error) { } func (r *RemoteHost) ConfirmChange(id CartId, host string) error { - data, err := r.Call(RemoteGrainChanged, AckChange, []byte(fmt.Sprintf("%s;%s", id, host))) + reply, err := r.Call(RemoteGrainChanged, AckChange, []byte(fmt.Sprintf("%s;%s", id, host))) if err != nil { return err } - if string(data) != "ok" { - return fmt.Errorf("remote grain change failed %s", string(data)) + if string(reply.Data) != "ok" { + return fmt.Errorf("remote grain change failed %s", string(reply.Data)) } return nil @@ -443,9 +447,9 @@ func (p *SyncedPool) getGrain(id CartId) (Grain, error) { return localGrain, nil } -func (p *SyncedPool) Process(id CartId, messages ...Message) ([]byte, error) { +func (p *SyncedPool) Process(id CartId, messages ...Message) (*CallResult, error) { pool, err := p.getGrain(id) - var res []byte + var res *CallResult if err != nil { return nil, err } @@ -458,7 +462,7 @@ func (p *SyncedPool) Process(id CartId, messages ...Message) ([]byte, error) { return res, nil } -func (p *SyncedPool) Get(id CartId) ([]byte, error) { +func (p *SyncedPool) Get(id CartId) (*CallResult, error) { grain, err := p.getGrain(id) if err != nil { return nil, err @@ -467,5 +471,5 @@ func (p *SyncedPool) Get(id CartId) ([]byte, error) { return remoteGrain.GetCurrentState() } - return json.Marshal(grain) + return grain.GetCurrentState() } diff --git a/tcp-cart-client.go b/tcp-cart-client.go index 47fb459..28ef304 100644 --- a/tcp-cart-client.go +++ b/tcp-cart-client.go @@ -73,7 +73,7 @@ func (m *CartTCPClient) SendPacket(messageType uint32, id CartId, data []byte) e err = binary.Write(m.Conn, binary.LittleEndian, CartPacket{ Version: CurrentPacketVersion, MessageType: messageType, - DataLength: uint64(len(data)), + DataLength: uint32(len(data)), Id: id, }) if err != nil { @@ -91,7 +91,7 @@ func (m *CartTCPClient) SendPacket(messageType uint32, id CartId, data []byte) e // return m.SendPacket(messageType, id, data) // } -func (m *CartTCPClient) Call(messageType uint32, id CartId, responseType uint32, data []byte) ([]byte, error) { +func (m *CartTCPClient) Call(messageType uint32, id CartId, responseType uint32, data []byte) (*CallResult, error) { packetChan := m.Expect(responseType, id) err := m.SendPacket(messageType, id, data) if err != nil { @@ -99,7 +99,7 @@ func (m *CartTCPClient) Call(messageType uint32, id CartId, responseType uint32, } select { case ret := <-packetChan: - return ret, nil + return &ret, nil case <-time.After(time.Second): return nil, fmt.Errorf("timeout") } diff --git a/tcp-cart-mux-server.go b/tcp-cart-mux-server.go index 2729ff2..3b3ae8d 100644 --- a/tcp-cart-mux-server.go +++ b/tcp-cart-mux-server.go @@ -69,13 +69,24 @@ func (m *TCPCartServerMux) handleFunction(connection net.Conn, messageType uint3 m.mu.RUnlock() if ok { responseType, responseData, err := fn(id, data) + if err != nil { + errData := []byte(err.Error()) + err = binary.Write(connection, binary.LittleEndian, CartPacket{ + Version: CurrentPacketVersion, + MessageType: responseType, + DataLength: uint32(len(errData)), + StatusCode: 500, + Id: id, + }) + _, err = connection.Write(errData) return true, err } err = binary.Write(connection, binary.LittleEndian, CartPacket{ Version: CurrentPacketVersion, MessageType: responseType, - DataLength: uint64(len(responseData)), + DataLength: uint32(len(responseData)), + StatusCode: 200, Id: id, }) if err != nil { @@ -101,7 +112,10 @@ func (m *TCPCartServerMux) HandleConnection(connection net.Conn) error { log.Printf("Error receiving packet: %v\n", err) return err } - + if packet.Version != CurrentPacketVersion { + log.Printf("Incorrect packet version: %d\n", packet.Version) + continue + } data, err := GetPacketData(connection, packet.DataLength) if err != nil { log.Printf("Error getting packet data: %v\n", err) diff --git a/tcp-cart_test.go b/tcp-cart_test.go index 589945e..e43b672 100644 --- a/tcp-cart_test.go +++ b/tcp-cart_test.go @@ -1,6 +1,7 @@ package main import ( + "fmt" "log" "testing" ) @@ -21,6 +22,10 @@ func TestCartTcpHelpers(t *testing.T) { messageData = string(data) return nil }) + server.HandleCall(666, func(id CartId, data []byte) (uint32, []byte, error) { + log.Printf("Received call: %s\n", string(data)) + return 3, []byte("Hello, client!"), fmt.Errorf("Det blev fel") + }) server.HandleCall(2, func(id CartId, data []byte) (uint32, []byte, error) { log.Printf("Received call: %s\n", string(data)) return 3, []byte("Hello, client!"), nil @@ -34,6 +39,13 @@ func TestCartTcpHelpers(t *testing.T) { if err != nil { t.Errorf("Error calling: %v\n", err) } + s, err := client.Call(666, id, 3, []byte("Hello, server!")) + if err != nil { + t.Errorf("Error calling: %v\n", err) + } + if s.StatusCode != 500 { + t.Errorf("Expected 500, got %d\n", s.StatusCode) + } for i := 0; i < 100; i++ { _, err = client.Call(Ping, id, Pong, nil) if err != nil { @@ -44,8 +56,8 @@ func TestCartTcpHelpers(t *testing.T) { if err != nil { t.Errorf("Error calling: %v\n", err) } - if string(answer) != "Hello, client!" { - t.Errorf("Expected answer 'Hello, client!', got %s\n", string(answer)) + if string(answer.Data) != "Hello, client!" { + t.Errorf("Expected answer 'Hello, client!', got %s\n", string(answer.Data)) } if messageData != "Hello, world!" { t.Errorf("Expected message 'Hello, world!', got %s\n", messageData) diff --git a/tcp-client.go b/tcp-client.go index 4d57572..f80b7b5 100644 --- a/tcp-client.go +++ b/tcp-client.go @@ -77,7 +77,8 @@ func (m *TCPClient) SendPacket(messageType uint32, data []byte) error { err = binary.Write(m.Conn, binary.LittleEndian, Packet{ Version: CurrentPacketVersion, MessageType: messageType, - DataLength: uint64(len(data)), + StatusCode: 0, + DataLength: uint32(len(data)), }) if err != nil { return m.HandleConnectionError(err) @@ -94,7 +95,7 @@ func (m *TCPClient) SendPacket(messageType uint32, data []byte) error { // return m.SendPacket(messageType, data) // } -func (m *TCPClient) Call(messageType uint32, responseType uint32, data []byte) ([]byte, error) { +func (m *TCPClient) Call(messageType uint32, responseType uint32, data []byte) (*CallResult, error) { packetChan := m.Expect(responseType) err := m.SendPacket(messageType, data) if err != nil { @@ -103,7 +104,7 @@ func (m *TCPClient) Call(messageType uint32, responseType uint32, data []byte) ( select { case ret := <-packetChan: - return ret, nil + return &ret, nil case <-time.After(time.Second): return nil, fmt.Errorf("timeout") } diff --git a/tcp-mux-server.go b/tcp-mux-server.go index e2a83b3..8ca37c3 100644 --- a/tcp-mux-server.go +++ b/tcp-mux-server.go @@ -70,12 +70,21 @@ func (m *TCPServerMux) handleFunction(connection net.Conn, messageType uint32, d if ok { responseType, responseData, err := function(data) if err != nil { + errData := []byte(err.Error()) + err = binary.Write(connection, binary.LittleEndian, Packet{ + Version: CurrentPacketVersion, + MessageType: responseType, + StatusCode: 500, + DataLength: uint32(len(errData)), + }) + _, err = connection.Write(errData) return true, err } err = binary.Write(connection, binary.LittleEndian, Packet{ Version: CurrentPacketVersion, MessageType: responseType, - DataLength: uint64(len(responseData)), + StatusCode: 200, + DataLength: uint32(len(responseData)), }) if err != nil { return true, err @@ -100,6 +109,10 @@ func (m *TCPServerMux) HandleConnection(connection net.Conn) error { log.Printf("Error receiving packet: %v\n", err) return err } + if packet.Version != CurrentPacketVersion { + log.Printf("Incorrect package version: %v\n", err) + continue + } data, err := GetPacketData(connection, packet.DataLength) if err != nil { log.Printf("Error receiving packet data: %v\n", err) diff --git a/tcp_test.go b/tcp_test.go index 1aafeb1..7cb331d 100644 --- a/tcp_test.go +++ b/tcp_test.go @@ -39,8 +39,8 @@ func TestTcpHelpers(t *testing.T) { t.Errorf("Error calling: %v\n", err) } client.Close() - if string(answer) != "Hello, client!" { - t.Errorf("Expected answer 'Hello, client!', got %s\n", string(answer)) + if string(answer.Data) != "Hello, client!" { + t.Errorf("Expected answer 'Hello, client!', got %s\n", string(answer.Data)) } if messageData != "Hello, world!" { t.Errorf("Expected message 'Hello, world!', got %s\n", messageData)