implement statuscode in packets
All checks were successful
Build and Publish / BuildAndDeploy (push) Successful in 2m2s

This commit is contained in:
matst80
2024-11-11 23:24:03 +01:00
parent 9c15251f67
commit 0b290a32bf
17 changed files with 295 additions and 226 deletions

View File

@@ -54,7 +54,8 @@ type CartGrain struct {
type Grain interface { type Grain interface {
GetId() CartId GetId() CartId
HandleMessage(message *Message, isReplay bool) ([]byte, error) HandleMessage(message *Message, isReplay bool) (*CallResult, error)
GetCurrentState() (*CallResult, error)
} }
func (c *CartGrain) GetId() CartId { func (c *CartGrain) GetId() CartId {
@@ -68,6 +69,14 @@ func (c *CartGrain) GetLastChange() int64 {
return *c.storageMessages[len(c.storageMessages)-1].TimeStamp return *c.storageMessages[len(c.storageMessages)-1].TimeStamp
} }
func (c *CartGrain) GetCurrentState() (*CallResult, error) {
result, err := json.Marshal(c)
return &CallResult{
StatusCode: 200,
Data: result,
}, err
}
func getItemData(sku string, qty int) (*messages.AddItem, error) { func getItemData(sku string, qty int) (*messages.AddItem, error) {
item, err := FetchItem(sku) item, err := FetchItem(sku)
if err != nil { if err != nil {
@@ -99,7 +108,7 @@ func getItemData(sku string, qty int) (*messages.AddItem, error) {
}, nil }, nil
} }
func (c *CartGrain) AddItem(sku string, qty int) ([]byte, error) { func (c *CartGrain) AddItem(sku string, qty int) (*CallResult, error) {
cartItem, err := getItemData(sku, qty) cartItem, err := getItemData(sku, qty)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -171,7 +180,7 @@ func (c *CartGrain) FindItemWithSku(sku string) (*CartItem, bool) {
return nil, false return nil, false
} }
func (c *CartGrain) HandleMessage(message *Message, isReplay bool) ([]byte, error) { func (c *CartGrain) HandleMessage(message *Message, isReplay bool) (*CallResult, error) {
if message.TimeStamp == nil { if message.TimeStamp == nil {
now := time.Now().Unix() now := time.Now().Unix()
message.TimeStamp = &now message.TimeStamp = &now
@@ -294,5 +303,9 @@ func (c *CartGrain) HandleMessage(message *Message, isReplay bool) ([]byte, erro
c.storageMessages = append(c.storageMessages, *message) c.storageMessages = append(c.storageMessages, *message)
c.mu.Unlock() c.mu.Unlock()
} }
return json.Marshal(c) result, err := json.Marshal(c)
return &CallResult{
StatusCode: 200,
Data: result,
}, err
} }

View File

@@ -88,8 +88,8 @@ func TestAddToCart(t *testing.T) {
if err != nil { if err != nil {
t.Errorf("Error handling message: %v\n", err) t.Errorf("Error handling message: %v\n", err)
} }
if len(result) == 0 { if result.StatusCode != 200 {
t.Errorf("Expected result, got nil\n") t.Errorf("Call failed\n")
} }
if grain.TotalPrice != 200 { if grain.TotalPrice != 200 {
t.Errorf("Expected total price 200, got %d\n", grain.TotalPrice) t.Errorf("Expected total price 200, got %d\n", grain.TotalPrice)
@@ -104,8 +104,8 @@ func TestAddToCart(t *testing.T) {
if err != nil { if err != nil {
t.Errorf("Error handling message: %v\n", err) t.Errorf("Error handling message: %v\n", err)
} }
if len(result) == 0 { if result.StatusCode != 200 {
t.Errorf("Expected result, got nil\n") t.Errorf("Call failed\n")
} }
if grain.Items[0].Quantity != 4 { if grain.Items[0].Quantity != 4 {
t.Errorf("Expected quantity 4, got %d\n", grain.Items[0].Quantity) t.Errorf("Expected quantity 4, got %d\n", grain.Items[0].Quantity)
@@ -146,8 +146,8 @@ func TestSetDelivery(t *testing.T) {
if err != nil { if err != nil {
t.Errorf("Error handling message: %v\n", err) t.Errorf("Error handling message: %v\n", err)
} }
if len(result) == 0 { if result.StatusCode != 200 {
t.Errorf("Expected result, got nil\n") t.Errorf("Call failed\n")
} }
setDelivery := GetMessage(SetDeliveryType, &messages.SetDelivery{ setDelivery := GetMessage(SetDeliveryType, &messages.SetDelivery{
@@ -198,8 +198,8 @@ func TestSetDeliveryOnAll(t *testing.T) {
if err != nil { if err != nil {
t.Errorf("Error handling message: %v\n", err) t.Errorf("Error handling message: %v\n", err)
} }
if len(result) == 0 { if result.StatusCode != 200 {
t.Errorf("Expected result, got nil\n") t.Errorf("Call failed\n")
} }
setDelivery := GetMessage(SetDeliveryType, &messages.SetDelivery{ setDelivery := GetMessage(SetDeliveryType, &messages.SetDelivery{

View File

@@ -43,7 +43,10 @@ func (p *CartPacketQueue) HandleConnection(connection net.Conn) error {
continue continue
} }
if packet.DataLength == 0 { if packet.DataLength == 0 {
go p.HandleData(packet.MessageType, packet.Id, []byte{}) go p.HandleData(packet.MessageType, packet.Id, CallResult{
StatusCode: packet.StatusCode,
Data: []byte{},
})
continue continue
} }
data, err := GetPacketData(connection, packet.DataLength) data, err := GetPacketData(connection, packet.DataLength)
@@ -51,11 +54,14 @@ func (p *CartPacketQueue) HandleConnection(connection net.Conn) error {
log.Printf("Error receiving packet data: %v\n", err) log.Printf("Error receiving packet data: %v\n", err)
return err return err
} }
go p.HandleData(packet.MessageType, packet.Id, data) go p.HandleData(packet.MessageType, packet.Id, CallResult{
StatusCode: packet.StatusCode,
Data: data,
})
} }
} }
func (p *CartPacketQueue) HandleData(t uint32, id CartId, data []byte) { func (p *CartPacketQueue) HandleData(t uint32, id CartId, data CallResult) {
p.mu.Lock() p.mu.Lock()
defer p.mu.Unlock() defer p.mu.Unlock()
pl, ok := p.expectedPackages[t] pl, ok := p.expectedPackages[t]
@@ -70,10 +76,9 @@ func (p *CartPacketQueue) HandleData(t uint32, id CartId, data []byte) {
} }
} }
} }
data = nil
} }
func (p *CartPacketQueue) Expect(messageType uint32, id CartId) <-chan []byte { func (p *CartPacketQueue) Expect(messageType uint32, id CartId) <-chan CallResult {
p.mu.Lock() p.mu.Lock()
defer p.mu.Unlock() defer p.mu.Unlock()
l, ok := p.expectedPackages[messageType] l, ok := p.expectedPackages[messageType]
@@ -82,7 +87,7 @@ func (p *CartPacketQueue) Expect(messageType uint32, id CartId) <-chan []byte {
idl.Count++ idl.Count++
return idl.Chan return idl.Chan
} }
ch := make(chan []byte) ch := make(chan CallResult)
(*l)[id] = Listener{ (*l)[id] = Listener{
Chan: ch, Chan: ch,
Count: 1, Count: 1,
@@ -90,7 +95,7 @@ func (p *CartPacketQueue) Expect(messageType uint32, id CartId) <-chan []byte {
return ch return ch
} }
ch := make(chan []byte) ch := make(chan CallResult)
p.expectedPackages[messageType] = &CartListener{ p.expectedPackages[messageType] = &CartListener{
id: Listener{ id: Listener{
Chan: ch, Chan: ch,

View File

@@ -27,8 +27,8 @@ var (
) )
type GrainPool interface { type GrainPool interface {
Process(id CartId, messages ...Message) ([]byte, error) Process(id CartId, messages ...Message) (*CallResult, error)
Get(id CartId) ([]byte, error) Get(id CartId) (*CallResult, error)
} }
type Ttl struct { type Ttl struct {

View File

@@ -7,33 +7,24 @@ import (
"sync" "sync"
) )
// type PacketWithData struct {
// MessageType uint32
// Added time.Time
// Consumed bool
// Data []byte
// }
type PacketQueue struct { type PacketQueue struct {
mu sync.RWMutex mu sync.RWMutex
expectedPackages map[uint32]*Listener expectedPackages map[uint32]*Listener
//Packets []PacketWithData
//connection net.Conn
} }
//const cap = 150 type CallResult struct {
StatusCode uint32
Data []byte
}
type Listener struct { type Listener struct {
Count int Count int
Chan chan []byte Chan chan CallResult
} }
func NewPacketQueue(connection net.Conn) *PacketQueue { func NewPacketQueue(connection net.Conn) *PacketQueue {
queue := &PacketQueue{ queue := &PacketQueue{
expectedPackages: make(map[uint32]*Listener), expectedPackages: make(map[uint32]*Listener),
//Packets: make([]PacketWithData, 0, cap+1),
//connection: connection,
} }
go queue.HandleConnection(connection) go queue.HandleConnection(connection)
return queue return queue
@@ -57,7 +48,10 @@ func (p *PacketQueue) HandleConnection(connection net.Conn) error {
continue continue
} }
if packet.DataLength == 0 { if packet.DataLength == 0 {
go p.HandleData(packet.MessageType, []byte{}) go p.HandleData(packet.MessageType, CallResult{
StatusCode: packet.StatusCode,
Data: []byte{},
})
continue continue
} }
data, err := GetPacketData(connection, packet.DataLength) data, err := GetPacketData(connection, packet.DataLength)
@@ -65,12 +59,15 @@ func (p *PacketQueue) HandleConnection(connection net.Conn) error {
log.Printf("Error receiving packet data: %v\n", err) log.Printf("Error receiving packet data: %v\n", err)
//return err //return err
} else { } else {
go p.HandleData(packet.MessageType, data) go p.HandleData(packet.MessageType, CallResult{
StatusCode: packet.StatusCode,
Data: data,
})
} }
} }
} }
func (p *PacketQueue) HandleData(t uint32, data []byte) { func (p *PacketQueue) HandleData(t uint32, data CallResult) {
p.mu.Lock() p.mu.Lock()
defer p.mu.Unlock() defer p.mu.Unlock()
l, ok := p.expectedPackages[t] l, ok := p.expectedPackages[t]
@@ -83,10 +80,9 @@ func (p *PacketQueue) HandleData(t uint32, data []byte) {
} }
return return
} }
data = nil
} }
func (p *PacketQueue) Expect(messageType uint32) <-chan []byte { func (p *PacketQueue) Expect(messageType uint32) <-chan CallResult {
p.mu.Lock() p.mu.Lock()
defer p.mu.Unlock() defer p.mu.Unlock()
l, ok := p.expectedPackages[messageType] l, ok := p.expectedPackages[messageType]
@@ -95,7 +91,7 @@ func (p *PacketQueue) Expect(messageType uint32) <-chan []byte {
return l.Chan return l.Chan
} }
ch := make(chan []byte) ch := make(chan CallResult)
p.expectedPackages[messageType] = &Listener{ p.expectedPackages[messageType] = &Listener{
Count: 1, Count: 1,
Chan: ch, Chan: ch,

View File

@@ -16,14 +16,16 @@ const (
type CartPacket struct { type CartPacket struct {
Version uint32 Version uint32
MessageType uint32 MessageType uint32
DataLength uint64 DataLength uint32
StatusCode uint32
Id CartId Id CartId
} }
type Packet struct { type Packet struct {
Version uint32 Version uint32
MessageType uint32 MessageType uint32
DataLength uint64 DataLength uint32
StatusCode uint32
} }
func ReadPacket(conn io.Reader, packet *Packet) error { func ReadPacket(conn io.Reader, packet *Packet) error {
@@ -34,7 +36,7 @@ func ReadCartPacket(conn io.Reader, packet *CartPacket) error {
return binary.Read(conn, binary.LittleEndian, packet) return binary.Read(conn, binary.LittleEndian, packet)
} }
func GetPacketData(conn io.Reader, len uint64) ([]byte, error) { func GetPacketData(conn io.Reader, len uint32) ([]byte, error) {
if len == 0 { if len == 0 {
return []byte{}, nil return []byte{}, nil
} }

View File

@@ -52,11 +52,16 @@ func ErrorHandler(fn func(w http.ResponseWriter, r *http.Request) error) func(w
} }
} }
func (s *PoolServer) WriteResult(w http.ResponseWriter, data []byte) error { func (s *PoolServer) WriteResult(w http.ResponseWriter, result *CallResult) error {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
w.Header().Set("X-Pod-Name", s.pod_name) w.Header().Set("X-Pod-Name", s.pod_name)
if result.StatusCode != 200 {
w.WriteHeader(int(result.StatusCode))
w.Write([]byte(result.Data))
return nil
}
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
_, err := w.Write(data) _, err := w.Write(result.Data)
return err return err
} }

67
remote-grain-pool.go Normal file
View File

@@ -0,0 +1,67 @@
package main
import "sync"
type RemoteGrainPool struct {
mu sync.RWMutex
Host string
grains map[CartId]*RemoteGrain
}
func NewRemoteGrainPool(addr string) *RemoteGrainPool {
return &RemoteGrainPool{
Host: addr,
grains: make(map[CartId]*RemoteGrain),
}
}
func (p *RemoteGrainPool) findRemoteGrain(id CartId) *RemoteGrain {
p.mu.RLock()
grain, ok := p.grains[id]
p.mu.RUnlock()
if !ok {
return nil
}
return grain
}
func (p *RemoteGrainPool) findOrCreateGrain(id CartId) (*RemoteGrain, error) {
grain := p.findRemoteGrain(id)
if grain == nil {
grain, err := NewRemoteGrain(id, p.Host)
if err != nil {
return nil, err
}
p.mu.Lock()
p.grains[id] = grain
p.mu.Unlock()
}
return grain, nil
}
func (p *RemoteGrainPool) Delete(id CartId) {
p.mu.Lock()
delete(p.grains, id)
p.mu.Unlock()
}
func (p *RemoteGrainPool) Process(id CartId, messages ...Message) (*CallResult, error) {
var result *CallResult
grain, err := p.findOrCreateGrain(id)
if err != nil {
return nil, err
}
for _, message := range messages {
result, err = grain.HandleMessage(&message, false)
}
return result, err
}
func (p *RemoteGrainPool) Get(id CartId) (*CallResult, error) {
grain, err := p.findOrCreateGrain(id)
if err != nil {
return nil, err
}
return grain.GetCurrentState()
}

91
remote-grain.go Normal file
View File

@@ -0,0 +1,91 @@
package main
import (
"fmt"
"strings"
"time"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
)
func (id CartId) String() string {
return strings.Trim(string(id[:]), "\x00")
}
func ToCartId(id string) CartId {
var result [16]byte
copy(result[:], []byte(id))
return result
}
type RemoteGrain struct {
*CartClient
Id CartId
Host string
}
func NewRemoteGrain(id CartId, host string) (*RemoteGrain, error) {
client, err := CartDial(fmt.Sprintf("%s:1337", host))
if err != nil {
return nil, err
}
return &RemoteGrain{
Id: id,
Host: host,
CartClient: client,
}, nil
}
var (
remoteCartLatency = promauto.NewCounter(prometheus.CounterOpts{
Name: "cart_remote_grain_calls_total_latency",
Help: "The total latency of remote grains",
})
remoteCartCallsTotal = promauto.NewCounter(prometheus.CounterOpts{
Name: "cart_remote_grain_calls_total",
Help: "The total number of calls to remote grains",
})
)
var start time.Time
func MeasureLatency(fn func() (*CallResult, error)) (*CallResult, error) {
start = time.Now()
data, err := fn()
if err != nil {
return data, err
}
elapsed := time.Since(start).Milliseconds()
go func() {
remoteCartLatency.Add(float64(elapsed))
remoteCartCallsTotal.Inc()
}()
return data, nil
}
func (g *RemoteGrain) HandleMessage(message *Message, isReplay bool) (*CallResult, error) {
data, err := GetData(message.Write)
if err != nil {
return nil, err
}
reply, err := MeasureLatency(func() (*CallResult, error) {
return g.Call(RemoteHandleMutation, g.Id, RemoteHandleMutationReply, data)
})
if err != nil {
return nil, err
}
return reply, err
}
func (g *RemoteGrain) GetId() CartId {
return g.Id
}
func (g *RemoteGrain) GetCurrentState() (*CallResult, error) {
return MeasureLatency(func() (*CallResult, error) { return g.Call(RemoteGetState, g.Id, RemoteGetStateReply, []byte{}) })
}

View File

@@ -1,154 +0,0 @@
package main
import (
"fmt"
"strings"
"sync"
"time"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
)
type RemoteGrainPool struct {
mu sync.RWMutex
Host string
grains map[CartId]*RemoteGrain
}
func (id CartId) String() string {
return strings.Trim(string(id[:]), "\x00")
}
func ToCartId(id string) CartId {
var result [16]byte
copy(result[:], []byte(id))
return result
}
type RemoteGrain struct {
*CartClient
Id CartId
Host string
}
func NewRemoteGrain(id CartId, host string) (*RemoteGrain, error) {
client, err := CartDial(fmt.Sprintf("%s:1337", host))
if err != nil {
return nil, err
}
return &RemoteGrain{
Id: id,
Host: host,
CartClient: client,
}, nil
}
var (
remoteCartLatency = promauto.NewCounter(prometheus.CounterOpts{
Name: "cart_remote_grain_calls_total_latency",
Help: "The total latency of remote grains",
})
remoteCartCallsTotal = promauto.NewCounter(prometheus.CounterOpts{
Name: "cart_remote_grain_calls_total",
Help: "The total number of calls to remote grains",
})
)
var start time.Time
func MeasureLatency(fn func() ([]byte, error)) ([]byte, error) {
start = time.Now()
data, err := fn()
if err != nil {
return data, err
}
elapsed := time.Since(start).Milliseconds()
go func() {
remoteCartLatency.Add(float64(elapsed))
remoteCartCallsTotal.Inc()
}()
return data, nil
}
func (g *RemoteGrain) HandleMessage(message *Message, isReplay bool) ([]byte, error) {
data, err := GetData(message.Write)
if err != nil {
return nil, err
}
reply, err := MeasureLatency(func() ([]byte, error) { return g.Call(RemoteHandleMutation, g.Id, RemoteHandleMutationReply, data) })
if err != nil {
return nil, err
}
return reply, err
}
func (g *RemoteGrain) GetId() CartId {
return g.Id
}
func (g *RemoteGrain) GetCurrentState() ([]byte, error) {
return MeasureLatency(func() ([]byte, error) { return g.Call(RemoteGetState, g.Id, RemoteGetStateReply, []byte{}) })
}
func NewRemoteGrainPool(addr string) *RemoteGrainPool {
return &RemoteGrainPool{
Host: addr,
grains: make(map[CartId]*RemoteGrain),
}
}
func (p *RemoteGrainPool) findRemoteGrain(id CartId) *RemoteGrain {
p.mu.RLock()
grain, ok := p.grains[id]
p.mu.RUnlock()
if !ok {
return nil
}
return grain
}
func (p *RemoteGrainPool) findOrCreateGrain(id CartId) (*RemoteGrain, error) {
grain := p.findRemoteGrain(id)
if grain == nil {
grain, err := NewRemoteGrain(id, p.Host)
if err != nil {
return nil, err
}
p.mu.Lock()
p.grains[id] = grain
p.mu.Unlock()
}
return grain, nil
}
func (p *RemoteGrainPool) Delete(id CartId) {
p.mu.Lock()
delete(p.grains, id)
p.mu.Unlock()
}
func (p *RemoteGrainPool) Process(id CartId, messages ...Message) ([]byte, error) {
var result []byte
grain, err := p.findOrCreateGrain(id)
if err != nil {
return nil, err
}
for _, message := range messages {
result, err = grain.HandleMessage(&message, false)
}
return result, err
}
func (p *RemoteGrainPool) Get(id CartId) ([]byte, error) {
grain, err := p.findOrCreateGrain(id)
if err != nil {
return nil, err
}
return grain.GetCurrentState()
}

View File

@@ -1,7 +1,6 @@
package main package main
import ( import (
"encoding/json"
"fmt" "fmt"
"log" "log"
"strings" "strings"
@@ -26,7 +25,6 @@ type RemoteHost struct {
*Client *Client
Host string Host string
MissedPings int MissedPings int
//Pool *RemoteGrainPool
} }
type SyncedPool struct { type SyncedPool struct {
@@ -248,21 +246,27 @@ const (
) )
func (h *RemoteHost) Negotiate(knownHosts []string) ([]string, error) { func (h *RemoteHost) Negotiate(knownHosts []string) ([]string, error) {
data, err := h.Call(RemoteNegotiate, RemoteNegotiateResponse, []byte(strings.Join(knownHosts, ";"))) reply, err := h.Call(RemoteNegotiate, RemoteNegotiateResponse, []byte(strings.Join(knownHosts, ";")))
if err != nil { if err != nil {
return nil, err return nil, err
} }
if reply.StatusCode != 200 {
return nil, fmt.Errorf("remote returned error on negotiate: %s", string(reply.Data))
}
return strings.Split(string(data), ";"), nil return strings.Split(string(reply.Data), ";"), nil
} }
func (g *RemoteHost) GetCartMappings() ([]CartId, error) { func (g *RemoteHost) GetCartMappings() ([]CartId, error) {
data, err := g.Call(GetCartIds, CartIdsResponse, []byte{}) reply, err := g.Call(GetCartIds, CartIdsResponse, []byte{})
if err != nil { if err != nil {
return nil, err return nil, err
} }
parts := strings.Split(string(data), ";") if reply.StatusCode != 200 {
return nil, fmt.Errorf("remote returned error: %s", string(reply.Data))
}
parts := strings.Split(string(reply.Data), ";")
ids := make([]CartId, 0, len(parts)) ids := make([]CartId, 0, len(parts))
for _, p := range parts { for _, p := range parts {
ids = append(ids, ToCartId(p)) ids = append(ids, ToCartId(p))
@@ -289,13 +293,13 @@ func (p *SyncedPool) Negotiate(knownHosts []string) ([]string, error) {
} }
func (r *RemoteHost) ConfirmChange(id CartId, host string) error { func (r *RemoteHost) ConfirmChange(id CartId, host string) error {
data, err := r.Call(RemoteGrainChanged, AckChange, []byte(fmt.Sprintf("%s;%s", id, host))) reply, err := r.Call(RemoteGrainChanged, AckChange, []byte(fmt.Sprintf("%s;%s", id, host)))
if err != nil { if err != nil {
return err return err
} }
if string(data) != "ok" { if string(reply.Data) != "ok" {
return fmt.Errorf("remote grain change failed %s", string(data)) return fmt.Errorf("remote grain change failed %s", string(reply.Data))
} }
return nil return nil
@@ -443,9 +447,9 @@ func (p *SyncedPool) getGrain(id CartId) (Grain, error) {
return localGrain, nil return localGrain, nil
} }
func (p *SyncedPool) Process(id CartId, messages ...Message) ([]byte, error) { func (p *SyncedPool) Process(id CartId, messages ...Message) (*CallResult, error) {
pool, err := p.getGrain(id) pool, err := p.getGrain(id)
var res []byte var res *CallResult
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -458,7 +462,7 @@ func (p *SyncedPool) Process(id CartId, messages ...Message) ([]byte, error) {
return res, nil return res, nil
} }
func (p *SyncedPool) Get(id CartId) ([]byte, error) { func (p *SyncedPool) Get(id CartId) (*CallResult, error) {
grain, err := p.getGrain(id) grain, err := p.getGrain(id)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -467,5 +471,5 @@ func (p *SyncedPool) Get(id CartId) ([]byte, error) {
return remoteGrain.GetCurrentState() return remoteGrain.GetCurrentState()
} }
return json.Marshal(grain) return grain.GetCurrentState()
} }

View File

@@ -73,7 +73,7 @@ func (m *CartTCPClient) SendPacket(messageType uint32, id CartId, data []byte) e
err = binary.Write(m.Conn, binary.LittleEndian, CartPacket{ err = binary.Write(m.Conn, binary.LittleEndian, CartPacket{
Version: CurrentPacketVersion, Version: CurrentPacketVersion,
MessageType: messageType, MessageType: messageType,
DataLength: uint64(len(data)), DataLength: uint32(len(data)),
Id: id, Id: id,
}) })
if err != nil { if err != nil {
@@ -91,7 +91,7 @@ func (m *CartTCPClient) SendPacket(messageType uint32, id CartId, data []byte) e
// return m.SendPacket(messageType, id, data) // return m.SendPacket(messageType, id, data)
// } // }
func (m *CartTCPClient) Call(messageType uint32, id CartId, responseType uint32, data []byte) ([]byte, error) { func (m *CartTCPClient) Call(messageType uint32, id CartId, responseType uint32, data []byte) (*CallResult, error) {
packetChan := m.Expect(responseType, id) packetChan := m.Expect(responseType, id)
err := m.SendPacket(messageType, id, data) err := m.SendPacket(messageType, id, data)
if err != nil { if err != nil {
@@ -99,7 +99,7 @@ func (m *CartTCPClient) Call(messageType uint32, id CartId, responseType uint32,
} }
select { select {
case ret := <-packetChan: case ret := <-packetChan:
return ret, nil return &ret, nil
case <-time.After(time.Second): case <-time.After(time.Second):
return nil, fmt.Errorf("timeout") return nil, fmt.Errorf("timeout")
} }

View File

@@ -69,13 +69,24 @@ func (m *TCPCartServerMux) handleFunction(connection net.Conn, messageType uint3
m.mu.RUnlock() m.mu.RUnlock()
if ok { if ok {
responseType, responseData, err := fn(id, data) responseType, responseData, err := fn(id, data)
if err != nil { if err != nil {
errData := []byte(err.Error())
err = binary.Write(connection, binary.LittleEndian, CartPacket{
Version: CurrentPacketVersion,
MessageType: responseType,
DataLength: uint32(len(errData)),
StatusCode: 500,
Id: id,
})
_, err = connection.Write(errData)
return true, err return true, err
} }
err = binary.Write(connection, binary.LittleEndian, CartPacket{ err = binary.Write(connection, binary.LittleEndian, CartPacket{
Version: CurrentPacketVersion, Version: CurrentPacketVersion,
MessageType: responseType, MessageType: responseType,
DataLength: uint64(len(responseData)), DataLength: uint32(len(responseData)),
StatusCode: 200,
Id: id, Id: id,
}) })
if err != nil { if err != nil {
@@ -101,7 +112,10 @@ func (m *TCPCartServerMux) HandleConnection(connection net.Conn) error {
log.Printf("Error receiving packet: %v\n", err) log.Printf("Error receiving packet: %v\n", err)
return err return err
} }
if packet.Version != CurrentPacketVersion {
log.Printf("Incorrect packet version: %d\n", packet.Version)
continue
}
data, err := GetPacketData(connection, packet.DataLength) data, err := GetPacketData(connection, packet.DataLength)
if err != nil { if err != nil {
log.Printf("Error getting packet data: %v\n", err) log.Printf("Error getting packet data: %v\n", err)

View File

@@ -1,6 +1,7 @@
package main package main
import ( import (
"fmt"
"log" "log"
"testing" "testing"
) )
@@ -21,6 +22,10 @@ func TestCartTcpHelpers(t *testing.T) {
messageData = string(data) messageData = string(data)
return nil return nil
}) })
server.HandleCall(666, func(id CartId, data []byte) (uint32, []byte, error) {
log.Printf("Received call: %s\n", string(data))
return 3, []byte("Hello, client!"), fmt.Errorf("Det blev fel")
})
server.HandleCall(2, func(id CartId, data []byte) (uint32, []byte, error) { server.HandleCall(2, func(id CartId, data []byte) (uint32, []byte, error) {
log.Printf("Received call: %s\n", string(data)) log.Printf("Received call: %s\n", string(data))
return 3, []byte("Hello, client!"), nil return 3, []byte("Hello, client!"), nil
@@ -34,6 +39,13 @@ func TestCartTcpHelpers(t *testing.T) {
if err != nil { if err != nil {
t.Errorf("Error calling: %v\n", err) t.Errorf("Error calling: %v\n", err)
} }
s, err := client.Call(666, id, 3, []byte("Hello, server!"))
if err != nil {
t.Errorf("Error calling: %v\n", err)
}
if s.StatusCode != 500 {
t.Errorf("Expected 500, got %d\n", s.StatusCode)
}
for i := 0; i < 100; i++ { for i := 0; i < 100; i++ {
_, err = client.Call(Ping, id, Pong, nil) _, err = client.Call(Ping, id, Pong, nil)
if err != nil { if err != nil {
@@ -44,8 +56,8 @@ func TestCartTcpHelpers(t *testing.T) {
if err != nil { if err != nil {
t.Errorf("Error calling: %v\n", err) t.Errorf("Error calling: %v\n", err)
} }
if string(answer) != "Hello, client!" { if string(answer.Data) != "Hello, client!" {
t.Errorf("Expected answer 'Hello, client!', got %s\n", string(answer)) t.Errorf("Expected answer 'Hello, client!', got %s\n", string(answer.Data))
} }
if messageData != "Hello, world!" { if messageData != "Hello, world!" {
t.Errorf("Expected message 'Hello, world!', got %s\n", messageData) t.Errorf("Expected message 'Hello, world!', got %s\n", messageData)

View File

@@ -77,7 +77,8 @@ func (m *TCPClient) SendPacket(messageType uint32, data []byte) error {
err = binary.Write(m.Conn, binary.LittleEndian, Packet{ err = binary.Write(m.Conn, binary.LittleEndian, Packet{
Version: CurrentPacketVersion, Version: CurrentPacketVersion,
MessageType: messageType, MessageType: messageType,
DataLength: uint64(len(data)), StatusCode: 0,
DataLength: uint32(len(data)),
}) })
if err != nil { if err != nil {
return m.HandleConnectionError(err) return m.HandleConnectionError(err)
@@ -94,7 +95,7 @@ func (m *TCPClient) SendPacket(messageType uint32, data []byte) error {
// return m.SendPacket(messageType, data) // return m.SendPacket(messageType, data)
// } // }
func (m *TCPClient) Call(messageType uint32, responseType uint32, data []byte) ([]byte, error) { func (m *TCPClient) Call(messageType uint32, responseType uint32, data []byte) (*CallResult, error) {
packetChan := m.Expect(responseType) packetChan := m.Expect(responseType)
err := m.SendPacket(messageType, data) err := m.SendPacket(messageType, data)
if err != nil { if err != nil {
@@ -103,7 +104,7 @@ func (m *TCPClient) Call(messageType uint32, responseType uint32, data []byte) (
select { select {
case ret := <-packetChan: case ret := <-packetChan:
return ret, nil return &ret, nil
case <-time.After(time.Second): case <-time.After(time.Second):
return nil, fmt.Errorf("timeout") return nil, fmt.Errorf("timeout")
} }

View File

@@ -70,12 +70,21 @@ func (m *TCPServerMux) handleFunction(connection net.Conn, messageType uint32, d
if ok { if ok {
responseType, responseData, err := function(data) responseType, responseData, err := function(data)
if err != nil { if err != nil {
errData := []byte(err.Error())
err = binary.Write(connection, binary.LittleEndian, Packet{
Version: CurrentPacketVersion,
MessageType: responseType,
StatusCode: 500,
DataLength: uint32(len(errData)),
})
_, err = connection.Write(errData)
return true, err return true, err
} }
err = binary.Write(connection, binary.LittleEndian, Packet{ err = binary.Write(connection, binary.LittleEndian, Packet{
Version: CurrentPacketVersion, Version: CurrentPacketVersion,
MessageType: responseType, MessageType: responseType,
DataLength: uint64(len(responseData)), StatusCode: 200,
DataLength: uint32(len(responseData)),
}) })
if err != nil { if err != nil {
return true, err return true, err
@@ -100,6 +109,10 @@ func (m *TCPServerMux) HandleConnection(connection net.Conn) error {
log.Printf("Error receiving packet: %v\n", err) log.Printf("Error receiving packet: %v\n", err)
return err return err
} }
if packet.Version != CurrentPacketVersion {
log.Printf("Incorrect package version: %v\n", err)
continue
}
data, err := GetPacketData(connection, packet.DataLength) data, err := GetPacketData(connection, packet.DataLength)
if err != nil { if err != nil {
log.Printf("Error receiving packet data: %v\n", err) log.Printf("Error receiving packet data: %v\n", err)

View File

@@ -39,8 +39,8 @@ func TestTcpHelpers(t *testing.T) {
t.Errorf("Error calling: %v\n", err) t.Errorf("Error calling: %v\n", err)
} }
client.Close() client.Close()
if string(answer) != "Hello, client!" { if string(answer.Data) != "Hello, client!" {
t.Errorf("Expected answer 'Hello, client!', got %s\n", string(answer)) t.Errorf("Expected answer 'Hello, client!', got %s\n", string(answer.Data))
} }
if messageData != "Hello, world!" { if messageData != "Hello, world!" {
t.Errorf("Expected message 'Hello, world!', got %s\n", messageData) t.Errorf("Expected message 'Hello, world!', got %s\n", messageData)