implement statuscode in packets
All checks were successful
Build and Publish / BuildAndDeploy (push) Successful in 2m2s

This commit is contained in:
matst80
2024-11-11 23:24:03 +01:00
parent 9c15251f67
commit 0b290a32bf
17 changed files with 295 additions and 226 deletions

View File

@@ -54,7 +54,8 @@ type CartGrain struct {
type Grain interface {
GetId() CartId
HandleMessage(message *Message, isReplay bool) ([]byte, error)
HandleMessage(message *Message, isReplay bool) (*CallResult, error)
GetCurrentState() (*CallResult, error)
}
func (c *CartGrain) GetId() CartId {
@@ -68,6 +69,14 @@ func (c *CartGrain) GetLastChange() int64 {
return *c.storageMessages[len(c.storageMessages)-1].TimeStamp
}
func (c *CartGrain) GetCurrentState() (*CallResult, error) {
result, err := json.Marshal(c)
return &CallResult{
StatusCode: 200,
Data: result,
}, err
}
func getItemData(sku string, qty int) (*messages.AddItem, error) {
item, err := FetchItem(sku)
if err != nil {
@@ -99,7 +108,7 @@ func getItemData(sku string, qty int) (*messages.AddItem, error) {
}, nil
}
func (c *CartGrain) AddItem(sku string, qty int) ([]byte, error) {
func (c *CartGrain) AddItem(sku string, qty int) (*CallResult, error) {
cartItem, err := getItemData(sku, qty)
if err != nil {
return nil, err
@@ -171,7 +180,7 @@ func (c *CartGrain) FindItemWithSku(sku string) (*CartItem, bool) {
return nil, false
}
func (c *CartGrain) HandleMessage(message *Message, isReplay bool) ([]byte, error) {
func (c *CartGrain) HandleMessage(message *Message, isReplay bool) (*CallResult, error) {
if message.TimeStamp == nil {
now := time.Now().Unix()
message.TimeStamp = &now
@@ -294,5 +303,9 @@ func (c *CartGrain) HandleMessage(message *Message, isReplay bool) ([]byte, erro
c.storageMessages = append(c.storageMessages, *message)
c.mu.Unlock()
}
return json.Marshal(c)
result, err := json.Marshal(c)
return &CallResult{
StatusCode: 200,
Data: result,
}, err
}

View File

@@ -88,8 +88,8 @@ func TestAddToCart(t *testing.T) {
if err != nil {
t.Errorf("Error handling message: %v\n", err)
}
if len(result) == 0 {
t.Errorf("Expected result, got nil\n")
if result.StatusCode != 200 {
t.Errorf("Call failed\n")
}
if grain.TotalPrice != 200 {
t.Errorf("Expected total price 200, got %d\n", grain.TotalPrice)
@@ -104,8 +104,8 @@ func TestAddToCart(t *testing.T) {
if err != nil {
t.Errorf("Error handling message: %v\n", err)
}
if len(result) == 0 {
t.Errorf("Expected result, got nil\n")
if result.StatusCode != 200 {
t.Errorf("Call failed\n")
}
if grain.Items[0].Quantity != 4 {
t.Errorf("Expected quantity 4, got %d\n", grain.Items[0].Quantity)
@@ -146,8 +146,8 @@ func TestSetDelivery(t *testing.T) {
if err != nil {
t.Errorf("Error handling message: %v\n", err)
}
if len(result) == 0 {
t.Errorf("Expected result, got nil\n")
if result.StatusCode != 200 {
t.Errorf("Call failed\n")
}
setDelivery := GetMessage(SetDeliveryType, &messages.SetDelivery{
@@ -198,8 +198,8 @@ func TestSetDeliveryOnAll(t *testing.T) {
if err != nil {
t.Errorf("Error handling message: %v\n", err)
}
if len(result) == 0 {
t.Errorf("Expected result, got nil\n")
if result.StatusCode != 200 {
t.Errorf("Call failed\n")
}
setDelivery := GetMessage(SetDeliveryType, &messages.SetDelivery{

View File

@@ -43,7 +43,10 @@ func (p *CartPacketQueue) HandleConnection(connection net.Conn) error {
continue
}
if packet.DataLength == 0 {
go p.HandleData(packet.MessageType, packet.Id, []byte{})
go p.HandleData(packet.MessageType, packet.Id, CallResult{
StatusCode: packet.StatusCode,
Data: []byte{},
})
continue
}
data, err := GetPacketData(connection, packet.DataLength)
@@ -51,11 +54,14 @@ func (p *CartPacketQueue) HandleConnection(connection net.Conn) error {
log.Printf("Error receiving packet data: %v\n", err)
return err
}
go p.HandleData(packet.MessageType, packet.Id, data)
go p.HandleData(packet.MessageType, packet.Id, CallResult{
StatusCode: packet.StatusCode,
Data: data,
})
}
}
func (p *CartPacketQueue) HandleData(t uint32, id CartId, data []byte) {
func (p *CartPacketQueue) HandleData(t uint32, id CartId, data CallResult) {
p.mu.Lock()
defer p.mu.Unlock()
pl, ok := p.expectedPackages[t]
@@ -70,10 +76,9 @@ func (p *CartPacketQueue) HandleData(t uint32, id CartId, data []byte) {
}
}
}
data = nil
}
func (p *CartPacketQueue) Expect(messageType uint32, id CartId) <-chan []byte {
func (p *CartPacketQueue) Expect(messageType uint32, id CartId) <-chan CallResult {
p.mu.Lock()
defer p.mu.Unlock()
l, ok := p.expectedPackages[messageType]
@@ -82,7 +87,7 @@ func (p *CartPacketQueue) Expect(messageType uint32, id CartId) <-chan []byte {
idl.Count++
return idl.Chan
}
ch := make(chan []byte)
ch := make(chan CallResult)
(*l)[id] = Listener{
Chan: ch,
Count: 1,
@@ -90,7 +95,7 @@ func (p *CartPacketQueue) Expect(messageType uint32, id CartId) <-chan []byte {
return ch
}
ch := make(chan []byte)
ch := make(chan CallResult)
p.expectedPackages[messageType] = &CartListener{
id: Listener{
Chan: ch,

View File

@@ -27,8 +27,8 @@ var (
)
type GrainPool interface {
Process(id CartId, messages ...Message) ([]byte, error)
Get(id CartId) ([]byte, error)
Process(id CartId, messages ...Message) (*CallResult, error)
Get(id CartId) (*CallResult, error)
}
type Ttl struct {

View File

@@ -7,33 +7,24 @@ import (
"sync"
)
// type PacketWithData struct {
// MessageType uint32
// Added time.Time
// Consumed bool
// Data []byte
// }
type PacketQueue struct {
mu sync.RWMutex
expectedPackages map[uint32]*Listener
//Packets []PacketWithData
//connection net.Conn
}
//const cap = 150
type CallResult struct {
StatusCode uint32
Data []byte
}
type Listener struct {
Count int
Chan chan []byte
Chan chan CallResult
}
func NewPacketQueue(connection net.Conn) *PacketQueue {
queue := &PacketQueue{
expectedPackages: make(map[uint32]*Listener),
//Packets: make([]PacketWithData, 0, cap+1),
//connection: connection,
}
go queue.HandleConnection(connection)
return queue
@@ -57,7 +48,10 @@ func (p *PacketQueue) HandleConnection(connection net.Conn) error {
continue
}
if packet.DataLength == 0 {
go p.HandleData(packet.MessageType, []byte{})
go p.HandleData(packet.MessageType, CallResult{
StatusCode: packet.StatusCode,
Data: []byte{},
})
continue
}
data, err := GetPacketData(connection, packet.DataLength)
@@ -65,12 +59,15 @@ func (p *PacketQueue) HandleConnection(connection net.Conn) error {
log.Printf("Error receiving packet data: %v\n", err)
//return err
} else {
go p.HandleData(packet.MessageType, data)
go p.HandleData(packet.MessageType, CallResult{
StatusCode: packet.StatusCode,
Data: data,
})
}
}
}
func (p *PacketQueue) HandleData(t uint32, data []byte) {
func (p *PacketQueue) HandleData(t uint32, data CallResult) {
p.mu.Lock()
defer p.mu.Unlock()
l, ok := p.expectedPackages[t]
@@ -83,10 +80,9 @@ func (p *PacketQueue) HandleData(t uint32, data []byte) {
}
return
}
data = nil
}
func (p *PacketQueue) Expect(messageType uint32) <-chan []byte {
func (p *PacketQueue) Expect(messageType uint32) <-chan CallResult {
p.mu.Lock()
defer p.mu.Unlock()
l, ok := p.expectedPackages[messageType]
@@ -95,7 +91,7 @@ func (p *PacketQueue) Expect(messageType uint32) <-chan []byte {
return l.Chan
}
ch := make(chan []byte)
ch := make(chan CallResult)
p.expectedPackages[messageType] = &Listener{
Count: 1,
Chan: ch,

View File

@@ -16,14 +16,16 @@ const (
type CartPacket struct {
Version uint32
MessageType uint32
DataLength uint64
DataLength uint32
StatusCode uint32
Id CartId
}
type Packet struct {
Version uint32
MessageType uint32
DataLength uint64
DataLength uint32
StatusCode uint32
}
func ReadPacket(conn io.Reader, packet *Packet) error {
@@ -34,7 +36,7 @@ func ReadCartPacket(conn io.Reader, packet *CartPacket) error {
return binary.Read(conn, binary.LittleEndian, packet)
}
func GetPacketData(conn io.Reader, len uint64) ([]byte, error) {
func GetPacketData(conn io.Reader, len uint32) ([]byte, error) {
if len == 0 {
return []byte{}, nil
}

View File

@@ -52,11 +52,16 @@ func ErrorHandler(fn func(w http.ResponseWriter, r *http.Request) error) func(w
}
}
func (s *PoolServer) WriteResult(w http.ResponseWriter, data []byte) error {
func (s *PoolServer) WriteResult(w http.ResponseWriter, result *CallResult) error {
w.Header().Set("Content-Type", "application/json")
w.Header().Set("X-Pod-Name", s.pod_name)
if result.StatusCode != 200 {
w.WriteHeader(int(result.StatusCode))
w.Write([]byte(result.Data))
return nil
}
w.WriteHeader(http.StatusOK)
_, err := w.Write(data)
_, err := w.Write(result.Data)
return err
}

67
remote-grain-pool.go Normal file
View File

@@ -0,0 +1,67 @@
package main
import "sync"
type RemoteGrainPool struct {
mu sync.RWMutex
Host string
grains map[CartId]*RemoteGrain
}
func NewRemoteGrainPool(addr string) *RemoteGrainPool {
return &RemoteGrainPool{
Host: addr,
grains: make(map[CartId]*RemoteGrain),
}
}
func (p *RemoteGrainPool) findRemoteGrain(id CartId) *RemoteGrain {
p.mu.RLock()
grain, ok := p.grains[id]
p.mu.RUnlock()
if !ok {
return nil
}
return grain
}
func (p *RemoteGrainPool) findOrCreateGrain(id CartId) (*RemoteGrain, error) {
grain := p.findRemoteGrain(id)
if grain == nil {
grain, err := NewRemoteGrain(id, p.Host)
if err != nil {
return nil, err
}
p.mu.Lock()
p.grains[id] = grain
p.mu.Unlock()
}
return grain, nil
}
func (p *RemoteGrainPool) Delete(id CartId) {
p.mu.Lock()
delete(p.grains, id)
p.mu.Unlock()
}
func (p *RemoteGrainPool) Process(id CartId, messages ...Message) (*CallResult, error) {
var result *CallResult
grain, err := p.findOrCreateGrain(id)
if err != nil {
return nil, err
}
for _, message := range messages {
result, err = grain.HandleMessage(&message, false)
}
return result, err
}
func (p *RemoteGrainPool) Get(id CartId) (*CallResult, error) {
grain, err := p.findOrCreateGrain(id)
if err != nil {
return nil, err
}
return grain.GetCurrentState()
}

91
remote-grain.go Normal file
View File

@@ -0,0 +1,91 @@
package main
import (
"fmt"
"strings"
"time"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
)
func (id CartId) String() string {
return strings.Trim(string(id[:]), "\x00")
}
func ToCartId(id string) CartId {
var result [16]byte
copy(result[:], []byte(id))
return result
}
type RemoteGrain struct {
*CartClient
Id CartId
Host string
}
func NewRemoteGrain(id CartId, host string) (*RemoteGrain, error) {
client, err := CartDial(fmt.Sprintf("%s:1337", host))
if err != nil {
return nil, err
}
return &RemoteGrain{
Id: id,
Host: host,
CartClient: client,
}, nil
}
var (
remoteCartLatency = promauto.NewCounter(prometheus.CounterOpts{
Name: "cart_remote_grain_calls_total_latency",
Help: "The total latency of remote grains",
})
remoteCartCallsTotal = promauto.NewCounter(prometheus.CounterOpts{
Name: "cart_remote_grain_calls_total",
Help: "The total number of calls to remote grains",
})
)
var start time.Time
func MeasureLatency(fn func() (*CallResult, error)) (*CallResult, error) {
start = time.Now()
data, err := fn()
if err != nil {
return data, err
}
elapsed := time.Since(start).Milliseconds()
go func() {
remoteCartLatency.Add(float64(elapsed))
remoteCartCallsTotal.Inc()
}()
return data, nil
}
func (g *RemoteGrain) HandleMessage(message *Message, isReplay bool) (*CallResult, error) {
data, err := GetData(message.Write)
if err != nil {
return nil, err
}
reply, err := MeasureLatency(func() (*CallResult, error) {
return g.Call(RemoteHandleMutation, g.Id, RemoteHandleMutationReply, data)
})
if err != nil {
return nil, err
}
return reply, err
}
func (g *RemoteGrain) GetId() CartId {
return g.Id
}
func (g *RemoteGrain) GetCurrentState() (*CallResult, error) {
return MeasureLatency(func() (*CallResult, error) { return g.Call(RemoteGetState, g.Id, RemoteGetStateReply, []byte{}) })
}

View File

@@ -1,154 +0,0 @@
package main
import (
"fmt"
"strings"
"sync"
"time"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
)
type RemoteGrainPool struct {
mu sync.RWMutex
Host string
grains map[CartId]*RemoteGrain
}
func (id CartId) String() string {
return strings.Trim(string(id[:]), "\x00")
}
func ToCartId(id string) CartId {
var result [16]byte
copy(result[:], []byte(id))
return result
}
type RemoteGrain struct {
*CartClient
Id CartId
Host string
}
func NewRemoteGrain(id CartId, host string) (*RemoteGrain, error) {
client, err := CartDial(fmt.Sprintf("%s:1337", host))
if err != nil {
return nil, err
}
return &RemoteGrain{
Id: id,
Host: host,
CartClient: client,
}, nil
}
var (
remoteCartLatency = promauto.NewCounter(prometheus.CounterOpts{
Name: "cart_remote_grain_calls_total_latency",
Help: "The total latency of remote grains",
})
remoteCartCallsTotal = promauto.NewCounter(prometheus.CounterOpts{
Name: "cart_remote_grain_calls_total",
Help: "The total number of calls to remote grains",
})
)
var start time.Time
func MeasureLatency(fn func() ([]byte, error)) ([]byte, error) {
start = time.Now()
data, err := fn()
if err != nil {
return data, err
}
elapsed := time.Since(start).Milliseconds()
go func() {
remoteCartLatency.Add(float64(elapsed))
remoteCartCallsTotal.Inc()
}()
return data, nil
}
func (g *RemoteGrain) HandleMessage(message *Message, isReplay bool) ([]byte, error) {
data, err := GetData(message.Write)
if err != nil {
return nil, err
}
reply, err := MeasureLatency(func() ([]byte, error) { return g.Call(RemoteHandleMutation, g.Id, RemoteHandleMutationReply, data) })
if err != nil {
return nil, err
}
return reply, err
}
func (g *RemoteGrain) GetId() CartId {
return g.Id
}
func (g *RemoteGrain) GetCurrentState() ([]byte, error) {
return MeasureLatency(func() ([]byte, error) { return g.Call(RemoteGetState, g.Id, RemoteGetStateReply, []byte{}) })
}
func NewRemoteGrainPool(addr string) *RemoteGrainPool {
return &RemoteGrainPool{
Host: addr,
grains: make(map[CartId]*RemoteGrain),
}
}
func (p *RemoteGrainPool) findRemoteGrain(id CartId) *RemoteGrain {
p.mu.RLock()
grain, ok := p.grains[id]
p.mu.RUnlock()
if !ok {
return nil
}
return grain
}
func (p *RemoteGrainPool) findOrCreateGrain(id CartId) (*RemoteGrain, error) {
grain := p.findRemoteGrain(id)
if grain == nil {
grain, err := NewRemoteGrain(id, p.Host)
if err != nil {
return nil, err
}
p.mu.Lock()
p.grains[id] = grain
p.mu.Unlock()
}
return grain, nil
}
func (p *RemoteGrainPool) Delete(id CartId) {
p.mu.Lock()
delete(p.grains, id)
p.mu.Unlock()
}
func (p *RemoteGrainPool) Process(id CartId, messages ...Message) ([]byte, error) {
var result []byte
grain, err := p.findOrCreateGrain(id)
if err != nil {
return nil, err
}
for _, message := range messages {
result, err = grain.HandleMessage(&message, false)
}
return result, err
}
func (p *RemoteGrainPool) Get(id CartId) ([]byte, error) {
grain, err := p.findOrCreateGrain(id)
if err != nil {
return nil, err
}
return grain.GetCurrentState()
}

View File

@@ -1,7 +1,6 @@
package main
import (
"encoding/json"
"fmt"
"log"
"strings"
@@ -26,7 +25,6 @@ type RemoteHost struct {
*Client
Host string
MissedPings int
//Pool *RemoteGrainPool
}
type SyncedPool struct {
@@ -248,21 +246,27 @@ const (
)
func (h *RemoteHost) Negotiate(knownHosts []string) ([]string, error) {
data, err := h.Call(RemoteNegotiate, RemoteNegotiateResponse, []byte(strings.Join(knownHosts, ";")))
reply, err := h.Call(RemoteNegotiate, RemoteNegotiateResponse, []byte(strings.Join(knownHosts, ";")))
if err != nil {
return nil, err
}
if reply.StatusCode != 200 {
return nil, fmt.Errorf("remote returned error on negotiate: %s", string(reply.Data))
}
return strings.Split(string(data), ";"), nil
return strings.Split(string(reply.Data), ";"), nil
}
func (g *RemoteHost) GetCartMappings() ([]CartId, error) {
data, err := g.Call(GetCartIds, CartIdsResponse, []byte{})
reply, err := g.Call(GetCartIds, CartIdsResponse, []byte{})
if err != nil {
return nil, err
}
parts := strings.Split(string(data), ";")
if reply.StatusCode != 200 {
return nil, fmt.Errorf("remote returned error: %s", string(reply.Data))
}
parts := strings.Split(string(reply.Data), ";")
ids := make([]CartId, 0, len(parts))
for _, p := range parts {
ids = append(ids, ToCartId(p))
@@ -289,13 +293,13 @@ func (p *SyncedPool) Negotiate(knownHosts []string) ([]string, error) {
}
func (r *RemoteHost) ConfirmChange(id CartId, host string) error {
data, err := r.Call(RemoteGrainChanged, AckChange, []byte(fmt.Sprintf("%s;%s", id, host)))
reply, err := r.Call(RemoteGrainChanged, AckChange, []byte(fmt.Sprintf("%s;%s", id, host)))
if err != nil {
return err
}
if string(data) != "ok" {
return fmt.Errorf("remote grain change failed %s", string(data))
if string(reply.Data) != "ok" {
return fmt.Errorf("remote grain change failed %s", string(reply.Data))
}
return nil
@@ -443,9 +447,9 @@ func (p *SyncedPool) getGrain(id CartId) (Grain, error) {
return localGrain, nil
}
func (p *SyncedPool) Process(id CartId, messages ...Message) ([]byte, error) {
func (p *SyncedPool) Process(id CartId, messages ...Message) (*CallResult, error) {
pool, err := p.getGrain(id)
var res []byte
var res *CallResult
if err != nil {
return nil, err
}
@@ -458,7 +462,7 @@ func (p *SyncedPool) Process(id CartId, messages ...Message) ([]byte, error) {
return res, nil
}
func (p *SyncedPool) Get(id CartId) ([]byte, error) {
func (p *SyncedPool) Get(id CartId) (*CallResult, error) {
grain, err := p.getGrain(id)
if err != nil {
return nil, err
@@ -467,5 +471,5 @@ func (p *SyncedPool) Get(id CartId) ([]byte, error) {
return remoteGrain.GetCurrentState()
}
return json.Marshal(grain)
return grain.GetCurrentState()
}

View File

@@ -73,7 +73,7 @@ func (m *CartTCPClient) SendPacket(messageType uint32, id CartId, data []byte) e
err = binary.Write(m.Conn, binary.LittleEndian, CartPacket{
Version: CurrentPacketVersion,
MessageType: messageType,
DataLength: uint64(len(data)),
DataLength: uint32(len(data)),
Id: id,
})
if err != nil {
@@ -91,7 +91,7 @@ func (m *CartTCPClient) SendPacket(messageType uint32, id CartId, data []byte) e
// return m.SendPacket(messageType, id, data)
// }
func (m *CartTCPClient) Call(messageType uint32, id CartId, responseType uint32, data []byte) ([]byte, error) {
func (m *CartTCPClient) Call(messageType uint32, id CartId, responseType uint32, data []byte) (*CallResult, error) {
packetChan := m.Expect(responseType, id)
err := m.SendPacket(messageType, id, data)
if err != nil {
@@ -99,7 +99,7 @@ func (m *CartTCPClient) Call(messageType uint32, id CartId, responseType uint32,
}
select {
case ret := <-packetChan:
return ret, nil
return &ret, nil
case <-time.After(time.Second):
return nil, fmt.Errorf("timeout")
}

View File

@@ -69,13 +69,24 @@ func (m *TCPCartServerMux) handleFunction(connection net.Conn, messageType uint3
m.mu.RUnlock()
if ok {
responseType, responseData, err := fn(id, data)
if err != nil {
errData := []byte(err.Error())
err = binary.Write(connection, binary.LittleEndian, CartPacket{
Version: CurrentPacketVersion,
MessageType: responseType,
DataLength: uint32(len(errData)),
StatusCode: 500,
Id: id,
})
_, err = connection.Write(errData)
return true, err
}
err = binary.Write(connection, binary.LittleEndian, CartPacket{
Version: CurrentPacketVersion,
MessageType: responseType,
DataLength: uint64(len(responseData)),
DataLength: uint32(len(responseData)),
StatusCode: 200,
Id: id,
})
if err != nil {
@@ -101,7 +112,10 @@ func (m *TCPCartServerMux) HandleConnection(connection net.Conn) error {
log.Printf("Error receiving packet: %v\n", err)
return err
}
if packet.Version != CurrentPacketVersion {
log.Printf("Incorrect packet version: %d\n", packet.Version)
continue
}
data, err := GetPacketData(connection, packet.DataLength)
if err != nil {
log.Printf("Error getting packet data: %v\n", err)

View File

@@ -1,6 +1,7 @@
package main
import (
"fmt"
"log"
"testing"
)
@@ -21,6 +22,10 @@ func TestCartTcpHelpers(t *testing.T) {
messageData = string(data)
return nil
})
server.HandleCall(666, func(id CartId, data []byte) (uint32, []byte, error) {
log.Printf("Received call: %s\n", string(data))
return 3, []byte("Hello, client!"), fmt.Errorf("Det blev fel")
})
server.HandleCall(2, func(id CartId, data []byte) (uint32, []byte, error) {
log.Printf("Received call: %s\n", string(data))
return 3, []byte("Hello, client!"), nil
@@ -34,6 +39,13 @@ func TestCartTcpHelpers(t *testing.T) {
if err != nil {
t.Errorf("Error calling: %v\n", err)
}
s, err := client.Call(666, id, 3, []byte("Hello, server!"))
if err != nil {
t.Errorf("Error calling: %v\n", err)
}
if s.StatusCode != 500 {
t.Errorf("Expected 500, got %d\n", s.StatusCode)
}
for i := 0; i < 100; i++ {
_, err = client.Call(Ping, id, Pong, nil)
if err != nil {
@@ -44,8 +56,8 @@ func TestCartTcpHelpers(t *testing.T) {
if err != nil {
t.Errorf("Error calling: %v\n", err)
}
if string(answer) != "Hello, client!" {
t.Errorf("Expected answer 'Hello, client!', got %s\n", string(answer))
if string(answer.Data) != "Hello, client!" {
t.Errorf("Expected answer 'Hello, client!', got %s\n", string(answer.Data))
}
if messageData != "Hello, world!" {
t.Errorf("Expected message 'Hello, world!', got %s\n", messageData)

View File

@@ -77,7 +77,8 @@ func (m *TCPClient) SendPacket(messageType uint32, data []byte) error {
err = binary.Write(m.Conn, binary.LittleEndian, Packet{
Version: CurrentPacketVersion,
MessageType: messageType,
DataLength: uint64(len(data)),
StatusCode: 0,
DataLength: uint32(len(data)),
})
if err != nil {
return m.HandleConnectionError(err)
@@ -94,7 +95,7 @@ func (m *TCPClient) SendPacket(messageType uint32, data []byte) error {
// return m.SendPacket(messageType, data)
// }
func (m *TCPClient) Call(messageType uint32, responseType uint32, data []byte) ([]byte, error) {
func (m *TCPClient) Call(messageType uint32, responseType uint32, data []byte) (*CallResult, error) {
packetChan := m.Expect(responseType)
err := m.SendPacket(messageType, data)
if err != nil {
@@ -103,7 +104,7 @@ func (m *TCPClient) Call(messageType uint32, responseType uint32, data []byte) (
select {
case ret := <-packetChan:
return ret, nil
return &ret, nil
case <-time.After(time.Second):
return nil, fmt.Errorf("timeout")
}

View File

@@ -70,12 +70,21 @@ func (m *TCPServerMux) handleFunction(connection net.Conn, messageType uint32, d
if ok {
responseType, responseData, err := function(data)
if err != nil {
errData := []byte(err.Error())
err = binary.Write(connection, binary.LittleEndian, Packet{
Version: CurrentPacketVersion,
MessageType: responseType,
StatusCode: 500,
DataLength: uint32(len(errData)),
})
_, err = connection.Write(errData)
return true, err
}
err = binary.Write(connection, binary.LittleEndian, Packet{
Version: CurrentPacketVersion,
MessageType: responseType,
DataLength: uint64(len(responseData)),
StatusCode: 200,
DataLength: uint32(len(responseData)),
})
if err != nil {
return true, err
@@ -100,6 +109,10 @@ func (m *TCPServerMux) HandleConnection(connection net.Conn) error {
log.Printf("Error receiving packet: %v\n", err)
return err
}
if packet.Version != CurrentPacketVersion {
log.Printf("Incorrect package version: %v\n", err)
continue
}
data, err := GetPacketData(connection, packet.DataLength)
if err != nil {
log.Printf("Error receiving packet data: %v\n", err)

View File

@@ -39,8 +39,8 @@ func TestTcpHelpers(t *testing.T) {
t.Errorf("Error calling: %v\n", err)
}
client.Close()
if string(answer) != "Hello, client!" {
t.Errorf("Expected answer 'Hello, client!', got %s\n", string(answer))
if string(answer.Data) != "Hello, client!" {
t.Errorf("Expected answer 'Hello, client!', got %s\n", string(answer.Data))
}
if messageData != "Hello, world!" {
t.Errorf("Expected message 'Hello, world!', got %s\n", messageData)