From abf561c3fec7c8ad29d8bf7dba8f1a6c0bad697b Mon Sep 17 00:00:00 2001 From: matst80 Date: Wed, 13 Nov 2024 21:56:40 +0100 Subject: [PATCH] major refactor --- cart-grain.go | 28 ++++--- cart-packet-queue.go | 163 ----------------------------------------- grain-pool.go | 22 ++++-- packet-queue.go | 122 ------------------------------ packet-queue_test.go | 28 ------- packet.go | 130 +++++++++++++++----------------- pool-server.go | 6 +- remote-grain-pool.go | 6 +- remote-grain.go | 77 +++++++++---------- remote-host.go | 35 ++++----- rpc-server.go | 44 ++++++----- synced-pool.go | 118 ++++++++++++++--------------- tcp-cart-client.go | 96 ------------------------ tcp-cart-mux-server.go | 160 ---------------------------------------- tcp-cart_test.go | 57 -------------- tcp-client.go | 144 ------------------------------------ tcp-connection.go | 119 ++++++++++++++++++------------ tcp-connection_test.go | 32 ++------ tcp-mux-server.go | 161 ---------------------------------------- tcp_test.go | 54 -------------- 20 files changed, 310 insertions(+), 1292 deletions(-) delete mode 100644 cart-packet-queue.go delete mode 100644 packet-queue.go delete mode 100644 packet-queue_test.go delete mode 100644 tcp-cart-client.go delete mode 100644 tcp-cart-mux-server.go delete mode 100644 tcp-cart_test.go delete mode 100644 tcp-client.go delete mode 100644 tcp-mux-server.go delete mode 100644 tcp_test.go diff --git a/cart-grain.go b/cart-grain.go index 05bddff..5ffd0ef 100644 --- a/cart-grain.go +++ b/cart-grain.go @@ -54,8 +54,8 @@ type CartGrain struct { type Grain interface { GetId() CartId - HandleMessage(message *Message, isReplay bool) (*CallResult, error) - GetCurrentState() (*CallResult, error) + HandleMessage(message *Message, isReplay bool) (*FrameWithPayload, error) + GetCurrentState() (*FrameWithPayload, error) } func (c *CartGrain) GetId() CartId { @@ -69,12 +69,14 @@ func (c *CartGrain) GetLastChange() int64 { return *c.storageMessages[len(c.storageMessages)-1].TimeStamp } -func (c *CartGrain) GetCurrentState() (*CallResult, error) { +func (c *CartGrain) GetCurrentState() (*FrameWithPayload, error) { result, err := json.Marshal(c) - return &CallResult{ - StatusCode: 200, - Data: result, - }, err + if err != nil { + ret := MakeFrameWithPayload(0, 400, []byte(err.Error())) + return &ret, nil + } + ret := MakeFrameWithPayload(0, 200, result) + return &ret, nil } func getItemData(sku string, qty int) (*messages.AddItem, error) { @@ -108,7 +110,7 @@ func getItemData(sku string, qty int) (*messages.AddItem, error) { }, nil } -func (c *CartGrain) AddItem(sku string, qty int) (*CallResult, error) { +func (c *CartGrain) AddItem(sku string, qty int) (*FrameWithPayload, error) { cartItem, err := getItemData(sku, qty) if err != nil { return nil, err @@ -180,7 +182,7 @@ func (c *CartGrain) FindItemWithSku(sku string) (*CartItem, bool) { return nil, false } -func (c *CartGrain) HandleMessage(message *Message, isReplay bool) (*CallResult, error) { +func (c *CartGrain) HandleMessage(message *Message, isReplay bool) (*FrameWithPayload, error) { if message.TimeStamp == nil { now := time.Now().Unix() message.TimeStamp = &now @@ -305,8 +307,10 @@ func (c *CartGrain) HandleMessage(message *Message, isReplay bool) (*CallResult, c.mu.Unlock() } result, err := json.Marshal(c) - return &CallResult{ - StatusCode: 200, - Data: result, + return &FrameWithPayload{ + Frame: Frame{ + StatusCode: 200, + }, + Payload: result, }, err } diff --git a/cart-packet-queue.go b/cart-packet-queue.go deleted file mode 100644 index 5584dd5..0000000 --- a/cart-packet-queue.go +++ /dev/null @@ -1,163 +0,0 @@ -package main - -import ( - "bufio" - "fmt" - "log" - "sync" - "time" -) - -type CartPacketQueue struct { - mu sync.RWMutex - expectedPackages map[CartMessage]*CartListener -} - -const CurrentPacketVersion = 2 - -type CartListener map[CartId]Listener - -func NewCartPacketQueue(connection *PersistentConnection) *CartPacketQueue { - queue := &CartPacketQueue{ - expectedPackages: make(map[CartMessage]*CartListener), - } - go queue.HandleConnection(connection) - return queue -} - -func (p *CartPacketQueue) RemoveListeners() { - p.mu.Lock() - defer p.mu.Unlock() - for _, l := range p.expectedPackages { - for _, l := range *l { - close(l.Chan) - } - } - p.expectedPackages = make(map[CartMessage]*CartListener) -} - -func (p *CartPacketQueue) HandleConnection(connection *PersistentConnection) error { - defer p.RemoveListeners() - defer connection.Close() - var packet CartPacket - reader := bufio.NewReader(connection) - for { - err := ReadCartPacket(reader, &packet) - if err != nil { - log.Printf("Error receiving packet: %v\n", err) - return connection.HandleConnectionError(err) - } - if packet.Version != CurrentPacketVersion { - log.Printf("Incorrect version: %v\n", packet.Version) - return connection.HandleConnectionError(fmt.Errorf("incorrect version: %d", packet.Version)) - } - if packet.DataLength == 0 { - go p.HandleData(packet.MessageType, packet.Id, CallResult{ - StatusCode: packet.StatusCode, - Data: []byte{}, - }) - continue - } - data, err := GetPacketData(reader, packet.DataLength) - if err != nil { - log.Printf("Error receiving packet data: %v\n", err) - return connection.HandleConnectionError(err) - } - go p.HandleData(packet.MessageType, packet.Id, CallResult{ - StatusCode: packet.StatusCode, - Data: data, - }) - } -} - -func (p *CartPacketQueue) HandleData(t CartMessage, id CartId, data CallResult) { - p.getListener(t, id, func(l *Listener) { - l.Chan <- data - l.Count-- - }) - // p.mu.Lock() - // defer p.mu.Unlock() - // pl, ok := p.expectedPackages[t] - // if ok { - // l, ok := (*pl)[id] - // if ok { - // l.Chan <- data - // l.Count-- - // if l.Count == 0 { - // close(l.Chan) - // delete(*pl, id) - // } - // } - // } -} - -func (p *CartPacketQueue) getListener(t CartMessage, id CartId, fn func(*Listener)) { - p.mu.Lock() - defer p.mu.Unlock() - pl, ok := p.expectedPackages[t] - if ok { - l, ok := (*pl)[id] - if ok { - fn(&l) - if l.Count == 0 { - close(l.Chan) - delete(*pl, id) - } - } - } -} - -func CallResultWithTimeout(onTimeout func() CallResult) chan CallResult { - ch := make(chan CallResult, 1) - resultCh := make(chan CallResult, 1) - select { - case ret := <-resultCh: - ch <- ret - case <-time.After(300 * time.Millisecond): - ch <- onTimeout() - } - return ch -} - -func (p *CartPacketQueue) MakeChannel(messageType CartMessage, id CartId) chan CallResult { - return CallResultWithTimeout(func() CallResult { - p.getListener(messageType, id, func(l *Listener) { - l.Count-- - }) - return CallResult{ - StatusCode: 504, - Data: []byte("timeout cart call"), - } - }) -} - -func (p *CartPacketQueue) Expect(messageType CartMessage, id CartId) <-chan CallResult { - p.mu.Lock() - defer p.mu.Unlock() - l, ok := p.expectedPackages[messageType] - if ok { - if idl, idOk := (*l)[id]; idOk { - idl.Count++ - return idl.Chan - } - ch := p.MakeChannel(messageType, id) - - (*l)[id] = Listener{ - Chan: ch, - Count: 1, - } - - return ch - } - - ch := p.MakeChannel(messageType, id) - p.expectedPackages[messageType] = &CartListener{ - id: Listener{ - Chan: ch, - Count: 1, - }, - } - - return ch - -} diff --git a/grain-pool.go b/grain-pool.go index 114fba0..991cec8 100644 --- a/grain-pool.go +++ b/grain-pool.go @@ -27,8 +27,8 @@ var ( ) type GrainPool interface { - Process(id CartId, messages ...Message) (*CallResult, error) - Get(id CartId) (*CallResult, error) + Process(id CartId, messages ...Message) (*FrameWithPayload, error) + Get(id CartId) (*FrameWithPayload, error) } type Ttl struct { @@ -142,23 +142,29 @@ func (p *GrainLocalPool) GetGrain(id CartId) (*CartGrain, error) { return grain, err } -func (p *GrainLocalPool) Process(id CartId, messages ...Message) ([]byte, error) { +func (p *GrainLocalPool) Process(id CartId, messages ...Message) (*FrameWithPayload, error) { grain, err := p.GetGrain(id) + var result *FrameWithPayload if err == nil && grain != nil { for _, message := range messages { - _, err = grain.HandleMessage(&message, false) + result, err = grain.HandleMessage(&message, false) } } if err != nil { - return nil, err + return result, err } - return json.Marshal(grain) + return result, err } -func (p *GrainLocalPool) Get(id CartId) ([]byte, error) { +func (p *GrainLocalPool) Get(id CartId) (*FrameWithPayload, error) { grain, err := p.GetGrain(id) if err != nil { return nil, err } - return json.Marshal(grain) + data, err := json.Marshal(grain) + if err != nil { + return nil, err + } + ret := MakeFrameWithPayload(0, 200, data) + return &ret, nil } diff --git a/packet-queue.go b/packet-queue.go deleted file mode 100644 index ad37026..0000000 --- a/packet-queue.go +++ /dev/null @@ -1,122 +0,0 @@ -package main - -import ( - "bufio" - "fmt" - "log" - "sync" - "time" -) - -type PacketQueue struct { - mu sync.RWMutex - expectedPackages map[PoolMessage]*Listener -} - -type CallResult struct { - StatusCode uint32 - Data []byte -} - -type Listener struct { - Count int - Chan chan CallResult -} - -func NewPacketQueue(connection *PersistentConnection) *PacketQueue { - queue := &PacketQueue{ - expectedPackages: make(map[PoolMessage]*Listener), - } - go queue.HandleConnection(connection) - return queue -} - -func (p *PacketQueue) RemoveListeners() { - p.mu.Lock() - defer p.mu.Unlock() - for _, l := range p.expectedPackages { - close(l.Chan) - } - p.expectedPackages = make(map[PoolMessage]*Listener) -} - -func (p *PacketQueue) HandleConnection(connection *PersistentConnection) error { - defer connection.Close() - defer p.RemoveListeners() - var packet Packet - reader := bufio.NewReader(connection) - - for { - err := ReadPacket(reader, &packet) - if err != nil { - return connection.HandleConnectionError(err) - } - if packet.Version != CurrentPacketVersion { - log.Printf("Incorrect packet version: %v\n", packet.Version) - return connection.HandleConnectionError(fmt.Errorf("incorrect packet version: %d", packet.Version)) - } - if packet.DataLength == 0 { - go p.HandleData(packet.MessageType, CallResult{ - StatusCode: packet.StatusCode, - Data: []byte{}, - }) - continue - } - data, err := GetPacketData(reader, packet.DataLength) - if err != nil { - log.Printf("Error receiving packet data: %v\n", err) - return connection.HandleConnectionError(err) - } else { - go p.HandleData(packet.MessageType, CallResult{ - StatusCode: packet.StatusCode, - Data: data, - }) - } - } -} - -func (p *PacketQueue) HandleData(t PoolMessage, data CallResult) { - p.mu.Lock() - defer p.mu.Unlock() - l, ok := p.expectedPackages[t] - if ok { - l.Chan <- data - l.Count-- - if l.Count == 0 { - close(l.Chan) - delete(p.expectedPackages, t) - } - return - } -} - -func (p *PacketQueue) Expect(messageType PoolMessage) <-chan CallResult { - p.mu.Lock() - defer p.mu.Unlock() - l, ok := p.expectedPackages[messageType] - if ok { - l.Count++ - return l.Chan - } - - ch := make(chan CallResult, 1) - go func() { - time.Sleep(time.Millisecond * 300) - p.mu.Lock() - defer p.mu.Unlock() - - ch <- CallResult{ - StatusCode: 504, - Data: []byte("timeout cart call"), - } - - close(ch) - }() - p.expectedPackages[messageType] = &Listener{ - Count: 1, - Chan: ch, - } - - return ch - -} diff --git a/packet-queue_test.go b/packet-queue_test.go deleted file mode 100644 index b2fbdb0..0000000 --- a/packet-queue_test.go +++ /dev/null @@ -1,28 +0,0 @@ -package main - -import ( - "testing" - "time" -) - -func TestQueue(t *testing.T) { - localPool := NewGrainLocalPool(100, time.Minute, func(id CartId) (*CartGrain, error) { - return &CartGrain{ - Id: id, - storageMessages: []Message{}, - Items: []*CartItem{}, - TotalPrice: 0, - }, nil - }) - pool, err := NewSyncedPool(localPool, "localhost", nil) - if err != nil { - t.Errorf("Error creating pool: %v", err) - } - - err = pool.AddRemote("localhost") - if err != nil { - t.Errorf("Error adding remote: %v", err) - return - } - -} diff --git a/packet.go b/packet.go index 231df19..4146059 100644 --- a/packet.go +++ b/packet.go @@ -1,85 +1,77 @@ package main -import ( - "encoding/binary" - "io" -) - -type CartMessage uint32 -type PackageVersion uint32 - const ( - RemoteGetState = CartMessage(0x01) - RemoteHandleMutation = CartMessage(0x02) - ResponseBody = CartMessage(0x03) - RemoteGetStateReply = CartMessage(0x04) - RemoteHandleMutationReply = CartMessage(0x05) + RemoteGetState = FrameType(0x01) + RemoteHandleMutation = FrameType(0x02) + ResponseBody = FrameType(0x03) + RemoteGetStateReply = FrameType(0x04) + RemoteHandleMutationReply = FrameType(0x05) ) -type CartPacket struct { - Version PackageVersion - MessageType CartMessage - DataLength uint32 - StatusCode uint32 - Id CartId -} +// type CartPacket struct { +// Version PackageVersion +// MessageType CartMessage +// DataLength uint32 +// StatusCode uint32 +// Id CartId +// } -type Packet struct { - Version PackageVersion - MessageType PoolMessage - DataLength uint32 - StatusCode uint32 -} +// type Packet struct { +// Version PackageVersion +// MessageType PoolMessage +// DataLength uint32 +// StatusCode uint32 +// } -var headerData = make([]byte, 4) +// var headerData = make([]byte, 4) -func matchHeader(conn io.Reader) error { +// func matchHeader(conn io.Reader) error { - pos := 0 - for pos < 4 { +// pos := 0 +// for pos < 4 { - l, err := conn.Read(headerData) - if err != nil { - return err - } - for i := 0; i < l; i++ { - if headerData[i] == header[pos] { - pos++ - if pos == 4 { - return nil - } - } else { - pos = 0 - } - } - } - return nil -} +// l, err := conn.Read(headerData) +// if err != nil { +// return err +// } +// for i := 0; i < l; i++ { +// if headerData[i] == header[pos] { +// pos++ +// if pos == 4 { +// return nil +// } +// } else { +// pos = 0 +// } +// } +// } +// return nil +// } -func ReadPacket(conn io.Reader, packet *Packet) error { - err := matchHeader(conn) - if err != nil { - return err - } - return binary.Read(conn, binary.LittleEndian, packet) -} +// func ReadPacket(conn io.Reader, packet *Packet) error { +// err := matchHeader(conn) +// if err != nil { +// return err +// } +// return binary.Read(conn, binary.LittleEndian, packet) +// } -func ReadCartPacket(conn io.Reader, packet *CartPacket) error { - err := matchHeader(conn) - if err != nil { - return err - } - return binary.Read(conn, binary.LittleEndian, packet) -} +// func ReadCartPacket(conn io.Reader, packet *CartPacket) error { +// err := matchHeader(conn) +// if err != nil { +// return err +// } +// return binary.Read(conn, binary.LittleEndian, packet) +// } -func GetPacketData(conn io.Reader, len uint32) ([]byte, error) { - if len == 0 { - return []byte{}, nil - } - data := make([]byte, len) - _, err := conn.Read(data) - return data, err -} +// func GetPacketData(conn io.Reader, len uint32) ([]byte, error) { +// if len == 0 { +// return []byte{}, nil +// } +// data := make([]byte, len) +// _, err := conn.Read(data) +// return data, err +// } // func ReceivePacket(conn io.Reader) (uint32, []byte, error) { // var packet Packet diff --git a/pool-server.go b/pool-server.go index 54c38ff..e7f5db9 100644 --- a/pool-server.go +++ b/pool-server.go @@ -55,7 +55,7 @@ func ErrorHandler(fn func(w http.ResponseWriter, r *http.Request) error) func(w } } -func (s *PoolServer) WriteResult(w http.ResponseWriter, result *CallResult) error { +func (s *PoolServer) WriteResult(w http.ResponseWriter, result *FrameWithPayload) error { w.Header().Set("Content-Type", "application/json") w.Header().Set("X-Pod-Name", s.pod_name) if result.StatusCode != 200 { @@ -65,11 +65,11 @@ func (s *PoolServer) WriteResult(w http.ResponseWriter, result *CallResult) erro } else { w.WriteHeader(http.StatusInternalServerError) } - w.Write([]byte(result.Data)) + w.Write([]byte(result.Payload)) return nil } w.WriteHeader(http.StatusOK) - _, err := w.Write(result.Data) + _, err := w.Write(result.Payload) return err } diff --git a/remote-grain-pool.go b/remote-grain-pool.go index 8a678b2..3be28ff 100644 --- a/remote-grain-pool.go +++ b/remote-grain-pool.go @@ -46,8 +46,8 @@ func (p *RemoteGrainPool) Delete(id CartId) { p.mu.Unlock() } -func (p *RemoteGrainPool) Process(id CartId, messages ...Message) (*CallResult, error) { - var result *CallResult +func (p *RemoteGrainPool) Process(id CartId, messages ...Message) (*FrameWithPayload, error) { + var result *FrameWithPayload grain, err := p.findOrCreateGrain(id) if err != nil { return nil, err @@ -58,7 +58,7 @@ func (p *RemoteGrainPool) Process(id CartId, messages ...Message) (*CallResult, return result, err } -func (p *RemoteGrainPool) Get(id CartId) (*CallResult, error) { +func (p *RemoteGrainPool) Get(id CartId) (*FrameWithPayload, error) { grain, err := p.findOrCreateGrain(id) if err != nil { return nil, err diff --git a/remote-grain.go b/remote-grain.go index c5ff0de..8fde920 100644 --- a/remote-grain.go +++ b/remote-grain.go @@ -3,7 +3,6 @@ package main import ( "fmt" "strings" - "time" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" @@ -13,6 +12,25 @@ func (id CartId) String() string { return strings.Trim(string(id[:]), "\x00") } +type CartIdPayload struct { + Id CartId + Data []byte +} + +func MakeCartInnerFrame(id CartId, payload []byte) []byte { + return append(id[:], payload...) +} + +func GetCartFrame(data []byte) (*CartIdPayload, error) { + if len(data) < 16 { + return nil, fmt.Errorf("data too short") + } + return &CartIdPayload{ + Id: CartId(data[:16]), + Data: data[16:], + }, nil +} + func ToCartId(id string) CartId { var result [16]byte copy(result[:], []byte(id)) @@ -20,21 +38,16 @@ func ToCartId(id string) CartId { } type RemoteGrain struct { - *CartClient + *Connection Id CartId Host string } func NewRemoteGrain(id CartId, host string) (*RemoteGrain, error) { - client, err := CartDial(fmt.Sprintf("%s:1337", host)) - if err != nil { - return nil, err - } - return &RemoteGrain{ Id: id, Host: host, - CartClient: client, + Connection: NewConnection(fmt.Sprintf("%s:1337", host)), }, nil } @@ -49,47 +62,35 @@ var ( }) ) -var start time.Time +// var start time.Time -func MeasureLatency(fn func() (*CallResult, error)) (*CallResult, error) { - start = time.Now() - data, err := fn() - if err != nil { - return data, err - } - elapsed := time.Since(start).Milliseconds() - go func() { - remoteCartLatency.Add(float64(elapsed)) - remoteCartCallsTotal.Inc() - }() - return data, nil -} +// func MeasureLatency(fn func() (*CallResult, error)) (*CallResult, error) { +// start = time.Now() +// data, err := fn() +// if err != nil { +// return data, err +// } +// elapsed := time.Since(start).Milliseconds() +// go func() { +// remoteCartLatency.Add(float64(elapsed)) +// remoteCartCallsTotal.Inc() +// }() +// return data, nil +// } -func (g *RemoteGrain) HandleMessage(message *Message, isReplay bool) (*CallResult, error) { +func (g *RemoteGrain) HandleMessage(message *Message, isReplay bool) (*FrameWithPayload, error) { data, err := GetData(message.Write) if err != nil { return nil, err } - reply, err := MeasureLatency(func() (*CallResult, error) { - return g.Call(RemoteHandleMutation, g.Id, RemoteHandleMutationReply, data) - }) - - if err != nil { - return nil, err - } - - return reply, err -} - -func (g *RemoteGrain) Close() { - g.CartClient.PersistentConnection.Close() + return g.Call(RemoteHandleMutation, MakeCartInnerFrame(g.Id, data)) } func (g *RemoteGrain) GetId() CartId { return g.Id } -func (g *RemoteGrain) GetCurrentState() (*CallResult, error) { - return MeasureLatency(func() (*CallResult, error) { return g.Call(RemoteGetState, g.Id, RemoteGetStateReply, []byte{}) }) +func (g *RemoteGrain) GetCurrentState() (*FrameWithPayload, error) { + return g.Call(RemoteGetState, MakeCartInnerFrame(g.Id, nil)) } diff --git a/remote-host.go b/remote-host.go index bcd1ceb..2dff28c 100644 --- a/remote-host.go +++ b/remote-host.go @@ -7,13 +7,13 @@ import ( ) type RemoteHost struct { - *Client + *Connection Host string MissedPings int } func (h *RemoteHost) IsHealthy() bool { - return !h.PersistentConnection.Dead && h.MissedPings < 3 + return h.MissedPings < 3 } func (h *RemoteHost) Initialize(p *SyncedPool) { @@ -38,15 +38,11 @@ func (h *RemoteHost) Initialize(p *SyncedPool) { } func (h *RemoteHost) Ping() error { - _, err := h.Call(Ping, Pong, []byte{}) + result, err := h.Call(Ping, nil) - if err != nil { + if err != nil || result.StatusCode != 200 || result.Type != Pong { h.MissedPings++ log.Printf("Error pinging remote %s, missed pings: %d", h.Host, h.MissedPings) - if !h.IsHealthy() { - h.Close() - return fmt.Errorf("remote %s is dead", h.Host) - } } else { h.MissedPings = 0 } @@ -54,28 +50,28 @@ func (h *RemoteHost) Ping() error { } func (h *RemoteHost) Negotiate(knownHosts []string) ([]string, error) { - reply, err := h.Call(RemoteNegotiate, RemoteNegotiateResponse, []byte(strings.Join(knownHosts, ";"))) + reply, err := h.Call(RemoteNegotiate, []byte(strings.Join(knownHosts, ";"))) if err != nil { return nil, err } if reply.StatusCode != 200 { - return nil, fmt.Errorf("remote returned error on negotiate: %s", string(reply.Data)) + return nil, fmt.Errorf("remote returned error on negotiate: %s", string(reply.Payload)) } - return strings.Split(string(reply.Data), ";"), nil + return strings.Split(string(reply.Payload), ";"), nil } func (g *RemoteHost) GetCartMappings() ([]CartId, error) { - reply, err := g.Call(GetCartIds, CartIdsResponse, []byte{}) + reply, err := g.Call(GetCartIds, []byte{}) if err != nil { return nil, err } - if reply.StatusCode != 200 { - log.Printf("Remote returned error on get cart mappings: %s", string(reply.Data)) - return nil, fmt.Errorf("remote returned error: %s", string(reply.Data)) + if reply.StatusCode != 200 || reply.Type != CartIdsResponse { + log.Printf("Remote returned error on get cart mappings: %s", string(reply.Payload)) + return nil, fmt.Errorf("remote returned incorrect data") } - parts := strings.Split(string(reply.Data), ";") + parts := strings.Split(string(reply.Payload), ";") ids := make([]CartId, 0, len(parts)) for _, p := range parts { ids = append(ids, ToCartId(p)) @@ -84,14 +80,11 @@ func (g *RemoteHost) GetCartMappings() ([]CartId, error) { } func (r *RemoteHost) ConfirmChange(id CartId, host string) error { - reply, err := r.Call(RemoteGrainChanged, AckChange, []byte(fmt.Sprintf("%s;%s", id, host))) + reply, err := r.Call(RemoteGrainChanged, []byte(fmt.Sprintf("%s;%s", id, host))) - if err != nil { + if err != nil || reply.StatusCode != 200 || reply.Type != AckChange { return err } - if string(reply.Data) != "ok" { - return fmt.Errorf("remote grain change failed %s", string(reply.Data)) - } return nil } diff --git a/rpc-server.go b/rpc-server.go index 037431d..fcf4e81 100644 --- a/rpc-server.go +++ b/rpc-server.go @@ -6,7 +6,7 @@ import ( ) type GrainHandler struct { - *CartServer + *GenericListener pool *GrainLocalPool } @@ -20,13 +20,14 @@ func (h *GrainHandler) GetState(id CartId, reply *Grain) error { } func NewGrainHandler(pool *GrainLocalPool, listen string) (*GrainHandler, error) { - server, err := CartListen(listen) + conn := NewConnection(listen) + server, err := conn.Listen() handler := &GrainHandler{ - CartServer: server, - pool: pool, + GenericListener: server, + pool: pool, } - server.HandleCall(RemoteHandleMutation, handler.RemoteHandleMessageHandler) - server.HandleCall(RemoteGetState, handler.RemoteGetStateHandler) + server.AddHandler(RemoteHandleMutation, handler.RemoteHandleMessageHandler) + server.AddHandler(RemoteGetState, handler.RemoteGetStateHandler) return handler, err } @@ -34,29 +35,36 @@ func (h *GrainHandler) IsHealthy() bool { return len(h.pool.grains) < h.pool.PoolSize } -func (h *GrainHandler) RemoteHandleMessageHandler(id CartId, data []byte) (CartMessage, []byte, error) { +func (h *GrainHandler) RemoteHandleMessageHandler(data *FrameWithPayload, resultChan chan<- FrameWithPayload) error { + cartData, err := GetCartFrame(data.Payload) + if err != nil { + return err + } var msg Message - err := ReadMessage(bytes.NewReader(data), &msg) + err = ReadMessage(bytes.NewReader(cartData.Data), &msg) if err != nil { fmt.Println("Error reading message:", err) - return RemoteHandleMutationReply, nil, err + return err } - replyData, err := h.pool.Process(id, msg) + replyData, err := h.pool.Process(cartData.Id, msg) if err != nil { fmt.Println("Error handling message:", err) } - if err != nil { - return RemoteHandleMutationReply, nil, err - } - return RemoteHandleMutationReply, replyData, nil + resultChan <- *replyData + return nil } -func (h *GrainHandler) RemoteGetStateHandler(id CartId, data []byte) (CartMessage, []byte, error) { - reply, err := h.pool.Get(id) +func (h *GrainHandler) RemoteGetStateHandler(data *FrameWithPayload, resultChan chan<- FrameWithPayload) error { + cartData, err := GetCartFrame(data.Payload) if err != nil { - return RemoteGetStateReply, nil, err + return err } - return RemoteGetStateReply, reply, nil + reply, err := h.pool.Get(cartData.Id) + if err != nil { + return err + } + resultChan <- *reply + return nil } diff --git a/synced-pool.go b/synced-pool.go index 60a8d53..1fbb4cb 100644 --- a/synced-pool.go +++ b/synced-pool.go @@ -22,7 +22,7 @@ type HealthHandler interface { } type SyncedPool struct { - *Server + Server *GenericListener mu sync.RWMutex Hostname string local *GrainLocalPool @@ -61,11 +61,16 @@ var ( }) ) -func (p *SyncedPool) PongHandler(data []byte) (PoolMessage, []byte, error) { - return Pong, data, nil +var ( + PongResponse = MakeFrameWithPayload(Pong, 200, nil) +) + +func (p *SyncedPool) PongHandler(data *FrameWithPayload, resultChan chan<- FrameWithPayload) error { + resultChan <- PongResponse + return nil } -func (p *SyncedPool) GetCartIdHandler(data []byte) (PoolMessage, []byte, error) { +func (p *SyncedPool) GetCartIdHandler(data *FrameWithPayload, resultChan chan<- FrameWithPayload) error { ids := make([]string, 0, len(p.local.grains)) for id := range p.local.grains { if p.local.grains[id] == nil { @@ -78,45 +83,45 @@ func (p *SyncedPool) GetCartIdHandler(data []byte) (PoolMessage, []byte, error) ids = append(ids, s) } log.Printf("Returning %d cart ids\n", len(ids)) - return CartIdsResponse, []byte(strings.Join(ids, ";")), nil + resultChan <- MakeFrameWithPayload(CartIdsResponse, 200, []byte(strings.Join(ids, ";"))) + return nil } -func (p *SyncedPool) NegotiateHandler(data []byte) (PoolMessage, []byte, error) { +func (p *SyncedPool) NegotiateHandler(data *FrameWithPayload, resultChan chan<- FrameWithPayload) error { negotiationCount.Inc() log.Printf("Handling negotiation\n") - for _, host := range p.ExcludeKnown(strings.Split(string(data), ";")) { + for _, host := range p.ExcludeKnown(strings.Split(string(data.Payload), ";")) { if host == "" { continue } go p.AddRemote(host) } - - return RemoteNegotiateResponse, []byte("ok"), nil + resultChan <- MakeFrameWithPayload(RemoteNegotiateResponse, 200, []byte("ok")) + return nil } -func (p *SyncedPool) GrainOwnerChangeHandler(data []byte) (PoolMessage, []byte, error) { +func (p *SyncedPool) GrainOwnerChangeHandler(data *FrameWithPayload, resultChan chan<- FrameWithPayload) error { grainSyncCount.Inc() - idAndHostParts := strings.Split(string(data), ";") + idAndHostParts := strings.Split(string(data.Payload), ";") if len(idAndHostParts) != 2 { log.Printf("Invalid remote grain change message\n") - return AckChange, []byte("incorrect"), fmt.Errorf("invalid remote grain change message") + resultChan <- MakeFrameWithPayload(AckError, 400, []byte("invalid")) + return nil } id := ToCartId(idAndHostParts[0]) host := idAndHostParts[1] log.Printf("Handling remote grain owner change to %s for id %s\n", host, id) for _, r := range p.remotes { if r.Host == host && r.IsHealthy() { - // log.Printf("Remote grain %s changed to %s\n", id, host) - go p.SpawnRemoteGrain(id, host) - - return AckChange, []byte("ok"), nil + break } } go p.AddRemote(host) - return AckChange, []byte("ok"), nil + resultChan <- MakeFrameWithPayload(AckChange, 200, []byte("ok")) + return nil } func (p *SyncedPool) RemoveRemoteGrain(id CartId) { @@ -142,12 +147,12 @@ func (p *SyncedPool) SpawnRemoteGrain(id CartId, host string) { log.Printf("Error creating remote grain %v\n", err) return } - go func() { - <-remote.PersistentConnection.Died - p.RemoveRemoteGrain(id) - p.HandleHostError(host) - log.Printf("Remote grain %s died, host: %s\n", id.String(), host) - }() + // go func() { + // <-remote.Died + // p.RemoveRemoteGrain(id) + // p.HandleHostError(host) + // log.Printf("Remote grain %s died, host: %s\n", id.String(), host) + // }() p.mu.Lock() p.remoteIndex[id] = remote @@ -159,8 +164,6 @@ func (p *SyncedPool) HandleHostError(host string) { if r.Host == host { if !r.IsHealthy() { p.RemoveHost(r) - } else { - r.ErrorCount++ } return } @@ -169,8 +172,8 @@ func (p *SyncedPool) HandleHostError(host string) { func NewSyncedPool(local *GrainLocalPool, hostname string, discovery Discovery) (*SyncedPool, error) { listen := fmt.Sprintf("%s:1338", hostname) - - server, err := Listen(listen) + conn := NewConnection(listen) + server, err := conn.Listen() if err != nil { return nil, err } @@ -186,10 +189,10 @@ func NewSyncedPool(local *GrainLocalPool, hostname string, discovery Discovery) remoteIndex: make(map[CartId]*RemoteGrain), } - server.HandleCall(Ping, pool.PongHandler) - server.HandleCall(GetCartIds, pool.GetCartIdHandler) - server.HandleCall(RemoteNegotiate, pool.NegotiateHandler) - server.HandleCall(RemoteGrainChanged, pool.GrainOwnerChangeHandler) + server.AddHandler(Ping, pool.PongHandler) + server.AddHandler(GetCartIds, pool.GetCartIdHandler) + server.AddHandler(RemoteNegotiate, pool.NegotiateHandler) + server.AddHandler(RemoteGrainChanged, pool.GrainOwnerChangeHandler) if discovery != nil { go func() { @@ -259,18 +262,11 @@ func (p *SyncedPool) ExcludeKnown(hosts []string) []string { } func (p *SyncedPool) RemoveHost(host *RemoteHost) { - if p.remotes[host.Host] == nil { - return - } + p.mu.Lock() - defer p.mu.Unlock() - - h := p.remotes[host.Host] - if h != nil { - h.Close() - } delete(p.remotes, host.Host) - + p.mu.Unlock() + p.RemoveHostMappedCarts(host) connectedRemotes.Set(float64(len(p.remotes))) } @@ -279,24 +275,21 @@ func (p *SyncedPool) RemoveHostMappedCarts(host *RemoteHost) { defer p.mu.Unlock() for id, r := range p.remoteIndex { if r.Host == host.Host { - p.remoteIndex[id].Close() delete(p.remoteIndex, id) } } } -type PoolMessage uint32 - const ( - RemoteNegotiate = PoolMessage(3) - RemoteGrainChanged = PoolMessage(4) - AckChange = PoolMessage(5) - //AckError = PoolMessage(6) - Ping = PoolMessage(7) - Pong = PoolMessage(8) - GetCartIds = PoolMessage(9) - CartIdsResponse = PoolMessage(10) - RemoteNegotiateResponse = PoolMessage(11) + RemoteNegotiate = FrameType(3) + RemoteGrainChanged = FrameType(4) + AckChange = FrameType(5) + AckError = FrameType(6) + Ping = FrameType(7) + Pong = FrameType(8) + GetCartIds = FrameType(9) + CartIdsResponse = FrameType(10) + RemoteNegotiateResponse = FrameType(11) ) func (p *SyncedPool) Negotiate() { @@ -377,25 +370,22 @@ func (p *SyncedPool) AddRemote(host string) error { if host == "" || p.IsKnown(host) || hasHost { return nil } - client, err := Dial(fmt.Sprintf("%s:1338", host)) - if err != nil { + client := NewConnection(fmt.Sprintf("%s:1338", host)) + response, err := client.Call(Ping, nil) + if err != nil || response.StatusCode != 200 || response.Type != Pong { log.Printf("Error connecting to remote %s: %v\n", host, err) return err } remote := RemoteHost{ - Client: client, + Connection: client, MissedPings: 0, Host: host, } p.mu.Lock() p.remotes[host] = &remote p.mu.Unlock() - go func() { - <-remote.PersistentConnection.Died - log.Printf("Removing host, remote died %s", host) - p.RemoveHost(&remote) - }() + go func() { for range time.Tick(time.Second * 3) { @@ -450,9 +440,9 @@ func (p *SyncedPool) getGrain(id CartId) (Grain, error) { return localGrain, nil } -func (p *SyncedPool) Process(id CartId, messages ...Message) (*CallResult, error) { +func (p *SyncedPool) Process(id CartId, messages ...Message) (*FrameWithPayload, error) { pool, err := p.getGrain(id) - var res *CallResult + var res *FrameWithPayload if err != nil { return nil, err } @@ -465,7 +455,7 @@ func (p *SyncedPool) Process(id CartId, messages ...Message) (*CallResult, error return res, nil } -func (p *SyncedPool) Get(id CartId) (*CallResult, error) { +func (p *SyncedPool) Get(id CartId) (*FrameWithPayload, error) { grain, err := p.getGrain(id) if err != nil { return nil, err diff --git a/tcp-cart-client.go b/tcp-cart-client.go deleted file mode 100644 index 694cdae..0000000 --- a/tcp-cart-client.go +++ /dev/null @@ -1,96 +0,0 @@ -package main - -import ( - "encoding/binary" - "log" - "sync" -) - -type CartClient struct { - *CartTCPClient -} - -func CartDial(address string) (*CartClient, error) { - - mux, err := NewCartTCPClient(address) - if err != nil { - return nil, err - } - client := &CartClient{ - CartTCPClient: mux, - } - return client, nil -} - -func (c *Client) Close() { - log.Printf("Closing connection to %s\n", c.PersistentConnection.address) - c.PersistentConnection.Close() -} - -type CartTCPClient struct { - PersistentConnection *PersistentConnection - sendMux sync.Mutex - ErrorCount int - address string - *CartPacketQueue -} - -func NewCartTCPClient(address string) (*CartTCPClient, error) { - connection, err := NewPersistentConnection(address) - if err != nil { - return nil, err - } - return &CartTCPClient{ - ErrorCount: 0, - PersistentConnection: connection, - address: address, - CartPacketQueue: NewCartPacketQueue(connection), - }, nil -} - -func (m *CartTCPClient) SendPacket(messageType CartMessage, id CartId, data []byte) error { - m.sendMux.Lock() - defer m.sendMux.Unlock() - m.PersistentConnection.Conn.Write(header[:]) - err := binary.Write(m.PersistentConnection, binary.LittleEndian, CartPacket{ - Version: CurrentPacketVersion, - MessageType: messageType, - DataLength: uint32(len(data)), - Id: id, - }) - if err != nil { - return m.PersistentConnection.HandleConnectionError(err) - } - _, err = m.PersistentConnection.Write(data) - return m.PersistentConnection.HandleConnectionError(err) -} - -func (m *CartTCPClient) call(messageType CartMessage, id CartId, responseType CartMessage, data []byte) (*CallResult, error) { - packetChan := m.Expect(responseType, id) - err := m.SendPacket(messageType, id, data) - if err != nil { - return nil, m.PersistentConnection.HandleConnectionError(err) - } - - ret := <-packetChan - return &ret, nil -} - -func isRetirableError(err error) bool { - log.Printf("is retryable error: %v", err) - return false -} - -func (m *CartTCPClient) Call(messageType CartMessage, id CartId, responseType CartMessage, data []byte) (*CallResult, error) { - retries := 0 - result, err := m.call(messageType, id, responseType, data) - for err != nil && retries < 3 { - if !isRetirableError(err) { - break - } - retries++ - log.Printf("Retrying call to %d\n", messageType) - result, err = m.call(messageType, id, responseType, data) - } - return result, err -} diff --git a/tcp-cart-mux-server.go b/tcp-cart-mux-server.go deleted file mode 100644 index 0ca4100..0000000 --- a/tcp-cart-mux-server.go +++ /dev/null @@ -1,160 +0,0 @@ -package main - -import ( - "bufio" - "encoding/binary" - "io" - "log" - "net" - "sync" -) - -type CartServer struct { - *TCPCartServerMux -} - -func CartListen(address string) (*CartServer, error) { - listener, err := net.Listen("tcp", address) - server := &CartServer{ - NewCartTCPServerMux(), - } - - if err != nil { - return nil, err - } - go func() { - for { - conn, err := listener.Accept() - if err != nil { - log.Printf("Error accepting connection: %v\n", err) - continue - } - go server.HandleConnection(conn) - } - }() - return server, nil -} - -type TCPCartServerMux struct { - mu sync.RWMutex - sendMux sync.Mutex - listeners map[CartMessage]func(CartId, []byte) error - functions map[CartMessage]func(CartId, []byte) (CartMessage, []byte, error) -} - -func NewCartTCPServerMux() *TCPCartServerMux { - m := &TCPCartServerMux{ - mu: sync.RWMutex{}, - listeners: make(map[CartMessage]func(CartId, []byte) error), - functions: make(map[CartMessage]func(CartId, []byte) (CartMessage, []byte, error)), - } - - return m -} - -func (m *TCPCartServerMux) handleListener(messageType CartMessage, id CartId, data []byte) (bool, error) { - m.mu.RLock() - handler, ok := m.listeners[messageType] - m.mu.RUnlock() - if ok { - err := handler(id, data) - if err != nil { - return true, err - } - } - return false, nil -} - -func (m *TCPCartServerMux) handleFunction(connection net.Conn, messageType CartMessage, id CartId, data []byte) (bool, error) { - m.mu.RLock() - fn, ok := m.functions[messageType] - m.mu.RUnlock() - m.sendMux.Lock() - defer m.sendMux.Unlock() - if ok { - responseType, responseData, err := fn(id, data) - connection.Write(header[:]) - if err != nil { - errData := []byte(err.Error()) - err = binary.Write(connection, binary.LittleEndian, CartPacket{ - Version: CurrentPacketVersion, - MessageType: responseType, - DataLength: uint32(len(errData)), - StatusCode: 500, - Id: id, - }) - _, err = connection.Write(errData) - return true, err - } - err = binary.Write(connection, binary.LittleEndian, CartPacket{ - Version: CurrentPacketVersion, - MessageType: responseType, - DataLength: uint32(len(responseData)), - StatusCode: 200, - Id: id, - }) - if err != nil { - return true, err - } - packetsSent.Inc() - _, err = connection.Write(responseData) - return true, err - } else { - log.Printf("No cart handler for type: %d\n", messageType) - } - return false, nil -} - -func (m *TCPCartServerMux) HandleConnection(connection net.Conn) error { - var packet CartPacket - var err error - defer connection.Close() - reader := bufio.NewReader(connection) - for { - err = ReadCartPacket(reader, &packet) - if err != nil { - if err == io.EOF { - return nil - } - log.Printf("Error receiving packet: %v\n", err) - return err - } - if packet.Version != CurrentPacketVersion { - log.Printf("Incorrect packet version: %d\n", packet.Version) - continue - } - data, err := GetPacketData(reader, packet.DataLength) - if err != nil { - log.Printf("Error getting packet data: %v\n", err) - } - go m.HandleData(connection, packet.MessageType, packet.Id, data) - } -} - -func (m *TCPCartServerMux) HandleData(connection net.Conn, t CartMessage, id CartId, data []byte) { - status, err := m.handleListener(t, id, data) - if err != nil { - log.Printf("Error handling listener: %v\n", err) - } - if !status { - status, err = m.handleFunction(connection, t, id, data) - if err != nil { - log.Printf("Error handling function: %v\n", err) - } - if !status { - log.Printf("Unknown message type: %d\n", t) - } - } -} - -func (m *TCPCartServerMux) ListenFor(messageType CartMessage, handler func(CartId, []byte) error) { - m.mu.Lock() - m.listeners[messageType] = handler - m.mu.Unlock() -} - -func (m *TCPCartServerMux) HandleCall(messageType CartMessage, handler func(CartId, []byte) (CartMessage, []byte, error)) { - m.mu.Lock() - m.functions[messageType] = handler - m.mu.Unlock() -} diff --git a/tcp-cart_test.go b/tcp-cart_test.go deleted file mode 100644 index 0f11cf8..0000000 --- a/tcp-cart_test.go +++ /dev/null @@ -1,57 +0,0 @@ -package main - -import ( - "fmt" - "log" - "testing" -) - -func TestCartTcpHelpers(t *testing.T) { - - server, err := CartListen("localhost:51337") - if err != nil { - t.Errorf("Error listening: %v\n", err) - } - client, err := CartDial("localhost:51337") - if err != nil { - t.Errorf("Error dialing: %v\n", err) - } - var messageData string - server.ListenFor(1, func(id CartId, data []byte) error { - log.Printf("Received message: %s\n", string(data)) - messageData = string(data) - return nil - }) - server.HandleCall(666, func(id CartId, data []byte) (CartMessage, []byte, error) { - log.Printf("Received 666 call: %s\n", string(data)) - return 3, []byte("Hello, client!"), fmt.Errorf("Det blev fel") - }) - server.HandleCall(2, func(id CartId, data []byte) (CartMessage, []byte, error) { - log.Printf("Received 2 call: %s\n", string(data)) - return 4, []byte("Hello, client!"), nil - }) - // server.HandleCall(Ping, func(id CartId, data []byte) (CartMessage, []byte, error) { - // return Pong, nil, nil - // }) - id := ToCartId("kalle") - client.SendPacket(1, id, []byte("Hello, world!")) - answer, err := client.Call(2, id, 4, []byte("Hello, server!")) - if err != nil { - t.Errorf("Error calling: %v\n", err) - } - s, err := client.Call(666, id, 3, []byte("Hello, server!")) - client.PersistentConnection.Close() - if err != nil { - t.Errorf("Error calling: %v\n", err) - } - if s.StatusCode != 500 { - t.Errorf("Expected 500, got %d\n", s.StatusCode) - } - - if string(answer.Data) != "Hello, client!" { - t.Errorf("Expected answer 'Hello, client!', got %s\n", string(answer.Data)) - } - if messageData != "Hello, world!" { - t.Errorf("Expected message 'Hello, world!', got %s\n", messageData) - } -} diff --git a/tcp-client.go b/tcp-client.go deleted file mode 100644 index 498085b..0000000 --- a/tcp-client.go +++ /dev/null @@ -1,144 +0,0 @@ -package main - -import ( - "encoding/binary" - "log" - "net" - "sync" - "time" -) - -type Client struct { - *TCPClient -} - -func Dial(address string) (*Client, error) { - - mux, err := NewTCPClient(address) - if err != nil { - return nil, err - } - client := &Client{ - TCPClient: mux, - } - return client, nil -} - -type TCPClient struct { - PersistentConnection *PersistentConnection - sendMux sync.Mutex - ErrorCount int - address string - *PacketQueue -} - -type PersistentConnection struct { - net.Conn - Died chan bool - Dead bool - address string -} - -func NewPersistentConnection(address string) (*PersistentConnection, error) { - - p := &PersistentConnection{ - Died: make(chan bool, 1), - Dead: false, - address: address, - } - err := p.Connect() - if err != nil { - return nil, err - } - return p, nil -} - -func (m *PersistentConnection) Connect() error { - fails := 0 - for { - connection, err := net.Dial("tcp", m.address) - if err != nil { - log.Printf("Can't connect to %s: %v, count: %d", m.address, err, fails) - fails++ - if fails > 15 { - log.Printf("Too many connection failures, closing connection to %s", m.address) - m.Died <- true - m.Dead = true - return err - } - } else { - m.Conn = connection - break - } - time.Sleep(time.Millisecond * 300) - - } - - return nil -} - -func (m *PersistentConnection) Close() { - log.Printf("Closing connection to %s\n", m.address) - m.Conn.Close() - m.Died <- true - m.Dead = true -} - -func (m *PersistentConnection) HandleConnectionError(err error) error { - if err != nil { - log.Printf("Error from to %s: %v\n", m.address, err) - m.Conn.Close() - m.Connect() - } - return err -} - -func NewTCPClient(address string) (*TCPClient, error) { - - connection, err := NewPersistentConnection(address) - if err != nil { - return nil, err - } - return &TCPClient{ - ErrorCount: 0, - PersistentConnection: connection, - address: address, - PacketQueue: NewPacketQueue(connection), - }, nil -} - -type PacketHeader [4]byte - -var ( - header = PacketHeader([4]byte{0x01, 0x02, 0x03, 0x04}) -) - -func (m *TCPClient) SendPacket(messageType PoolMessage, data []byte) error { - m.sendMux.Lock() - defer m.sendMux.Unlock() - m.PersistentConnection.Write(header[:]) - err := binary.Write(m.PersistentConnection, binary.LittleEndian, Packet{ - Version: CurrentPacketVersion, - MessageType: messageType, - StatusCode: 0, - DataLength: uint32(len(data)), - }) - if err != nil { - return m.PersistentConnection.HandleConnectionError(err) - } - _, err = m.PersistentConnection.Write(data) - return m.PersistentConnection.HandleConnectionError(err) -} - -func (m *TCPClient) Call(messageType PoolMessage, responseType PoolMessage, data []byte) (*CallResult, error) { - packetChan := m.Expect(responseType) - err := m.SendPacket(messageType, data) - if err != nil { - m.RemoveListeners() - return nil, m.PersistentConnection.HandleConnectionError(err) - } - - ret := <-packetChan - return &ret, nil - -} diff --git a/tcp-connection.go b/tcp-connection.go index 618c43a..2a2544a 100644 --- a/tcp-connection.go +++ b/tcp-connection.go @@ -15,12 +15,23 @@ type Connection struct { } type FrameType uint32 +type StatusCode uint32 +type CheckSum uint32 type Frame struct { - Id uint64 Type FrameType - StatusCode uint32 + StatusCode StatusCode Length uint32 + Checksum CheckSum +} + +func (f *Frame) IsValid() bool { + return f.Checksum == MakeChecksum(f.Type, f.StatusCode, f.Length) +} + +func MakeChecksum(msg FrameType, statusCode StatusCode, length uint32) CheckSum { + sum := CheckSum((uint32(msg) + uint32(statusCode) + length) / 8) + return sum } type FrameWithPayload struct { @@ -28,6 +39,19 @@ type FrameWithPayload struct { Payload []byte } +func MakeFrameWithPayload(msg FrameType, statusCode StatusCode, payload []byte) FrameWithPayload { + len := uint32(len(payload)) + return FrameWithPayload{ + Frame: Frame{ + Type: msg, + StatusCode: 0, + Length: len, + Checksum: MakeChecksum(msg, 0, len), + }, + Payload: payload, + } +} + type FrameData interface { ToBytes() []byte FromBytes([]byte) error @@ -41,11 +65,7 @@ func NewConnection(address string) *Connection { } func SendFrame(conn net.Conn, data *FrameWithPayload) error { - _, err := conn.Write(header[:]) - if err != nil { - return err - } - err = binary.Write(conn, binary.LittleEndian, data.Frame) + err := binary.Write(conn, binary.LittleEndian, data.Frame) if err != nil { return err } @@ -53,68 +73,67 @@ func SendFrame(conn net.Conn, data *FrameWithPayload) error { return err } -func (c *Connection) CallAsync(msg FrameType, data FrameData, ch chan<- *FrameWithPayload) error { +func (c *Connection) CallAsync(msg FrameType, payload []byte, ch chan<- FrameWithPayload) (net.Conn, error) { conn, err := net.Dial("tcp", c.address) go WaitForFrame(conn, ch) if err != nil { - return err - } - payload := data.ToBytes() - toSend := &FrameWithPayload{ - Frame: Frame{ - Id: c.count, - Type: msg, - StatusCode: 0, - Length: uint32(len(payload)), - }, - Payload: payload, + return conn, err } + toSend := MakeFrameWithPayload(msg, 1, payload) - err = SendFrame(conn, toSend) + err = SendFrame(conn, &toSend) if err != nil { + conn.Close() close(ch) - return err + return nil, err } c.count++ - return nil + return conn, nil } -func (c *Connection) Call(msg FrameType, data FrameData) (*FrameWithPayload, error) { - ch := make(chan *FrameWithPayload, 1) - c.CallAsync(msg, data, ch) +func (c *Connection) Call(msg FrameType, data []byte) (*FrameWithPayload, error) { + ch := make(chan FrameWithPayload, 1) + conn, err := c.CallAsync(msg, data, ch) + if err != nil { + return nil, err + } + defer conn.Close() select { case ret := <-ch: - return ret, nil - case <-time.After(5 * time.Second): + return &ret, nil + case <-time.After(MaxCallDuration): return nil, fmt.Errorf("timeout") } } -func WaitForFrame(conn net.Conn, resultChan chan<- *FrameWithPayload) error { - defer conn.Close() +func WaitForFrame(conn net.Conn, resultChan chan<- FrameWithPayload) error { var err error + var frame Frame r := bufio.NewReader(conn) - h := make([]byte, 4) - r.Read(h) - if h[0] == header[0] && h[1] == header[1] && h[2] == header[2] && h[3] == header[3] { - frame := Frame{} - err = binary.Read(r, binary.LittleEndian, &frame) + + err = binary.Read(r, binary.LittleEndian, &frame) + if err != nil { + return err + } + if frame.IsValid() { payload := make([]byte, frame.Length) _, err = r.Read(payload) - resultChan <- &FrameWithPayload{ + if err != nil { + return err + } + resultChan <- FrameWithPayload{ Frame: frame, Payload: payload, } return err } - resultChan <- nil - return err + return fmt.Errorf("checksum mismatch") } type GenericListener struct { Closed bool - handlers map[FrameType]func(*FrameWithPayload, chan<- *FrameWithPayload) error + handlers map[FrameType]func(*FrameWithPayload, chan<- FrameWithPayload) error } func (c *Connection) Listen() (*GenericListener, error) { @@ -123,7 +142,7 @@ func (c *Connection) Listen() (*GenericListener, error) { return nil, err } ret := &GenericListener{ - handlers: make(map[FrameType]func(*FrameWithPayload, chan<- *FrameWithPayload) error), + handlers: make(map[FrameType]func(*FrameWithPayload, chan<- FrameWithPayload) error), } go func() { for !ret.Closed { @@ -137,36 +156,44 @@ func (c *Connection) Listen() (*GenericListener, error) { return ret, nil } +const ( + MaxCallDuration = 500 * time.Millisecond +) + func (l *GenericListener) HandleConnection(conn net.Conn) { - ch := make(chan *FrameWithPayload, 1) + ch := make(chan FrameWithPayload, 1) go WaitForFrame(conn, ch) select { case frame := <-ch: - go l.HandleFrame(conn, frame) - case <-time.After(1 * time.Second): + go l.HandleFrame(conn, &frame) + case <-time.After(MaxCallDuration): close(ch) log.Printf("Timeout waiting for frame\n") } } -func (l *GenericListener) AddHandler(msg FrameType, handler func(*FrameWithPayload, chan<- *FrameWithPayload) error) { +func (l *GenericListener) AddHandler(msg FrameType, handler func(*FrameWithPayload, chan<- FrameWithPayload) error) { l.handlers[msg] = handler } func (l *GenericListener) HandleFrame(conn net.Conn, frame *FrameWithPayload) { handler, ok := l.handlers[frame.Type] - defer conn.Close() if ok { go func() { - resultChan := make(chan *FrameWithPayload, 1) + resultChan := make(chan FrameWithPayload, 1) defer close(resultChan) err := handler(frame, resultChan) if err != nil { log.Fatalf("Error handling frame: %v\n", err) } - SendFrame(conn, <-resultChan) + result := <-resultChan + err = SendFrame(conn, &result) + if err != nil { + log.Fatalf("Error sending frame: %v\n", err) + } }() } else { + conn.Close() log.Fatalf("No handler for frame type %d\n", frame.Type) } } diff --git a/tcp-connection_test.go b/tcp-connection_test.go index d208f37..e7714eb 100644 --- a/tcp-connection_test.go +++ b/tcp-connection_test.go @@ -2,37 +2,19 @@ package main import "testing" -type StringData string - -func (s StringData) ToBytes() []byte { - return []byte(s) -} - -func (s StringData) FromBytes(data []byte) error { - s = StringData(data) - return nil -} - func TestGenericConnection(t *testing.T) { conn := NewConnection("localhost:51337") listener, err := conn.Listen() if err != nil { t.Errorf("Error listening: %v\n", err) } - listener.AddHandler(1, func(input *FrameWithPayload, resultChan chan<- *FrameWithPayload) error { - payload := []byte("Hello, world!") - resultChan <- &FrameWithPayload{ - Frame: Frame{ - Type: 2, - Id: input.Id, - StatusCode: 200, - Length: uint32(len("Hello, world!")), - }, - Payload: payload, - } + datta := []byte("Hello, world!") + listener.AddHandler(1, func(input *FrameWithPayload, resultChan chan<- FrameWithPayload) error { + + resultChan <- MakeFrameWithPayload(2, 200, datta) return nil }) - r, err := conn.Call(1, StringData("Hello, world!")) + r, err := conn.Call(1, datta) if err != nil { t.Errorf("Error calling: %v\n", err) } @@ -40,9 +22,9 @@ func TestGenericConnection(t *testing.T) { t.Errorf("Expected type 2, got %d\n", r.Type) } i := 100 - results := make(chan *FrameWithPayload, i) + results := make(chan FrameWithPayload, i) for i > 0 { - conn.CallAsync(1, StringData("Hello, world!"), results) + go conn.CallAsync(1, datta, results) i-- } for i < 100 { diff --git a/tcp-mux-server.go b/tcp-mux-server.go deleted file mode 100644 index a5c9c28..0000000 --- a/tcp-mux-server.go +++ /dev/null @@ -1,161 +0,0 @@ -package main - -import ( - "bufio" - "encoding/binary" - "io" - "log" - "net" - "sync" -) - -type Server struct { - *TCPServerMux -} - -func Listen(address string) (*Server, error) { - listener, err := net.Listen("tcp", address) - server := &Server{ - NewTCPServerMux(), - } - - if err != nil { - return nil, err - } - go func() { - for { - conn, err := listener.Accept() - if err != nil { - log.Printf("Error accepting connection: %v\n", err) - continue - } - go server.HandleConnection(conn) - } - }() - return server, nil -} - -type TCPServerMux struct { - mu sync.RWMutex - sendMux sync.Mutex - listeners map[PoolMessage]func(data []byte) error - functions map[PoolMessage]func(data []byte) (PoolMessage, []byte, error) -} - -func NewTCPServerMux() *TCPServerMux { - m := &TCPServerMux{ - mu: sync.RWMutex{}, - listeners: make(map[PoolMessage]func(data []byte) error), - functions: make(map[PoolMessage]func(data []byte) (PoolMessage, []byte, error)), - } - - return m -} - -func (m *TCPServerMux) handleListener(messageType PoolMessage, data []byte) (bool, error) { - m.mu.RLock() - handler, ok := m.listeners[messageType] - m.mu.RUnlock() - if ok { - err := handler(data) - if err != nil { - return true, err - } - } - return false, nil -} - -func (m *TCPServerMux) handleFunction(connection net.Conn, messageType PoolMessage, data []byte) (bool, error) { - m.mu.RLock() - function, ok := m.functions[messageType] - m.mu.RUnlock() - m.sendMux.Lock() - defer m.sendMux.Unlock() - if ok { - connection.Write(header[:]) - responseType, responseData, err := function(data) - if err != nil { - errData := []byte(err.Error()) - err = binary.Write(connection, binary.LittleEndian, Packet{ - Version: CurrentPacketVersion, - MessageType: responseType, - StatusCode: 500, - DataLength: uint32(len(errData)), - }) - _, err = connection.Write(errData) - return true, err - } - err = binary.Write(connection, binary.LittleEndian, Packet{ - Version: CurrentPacketVersion, - MessageType: responseType, - StatusCode: 200, - DataLength: uint32(len(responseData)), - }) - if err != nil { - return true, err - } - packetsSent.Inc() - _, err = connection.Write(responseData) - return true, err - } else { - log.Printf("No pool handler for type: %d\n", messageType) - } - return false, nil -} - -func (m *TCPServerMux) HandleConnection(connection net.Conn) error { - - defer connection.Close() - var packet Packet - reader := bufio.NewReader(connection) - for { - err := ReadPacket(reader, &packet) - if err != nil { - if err == io.EOF { - return nil - } - log.Printf("Error receiving packet: %v\n", err) - return err - } - if packet.Version != CurrentPacketVersion { - log.Printf("Incorrect package version: %v\n", err) - continue - } - data, err := GetPacketData(reader, packet.DataLength) - if err != nil { - log.Printf("Error receiving packet data: %v\n", err) - return err - } - go m.HandleData(connection, packet.MessageType, data) - } -} - -func (m *TCPServerMux) HandleData(connection net.Conn, t PoolMessage, data []byte) { - // listener := m.listeners[t] - // handler := m.functions[t] - status, err := m.handleListener(t, data) - if err != nil { - log.Printf("Error handling listener: %v\n", err) - } - if !status { - status, err = m.handleFunction(connection, t, data) - if err != nil { - log.Printf("Error handling function: %v\n", err) - } - if !status { - log.Printf("Unknown message type: %d\n", t) - } - } -} - -func (m *TCPServerMux) ListenFor(messageType PoolMessage, handler func(data []byte) error) { - m.mu.Lock() - m.listeners[messageType] = handler - m.mu.Unlock() -} - -func (m *TCPServerMux) HandleCall(messageType PoolMessage, handler func(data []byte) (PoolMessage, []byte, error)) { - m.mu.Lock() - m.functions[messageType] = handler - m.mu.Unlock() -} diff --git a/tcp_test.go b/tcp_test.go deleted file mode 100644 index 849c507..0000000 --- a/tcp_test.go +++ /dev/null @@ -1,54 +0,0 @@ -package main - -import ( - "log" - "testing" -) - -func TestTcpHelpers(t *testing.T) { - - server, err := Listen("localhost:51337") - if err != nil { - t.Errorf("Error listening: %v\n", err) - } - client, err := Dial("localhost:51337") - if err != nil { - t.Errorf("Error dialing: %v\n", err) - } - var messageData string - server.ListenFor(1, func(data []byte) error { - log.Printf("Received message: %s\n", string(data)) - messageData = string(data) - return nil - }) - server.HandleCall(2, func(data []byte) (PoolMessage, []byte, error) { - log.Printf("Received call: %s\n", string(data)) - return 3, []byte("Hello, client!"), nil - }) - server.HandleCall(Ping, func(data []byte) (PoolMessage, []byte, error) { - return Pong, nil, nil - }) - - client.SendPacket(1, []byte("Hello, world!")) - answer, err := client.Call(2, 3, []byte("Hello, server!")) - if err != nil { - t.Errorf("Error calling: %v\n", err) - } - for i := 0; i < 100; i++ { - _, err = client.Call(Ping, Pong, nil) - if err != nil { - t.Errorf("Error calling: %v\n", err) - } - } - _, err = client.Call(Ping, Pong, nil) - if err != nil { - t.Errorf("Error calling: %v\n", err) - } - client.Close() - if string(answer.Data) != "Hello, client!" { - t.Errorf("Expected answer 'Hello, client!', got %s\n", string(answer.Data)) - } - if messageData != "Hello, world!" { - t.Errorf("Expected message 'Hello, world!', got %s\n", messageData) - } -}