diff --git a/cart-packet-queue.go b/cart-packet-queue.go index de6ded8..092f4bb 100644 --- a/cart-packet-queue.go +++ b/cart-packet-queue.go @@ -10,7 +10,7 @@ import ( type CartPacketQueue struct { mu sync.RWMutex - expectedPackages map[uint32]*CartListener + expectedPackages map[CartMessage]*CartListener } const CurrentPacketVersion = 2 @@ -20,7 +20,7 @@ type CartListener map[CartId]Listener func NewCartPacketQueue(connection *PersistentConnection) *CartPacketQueue { queue := &CartPacketQueue{ - expectedPackages: make(map[uint32]*CartListener), + expectedPackages: make(map[CartMessage]*CartListener), } go queue.HandleConnection(connection) return queue @@ -34,7 +34,7 @@ func (p *CartPacketQueue) RemoveListeners() { close(l.Chan) } } - p.expectedPackages = make(map[uint32]*CartListener) + p.expectedPackages = make(map[CartMessage]*CartListener) } func (p *CartPacketQueue) HandleConnection(connection *PersistentConnection) error { @@ -73,7 +73,7 @@ func (p *CartPacketQueue) HandleConnection(connection *PersistentConnection) err } } -func (p *CartPacketQueue) HandleData(t uint32, id CartId, data CallResult) { +func (p *CartPacketQueue) HandleData(t CartMessage, id CartId, data CallResult) { p.mu.Lock() defer p.mu.Unlock() pl, ok := p.expectedPackages[t] @@ -90,7 +90,7 @@ func (p *CartPacketQueue) HandleData(t uint32, id CartId, data CallResult) { } } -func (p *CartPacketQueue) Expect(messageType uint32, id CartId) <-chan CallResult { +func (p *CartPacketQueue) Expect(messageType CartMessage, id CartId) <-chan CallResult { p.mu.Lock() defer p.mu.Unlock() l, ok := p.expectedPackages[messageType] diff --git a/packet-queue.go b/packet-queue.go index 18a14be..c16a80f 100644 --- a/packet-queue.go +++ b/packet-queue.go @@ -9,7 +9,7 @@ import ( type PacketQueue struct { mu sync.RWMutex - expectedPackages map[uint32]*Listener + expectedPackages map[PoolMessage]*Listener } type CallResult struct { @@ -24,7 +24,7 @@ type Listener struct { func NewPacketQueue(connection *PersistentConnection) *PacketQueue { queue := &PacketQueue{ - expectedPackages: make(map[uint32]*Listener), + expectedPackages: make(map[PoolMessage]*Listener), } go queue.HandleConnection(connection) return queue @@ -36,7 +36,7 @@ func (p *PacketQueue) RemoveListeners() { for _, l := range p.expectedPackages { close(l.Chan) } - p.expectedPackages = make(map[uint32]*Listener) + p.expectedPackages = make(map[PoolMessage]*Listener) } func (p *PacketQueue) HandleConnection(connection *PersistentConnection) error { @@ -74,7 +74,7 @@ func (p *PacketQueue) HandleConnection(connection *PersistentConnection) error { } } -func (p *PacketQueue) HandleData(t uint32, data CallResult) { +func (p *PacketQueue) HandleData(t PoolMessage, data CallResult) { p.mu.Lock() defer p.mu.Unlock() l, ok := p.expectedPackages[t] @@ -89,7 +89,7 @@ func (p *PacketQueue) HandleData(t uint32, data CallResult) { } } -func (p *PacketQueue) Expect(messageType uint32) <-chan CallResult { +func (p *PacketQueue) Expect(messageType PoolMessage) <-chan CallResult { p.mu.Lock() defer p.mu.Unlock() l, ok := p.expectedPackages[messageType] diff --git a/packet.go b/packet.go index 8438f60..e2273b0 100644 --- a/packet.go +++ b/packet.go @@ -5,25 +5,28 @@ import ( "io" ) +type CartMessage uint32 +type PackageVersion uint32 + const ( - RemoteGetState = uint32(0x01) - RemoteHandleMutation = uint32(0x02) - ResponseBody = uint32(0x03) - RemoteGetStateReply = uint32(0x04) - RemoteHandleMutationReply = uint32(0x05) + RemoteGetState = CartMessage(0x01) + RemoteHandleMutation = CartMessage(0x02) + ResponseBody = CartMessage(0x03) + RemoteGetStateReply = CartMessage(0x04) + RemoteHandleMutationReply = CartMessage(0x05) ) type CartPacket struct { - Version uint32 - MessageType uint32 + Version PackageVersion + MessageType CartMessage DataLength uint32 StatusCode uint32 Id CartId } type Packet struct { - Version uint32 - MessageType uint32 + Version PackageVersion + MessageType PoolMessage DataLength uint32 StatusCode uint32 } diff --git a/rpc-server.go b/rpc-server.go index 2de1be3..037431d 100644 --- a/rpc-server.go +++ b/rpc-server.go @@ -34,7 +34,7 @@ func (h *GrainHandler) IsHealthy() bool { return len(h.pool.grains) < h.pool.PoolSize } -func (h *GrainHandler) RemoteHandleMessageHandler(id CartId, data []byte) (uint32, []byte, error) { +func (h *GrainHandler) RemoteHandleMessageHandler(id CartId, data []byte) (CartMessage, []byte, error) { var msg Message err := ReadMessage(bytes.NewReader(data), &msg) if err != nil { @@ -53,7 +53,7 @@ func (h *GrainHandler) RemoteHandleMessageHandler(id CartId, data []byte) (uint3 return RemoteHandleMutationReply, replyData, nil } -func (h *GrainHandler) RemoteGetStateHandler(id CartId, data []byte) (uint32, []byte, error) { +func (h *GrainHandler) RemoteGetStateHandler(id CartId, data []byte) (CartMessage, []byte, error) { reply, err := h.pool.Get(id) if err != nil { return RemoteGetStateReply, nil, err diff --git a/synced-pool.go b/synced-pool.go index 30c055d..385cbf0 100644 --- a/synced-pool.go +++ b/synced-pool.go @@ -61,11 +61,11 @@ var ( }) ) -func (p *SyncedPool) PongHandler(data []byte) (uint32, []byte, error) { +func (p *SyncedPool) PongHandler(data []byte) (PoolMessage, []byte, error) { return Pong, data, nil } -func (p *SyncedPool) GetCartIdHandler(data []byte) (uint32, []byte, error) { +func (p *SyncedPool) GetCartIdHandler(data []byte) (PoolMessage, []byte, error) { ids := make([]string, 0, len(p.local.grains)) for id := range p.local.grains { if p.local.grains[id] == nil { @@ -81,7 +81,7 @@ func (p *SyncedPool) GetCartIdHandler(data []byte) (uint32, []byte, error) { return CartIdsResponse, []byte(strings.Join(ids, ";")), nil } -func (p *SyncedPool) NegotiateHandler(data []byte) (uint32, []byte, error) { +func (p *SyncedPool) NegotiateHandler(data []byte) (PoolMessage, []byte, error) { negotiationCount.Inc() log.Printf("Handling negotiation\n") for _, host := range p.ExcludeKnown(strings.Split(string(data), ";")) { @@ -95,7 +95,7 @@ func (p *SyncedPool) NegotiateHandler(data []byte) (uint32, []byte, error) { return RemoteNegotiateResponse, []byte("ok"), nil } -func (p *SyncedPool) GrainOwnerChangeHandler(data []byte) (uint32, []byte, error) { +func (p *SyncedPool) GrainOwnerChangeHandler(data []byte) (PoolMessage, []byte, error) { grainSyncCount.Inc() idAndHostParts := strings.Split(string(data), ";") @@ -276,16 +276,18 @@ func (p *SyncedPool) RemoveHostMappedCarts(host *RemoteHost) { } } +type PoolMessage uint32 + const ( - RemoteNegotiate = uint32(3) - RemoteGrainChanged = uint32(4) - AckChange = uint32(5) - //AckError = uint32(6) - Ping = uint32(7) - Pong = uint32(8) - GetCartIds = uint32(9) - CartIdsResponse = uint32(10) - RemoteNegotiateResponse = uint32(11) + RemoteNegotiate = PoolMessage(3) + RemoteGrainChanged = PoolMessage(4) + AckChange = PoolMessage(5) + //AckError = PoolMessage(6) + Ping = PoolMessage(7) + Pong = PoolMessage(8) + GetCartIds = PoolMessage(9) + CartIdsResponse = PoolMessage(10) + RemoteNegotiateResponse = PoolMessage(11) ) func (p *SyncedPool) Negotiate() { diff --git a/tcp-cart-client.go b/tcp-cart-client.go index 09ac905..36458c1 100644 --- a/tcp-cart-client.go +++ b/tcp-cart-client.go @@ -47,7 +47,7 @@ func NewCartTCPClient(address string) (*CartTCPClient, error) { }, nil } -func (m *CartTCPClient) SendPacket(messageType uint32, id CartId, data []byte) error { +func (m *CartTCPClient) SendPacket(messageType CartMessage, id CartId, data []byte) error { err := binary.Write(m.Conn, binary.LittleEndian, CartPacket{ Version: CurrentPacketVersion, @@ -62,7 +62,7 @@ func (m *CartTCPClient) SendPacket(messageType uint32, id CartId, data []byte) e return m.HandleConnectionError(err) } -func (m *CartTCPClient) Call(messageType uint32, id CartId, responseType uint32, data []byte) (*CallResult, error) { +func (m *CartTCPClient) Call(messageType CartMessage, id CartId, responseType CartMessage, data []byte) (*CallResult, error) { packetChan := m.Expect(responseType, id) err := m.SendPacket(messageType, id, data) if err != nil { diff --git a/tcp-cart-mux-server.go b/tcp-cart-mux-server.go index a9f27aa..7a58f48 100644 --- a/tcp-cart-mux-server.go +++ b/tcp-cart-mux-server.go @@ -37,21 +37,21 @@ func CartListen(address string) (*CartServer, error) { type TCPCartServerMux struct { mu sync.RWMutex - listeners map[uint32]func(CartId, []byte) error - functions map[uint32]func(CartId, []byte) (uint32, []byte, error) + listeners map[CartMessage]func(CartId, []byte) error + functions map[CartMessage]func(CartId, []byte) (CartMessage, []byte, error) } func NewCartTCPServerMux() *TCPCartServerMux { m := &TCPCartServerMux{ mu: sync.RWMutex{}, - listeners: make(map[uint32]func(CartId, []byte) error), - functions: make(map[uint32]func(CartId, []byte) (uint32, []byte, error)), + listeners: make(map[CartMessage]func(CartId, []byte) error), + functions: make(map[CartMessage]func(CartId, []byte) (CartMessage, []byte, error)), } return m } -func (m *TCPCartServerMux) handleListener(messageType uint32, id CartId, data []byte) (bool, error) { +func (m *TCPCartServerMux) handleListener(messageType CartMessage, id CartId, data []byte) (bool, error) { m.mu.RLock() handler, ok := m.listeners[messageType] m.mu.RUnlock() @@ -64,7 +64,7 @@ func (m *TCPCartServerMux) handleListener(messageType uint32, id CartId, data [] return false, nil } -func (m *TCPCartServerMux) handleFunction(connection net.Conn, messageType uint32, id CartId, data []byte) (bool, error) { +func (m *TCPCartServerMux) handleFunction(connection net.Conn, messageType CartMessage, id CartId, data []byte) (bool, error) { m.mu.RLock() fn, ok := m.functions[messageType] m.mu.RUnlock() @@ -126,7 +126,7 @@ func (m *TCPCartServerMux) HandleConnection(connection net.Conn) error { } } -func (m *TCPCartServerMux) HandleData(connection net.Conn, t uint32, id CartId, data []byte) { +func (m *TCPCartServerMux) HandleData(connection net.Conn, t CartMessage, id CartId, data []byte) { status, err := m.handleListener(t, id, data) if err != nil { log.Printf("Error handling listener: %v\n", err) @@ -142,13 +142,13 @@ func (m *TCPCartServerMux) HandleData(connection net.Conn, t uint32, id CartId, } } -func (m *TCPCartServerMux) ListenFor(messageType uint32, handler func(CartId, []byte) error) { +func (m *TCPCartServerMux) ListenFor(messageType CartMessage, handler func(CartId, []byte) error) { m.mu.Lock() m.listeners[messageType] = handler m.mu.Unlock() } -func (m *TCPCartServerMux) HandleCall(messageType uint32, handler func(CartId, []byte) (uint32, []byte, error)) { +func (m *TCPCartServerMux) HandleCall(messageType CartMessage, handler func(CartId, []byte) (CartMessage, []byte, error)) { m.mu.Lock() m.functions[messageType] = handler m.mu.Unlock() diff --git a/tcp-cart_test.go b/tcp-cart_test.go index e43b672..5794d0c 100644 --- a/tcp-cart_test.go +++ b/tcp-cart_test.go @@ -22,17 +22,17 @@ func TestCartTcpHelpers(t *testing.T) { messageData = string(data) return nil }) - server.HandleCall(666, func(id CartId, data []byte) (uint32, []byte, error) { + server.HandleCall(666, func(id CartId, data []byte) (CartMessage, []byte, error) { log.Printf("Received call: %s\n", string(data)) return 3, []byte("Hello, client!"), fmt.Errorf("Det blev fel") }) - server.HandleCall(2, func(id CartId, data []byte) (uint32, []byte, error) { + server.HandleCall(2, func(id CartId, data []byte) (CartMessage, []byte, error) { log.Printf("Received call: %s\n", string(data)) return 3, []byte("Hello, client!"), nil }) - server.HandleCall(Ping, func(id CartId, data []byte) (uint32, []byte, error) { - return Pong, nil, nil - }) + // server.HandleCall(Ping, func(id CartId, data []byte) (CartMessage, []byte, error) { + // return Pong, nil, nil + // }) id := ToCartId("kalle") client.SendPacket(1, id, []byte("Hello, world!")) answer, err := client.Call(2, id, 3, []byte("Hello, server!")) @@ -46,16 +46,7 @@ func TestCartTcpHelpers(t *testing.T) { if s.StatusCode != 500 { t.Errorf("Expected 500, got %d\n", s.StatusCode) } - for i := 0; i < 100; i++ { - _, err = client.Call(Ping, id, Pong, nil) - if err != nil { - t.Errorf("Error calling: %v\n", err) - } - } - _, err = client.Call(Ping, id, Pong, nil) - if err != nil { - t.Errorf("Error calling: %v\n", err) - } + if string(answer.Data) != "Hello, client!" { t.Errorf("Expected answer 'Hello, client!', got %s\n", string(answer.Data)) } diff --git a/tcp-client.go b/tcp-client.go index d01679d..6dcf41b 100644 --- a/tcp-client.go +++ b/tcp-client.go @@ -92,7 +92,7 @@ func NewTCPClient(address string) (*TCPClient, error) { }, nil } -func (m *TCPClient) SendPacket(messageType uint32, data []byte) error { +func (m *TCPClient) SendPacket(messageType PoolMessage, data []byte) error { err := binary.Write(m.Conn, binary.LittleEndian, Packet{ Version: CurrentPacketVersion, @@ -107,7 +107,7 @@ func (m *TCPClient) SendPacket(messageType uint32, data []byte) error { return m.HandleConnectionError(err) } -func (m *TCPClient) Call(messageType uint32, responseType uint32, data []byte) (*CallResult, error) { +func (m *TCPClient) Call(messageType PoolMessage, responseType PoolMessage, data []byte) (*CallResult, error) { packetChan := m.Expect(responseType) err := m.SendPacket(messageType, data) if err != nil { diff --git a/tcp-mux-server.go b/tcp-mux-server.go index d2077b4..fa43cfc 100644 --- a/tcp-mux-server.go +++ b/tcp-mux-server.go @@ -37,21 +37,21 @@ func Listen(address string) (*Server, error) { type TCPServerMux struct { mu sync.RWMutex - listeners map[uint32]func(data []byte) error - functions map[uint32]func(data []byte) (uint32, []byte, error) + listeners map[PoolMessage]func(data []byte) error + functions map[PoolMessage]func(data []byte) (PoolMessage, []byte, error) } func NewTCPServerMux() *TCPServerMux { m := &TCPServerMux{ mu: sync.RWMutex{}, - listeners: make(map[uint32]func(data []byte) error), - functions: make(map[uint32]func(data []byte) (uint32, []byte, error)), + listeners: make(map[PoolMessage]func(data []byte) error), + functions: make(map[PoolMessage]func(data []byte) (PoolMessage, []byte, error)), } return m } -func (m *TCPServerMux) handleListener(messageType uint32, data []byte) (bool, error) { +func (m *TCPServerMux) handleListener(messageType PoolMessage, data []byte) (bool, error) { m.mu.RLock() handler, ok := m.listeners[messageType] m.mu.RUnlock() @@ -64,7 +64,7 @@ func (m *TCPServerMux) handleListener(messageType uint32, data []byte) (bool, er return false, nil } -func (m *TCPServerMux) handleFunction(connection net.Conn, messageType uint32, data []byte) (bool, error) { +func (m *TCPServerMux) handleFunction(connection net.Conn, messageType PoolMessage, data []byte) (bool, error) { m.mu.RLock() function, ok := m.functions[messageType] m.mu.RUnlock() @@ -124,7 +124,7 @@ func (m *TCPServerMux) HandleConnection(connection net.Conn) error { } } -func (m *TCPServerMux) HandleData(connection net.Conn, t uint32, data []byte) { +func (m *TCPServerMux) HandleData(connection net.Conn, t PoolMessage, data []byte) { // listener := m.listeners[t] // handler := m.functions[t] status, err := m.handleListener(t, data) @@ -142,13 +142,13 @@ func (m *TCPServerMux) HandleData(connection net.Conn, t uint32, data []byte) { } } -func (m *TCPServerMux) ListenFor(messageType uint32, handler func(data []byte) error) { +func (m *TCPServerMux) ListenFor(messageType PoolMessage, handler func(data []byte) error) { m.mu.Lock() m.listeners[messageType] = handler m.mu.Unlock() } -func (m *TCPServerMux) HandleCall(messageType uint32, handler func(data []byte) (uint32, []byte, error)) { +func (m *TCPServerMux) HandleCall(messageType PoolMessage, handler func(data []byte) (PoolMessage, []byte, error)) { m.mu.Lock() m.functions[messageType] = handler m.mu.Unlock() diff --git a/tcp_test.go b/tcp_test.go index 7cb331d..849c507 100644 --- a/tcp_test.go +++ b/tcp_test.go @@ -21,11 +21,11 @@ func TestTcpHelpers(t *testing.T) { messageData = string(data) return nil }) - server.HandleCall(2, func(data []byte) (uint32, []byte, error) { + server.HandleCall(2, func(data []byte) (PoolMessage, []byte, error) { log.Printf("Received call: %s\n", string(data)) return 3, []byte("Hello, client!"), nil }) - server.HandleCall(Ping, func(data []byte) (uint32, []byte, error) { + server.HandleCall(Ping, func(data []byte) (PoolMessage, []byte, error) { return Pong, nil, nil }) @@ -34,6 +34,12 @@ func TestTcpHelpers(t *testing.T) { if err != nil { t.Errorf("Error calling: %v\n", err) } + for i := 0; i < 100; i++ { + _, err = client.Call(Ping, Pong, nil) + if err != nil { + t.Errorf("Error calling: %v\n", err) + } + } _, err = client.Call(Ping, Pong, nil) if err != nil { t.Errorf("Error calling: %v\n", err)