diff --git a/rpc-pool.go b/rpc-pool.go index eb61bf2..6d1e16e 100644 --- a/rpc-pool.go +++ b/rpc-pool.go @@ -5,6 +5,7 @@ import ( "io" "net" "strings" + "time" ) type RemoteGrainPool struct { @@ -23,9 +24,10 @@ func ToCartId(id string) CartId { } type RemoteGrain struct { - client net.Conn - Id CartId - Address string + connection net.Conn + queue *PacketQueue + Id CartId + Address string } func NewRemoteGrain(id CartId, address string) *RemoteGrain { @@ -36,23 +38,27 @@ func NewRemoteGrain(id CartId, address string) *RemoteGrain { } func (g *RemoteGrain) Connect() error { - if g.client == nil { + if g.connection == nil { client, err := net.Dial("tcp", g.Address) if err != nil { return err } - g.client = client + g.connection = client + g.queue = NewPacketQueue(client) } return nil } func (g *RemoteGrain) HandleMessage(message *Message, isReplay bool) ([]byte, error) { - err := SendCartPacket(g.client, g.Id, RemoteHandleMessage, message.Write) + err := SendCartPacket(g.connection, g.Id, RemoteHandleMessage, message.Write) if err != nil { return nil, err } - _, data, err := ReceivePacket(g.client) - return data, err + packet, err := g.queue.Expect(ResponseBody, time.Second) + if err != nil { + return nil, err + } + return packet.Data, err } func (g *RemoteGrain) GetId() CartId { @@ -60,14 +66,17 @@ func (g *RemoteGrain) GetId() CartId { } func (g *RemoteGrain) GetCurrentState() ([]byte, error) { - err := SendCartPacket(g.client, g.Id, RemoteGetState, func(w io.Writer) error { + err := SendCartPacket(g.connection, g.Id, RemoteGetState, func(w io.Writer) error { return nil }) if err != nil { return nil, err } - _, data, err := ReceivePacket(g.client) - return data, err + packet, err := g.queue.Expect(ResponseBody, time.Second) + if err != nil { + return nil, err + } + return packet.Data, nil } func NewRemoteGrainPool(addr string) *RemoteGrainPool { diff --git a/synced-pool.go b/synced-pool.go index 6167c4f..8277fcb 100644 --- a/synced-pool.go +++ b/synced-pool.go @@ -8,6 +8,7 @@ import ( "log" "net" "strings" + "sync" "time" "github.com/prometheus/client_golang/prometheus" @@ -59,6 +60,7 @@ type RemoteHost struct { MissedPings int Pool *RemoteGrainPool connection net.Conn + queue *PacketQueue } type SyncedPool struct { @@ -90,7 +92,7 @@ func NewSyncedPool(local *GrainLocalPool, hostname string, d Discovery) (*Synced for { <-pingTimer.C for i, r := range pool.remotes { - err := DoPing(r.connection) + err := DoPing(r) if err != nil { r.MissedPings++ log.Printf("Error pinging remote %s: %v\n, missed pings: %d", r.Host, err, r.MissedPings) @@ -184,11 +186,63 @@ const ( RemoteNegotiate = uint16(3) RemoteGrainChanged = uint16(4) AckChange = uint16(5) - AckError = uint16(6) - Ping = uint16(7) - Pong = uint16(8) + //AckError = uint16(6) + Ping = uint16(7) + Pong = uint16(8) ) +type PacketWithData struct { + MessageType uint16 + Data []byte +} + +type PacketQueue struct { + mu sync.Mutex + Packets []PacketWithData + connection net.Conn +} + +func NewPacketQueue(connection net.Conn) *PacketQueue { + queue := &PacketQueue{ + Packets: make([]PacketWithData, 0), + connection: connection, + } + go func() { + for { + messageType, data, err := ReceivePacket(queue.connection) + if err != nil { + log.Printf("Error receiving packet: %v\n", err) + return + } + queue.mu.Lock() + queue.Packets = append(queue.Packets, PacketWithData{ + MessageType: messageType, + Data: data, + }) + queue.mu.Unlock() + } + }() + return queue +} + +func (p *PacketQueue) Expect(messageType uint16, timeToWait time.Duration) (PacketWithData, error) { + start := time.Now() + for { + if time.Since(start) > timeToWait { + return PacketWithData{}, fmt.Errorf("timeout waiting for message type %d", messageType) + } + for i, packet := range p.Packets { + if packet.MessageType == messageType { + p.mu.Lock() + p.Packets = append(p.Packets[:i], p.Packets[i+1:]...) + p.mu.Unlock() + return packet, nil + } + } + time.Sleep(time.Millisecond * 50) + } +} + func (p *SyncedPool) handleConnection(conn net.Conn) { defer conn.Close() var packet Packet @@ -262,10 +316,6 @@ func (p *SyncedPool) handleConnection(conn net.Conn) { if !found { log.Printf("Remote host %s not found\n", idAndHostParts[1]) log.Printf("Remotes %v\n", p.remotes) - err = SendPacket(conn, AckError, func(w io.Writer) error { - w.Write([]byte("remote host not found")) - return nil - }) } else { err = SendPacket(conn, AckChange, func(w io.Writer) error { _, err := w.Write([]byte("ok")) @@ -282,14 +332,13 @@ func (h *RemoteHost) Negotiate(knownHosts []string) ([]string, error) { w.Write([]byte(strings.Join(knownHosts, ";"))) return nil }) - t, data, err := ReceivePacket(h.connection) + packet, err := h.queue.Expect(RemoteNegotiate, time.Second) + if err != nil { return nil, err } - if t != RemoteNegotiate { - return nil, fmt.Errorf("unexpected message type %d", t) - } - return strings.Split(string(data), ";"), nil + + return strings.Split(string(packet.Data), ";"), nil } func (p *SyncedPool) Negotiate(knownHosts []string) ([]string, error) { @@ -315,16 +364,12 @@ func (r *RemoteHost) ConfirmChange(id CartId, host string) error { _, err := w.Write([]byte(fmt.Sprintf("%s;%s", id, host))) return err }) - t, data, err := ReceivePacket(r.connection) + _, err := r.queue.Expect(AckChange, time.Second) + if err != nil { return err } - if t == AckError { - return fmt.Errorf("error from remote: %s, from %s", string(data), r.Host) - } - if t != AckChange { - return fmt.Errorf("unexpected message type %d", t) - } + return nil } @@ -344,23 +389,23 @@ func (p *SyncedPool) AddRemoteWithConnection(address string, connection net.Conn pool := NewRemoteGrainPool(fmt.Sprintf(address, 1337)) remote := RemoteHost{ connection: connection, + queue: NewPacketQueue(connection), Pool: pool, Host: address, } return p.addRemoteHost(address, &remote) } -func DoPing(connection net.Conn) error { - SendPacket(connection, Ping, func(w io.Writer) error { +func DoPing(host *RemoteHost) error { + SendPacket(host.connection, Ping, func(w io.Writer) error { return nil }) - t, _, err := ReceivePacket(connection) + _, err := host.queue.Expect(Pong, time.Second) + if err != nil { return err } - if t != Pong { - return fmt.Errorf("unexpected message type %d", t) - } + return nil } @@ -372,7 +417,7 @@ func (p *SyncedPool) addRemoteHost(address string, remote *RemoteHost) error { } } - err := DoPing(remote.connection) + err := DoPing(remote) if err != nil { log.Printf("Error pinging remote %s: %v\n", address, err) } @@ -395,6 +440,7 @@ func (p *SyncedPool) AddRemote(address string) error { pool := NewRemoteGrainPool(fmt.Sprintf(address, 1337)) remote := RemoteHost{ connection: connection, + queue: NewPacketQueue(connection), Pool: pool, Host: address, }