From 10d85350d03f814a01e8787b46fc319bbb3968e0 Mon Sep 17 00:00:00 2001 From: matst80 Date: Sun, 10 Nov 2024 16:40:52 +0100 Subject: [PATCH] tcp mux and stuff --- packet.go | 17 +- synced-pool.go | 489 ++++++++++++++++++-------------------------- synced-pool_test.go | 10 + tcp-mux-client.go | 76 +++++++ tcp-mux-server.go | 133 ++++++++++++ tcp-mux_test.go | 40 ++++ 6 files changed, 470 insertions(+), 295 deletions(-) create mode 100644 tcp-mux-client.go create mode 100644 tcp-mux-server.go create mode 100644 tcp-mux_test.go diff --git a/packet.go b/packet.go index 94a3c31..a1e12be 100644 --- a/packet.go +++ b/packet.go @@ -63,12 +63,15 @@ func SendPacket(conn io.Writer, messageType uint16, datafn func(w io.Writer) err } func SendRawResponse(conn io.Writer, data []byte) error { - binary.Write(conn, binary.LittleEndian, Packet{ + err := binary.Write(conn, binary.LittleEndian, Packet{ Version: 1, MessageType: ResponseBody, DataLength: uint16(len(data)), }) - _, err := conn.Write(data) + if err != nil { + return err + } + _, err = conn.Write(data) return err } @@ -90,6 +93,12 @@ func ReceivePacket(conn io.Reader) (uint16, []byte, error) { return packet.MessageType, nil, err } data := make([]byte, packet.DataLength) - _, err = conn.Read(data) - return packet.MessageType, data, err + l, err := conn.Read(data) + if err != nil { + return packet.MessageType, nil, err + } + if l != int(packet.DataLength) { + return packet.MessageType, nil, fmt.Errorf("expected %d bytes, got %d", packet.DataLength, l) + } + return packet.MessageType, data, nil } diff --git a/synced-pool.go b/synced-pool.go index d2922cc..2a77ed7 100644 --- a/synced-pool.go +++ b/synced-pool.go @@ -1,11 +1,8 @@ package main import ( - "encoding/binary" "fmt" - "io" "log" - "net" "strings" "sync" "time" @@ -20,149 +17,22 @@ type Quorum interface { } type RemoteHost struct { - net.Conn - *PacketQueue + *Client Host string MissedPings int Pool *RemoteGrainPool } type SyncedPool struct { - mu sync.RWMutex - Discovery Discovery - listener net.Listener + *Server + mu sync.RWMutex + //Discovery Discovery Hostname string local *GrainLocalPool remotes []*RemoteHost remoteIndex map[CartId]*RemoteGrainPool } -func NewSyncedPool(local *GrainLocalPool, hostname string, d Discovery) (*SyncedPool, error) { - listen := fmt.Sprintf("%s:1338", hostname) - l, err := net.Listen("tcp", listen) - log.Printf("Listening on %s", listen) - if err != nil { - return nil, err - } - pool := &SyncedPool{ - Discovery: d, - Hostname: hostname, - local: local, - listener: l, - remotes: make([]*RemoteHost, 0), - remoteIndex: make(map[CartId]*RemoteGrainPool), - } - - go func() { - for { - for range time.Tick(time.Second * 2) { - for _, r := range pool.remotes { - err := DoPing(r) - if err != nil { - r.MissedPings++ - log.Printf("Error pinging remote %s: %v\n, missed pings: %d", r.Host, err, r.MissedPings) - if r.MissedPings > 3 { - log.Printf("Removing remote %s\n", r.Host) - go pool.RemoveHost(r) - //pool.remotes = append(pool.remotes[:i], pool.remotes[i+1:]...) - - } - } else { - r.MissedPings = 0 - } - } - connectedRemotes.Set(float64(len(pool.remotes))) - } - } - }() - - if d != nil { - go func() { - ch, err := d.Watch() - if err != nil { - log.Printf("Error discovering hosts: %v", err) - return - } - for host := range ch { - if pool.IsKnown(host) { - continue - } - go func(h string) { - log.Printf("Discovered host %s, waiting for startup", h) - time.Sleep(time.Second) - err := pool.AddRemote(h) - if err != nil { - log.Printf("Error adding remote %s: %v", h, err) - } - }(host) - } - }() - } else { - log.Printf("No discovery, waiting for remotes to connect") - } - go func() { - for { - conn, err := l.Accept() - if err != nil { - log.Printf("Error accepting connection: %v\n", err) - continue - } - log.Printf("Got connection from %s", conn.RemoteAddr()) - - go pool.handleConnection(conn) - } - }() - return pool, nil -} - -func (p *SyncedPool) IsKnown(host string) bool { - for _, r := range p.remotes { - if r.Host == host { - return true - } - } - return host != p.Hostname -} - -func (p *SyncedPool) ExcludeKnown(hosts []string) []string { - ret := make([]string, 0, len(hosts)) - for _, h := range hosts { - found := false - for _, r := range p.remotes { - if r.Host == h { - found = true - break - } - } - if !found && h != p.Hostname { - ret = append(ret, h) - } - } - return ret -} - -func (p *SyncedPool) RemoveHost(host *RemoteHost) { - - for i, r := range p.remotes { - if r == host { - p.RemoveHostMappedCarts(r) - p.remotes = append(p.remotes[:i], p.remotes[i+1:]...) - connectedRemotes.Set(float64(len(p.remotes))) - return - } - } -} - -func (p *SyncedPool) RemoveHostMappedCarts(host *RemoteHost) { - p.mu.Lock() - defer p.mu.Unlock() - for id, r := range p.remoteIndex { - if r == host.Pool { - delete(p.remoteIndex, id) - } - } -} - var ( negotiationCount = promauto.NewCounter(prometheus.CounterOpts{ Name: "cart_remote_negotiation_total", @@ -194,148 +64,207 @@ var ( }) ) +func (p *SyncedPool) PongHandler(data []byte) (uint16, []byte, error) { + return Pong, data, nil +} + +func (p *SyncedPool) GetCartIdHandler(data []byte) (uint16, []byte, error) { + ids := make([]string, 0, len(p.local.grains)) + for id := range p.local.grains { + ids = append(ids, id.String()) + } + return CartIdsResponse, []byte(strings.Join(ids, ";")), nil +} + +func (p *SyncedPool) NegotiateHandler(data []byte) (uint16, []byte, error) { + negotiationCount.Inc() + log.Printf("Handling negotiation\n") + for _, host := range p.ExcludeKnown(strings.Split(string(data), ";")) { + err := p.AddRemote(host) + if err != nil { + log.Printf("Error adding remote %s: %v\n", host, err) + } + } + + return RemoteNegotiateResponse, []byte("ok"), nil +} + +func (p *SyncedPool) GrainOwnerChangeHandler(data []byte) (uint16, []byte, error) { + grainSyncCount.Inc() + + idAndHostParts := strings.Split(string(data), ";") + if len(idAndHostParts) != 2 { + log.Printf("Invalid remote grain change message\n") + return AckChange, []byte("incorrect"), nil + } + + for _, r := range p.remotes { + if r.Host == string(idAndHostParts[1]) { + + log.Printf("Remote grain %s changed to %s\n", idAndHostParts[0], idAndHostParts[1]) + p.mu.Lock() + if p.local.grains[ToCartId(idAndHostParts[0])] != nil { + log.Printf("Grain %s already exists locally, deleting\n", idAndHostParts[0]) + delete(p.local.grains, ToCartId(idAndHostParts[0])) + } + p.remoteIndex[ToCartId(idAndHostParts[0])] = r.Pool + p.mu.Unlock() + return AckChange, []byte("ok"), nil + } + } + return AckChange, []byte("not found"), nil +} + +func NewSyncedPool(local *GrainLocalPool, hostname string, discovery Discovery) (*SyncedPool, error) { + listen := fmt.Sprintf("%s:1338", hostname) + + server, err := Listen(listen) + if err != nil { + return nil, err + } + + log.Printf("Listening on %s", listen) + + pool := &SyncedPool{ + Server: server, + //Discovery: discovery, + Hostname: hostname, + local: local, + + remotes: make([]*RemoteHost, 0), + remoteIndex: make(map[CartId]*RemoteGrainPool), + } + + server.HandleCall(Ping, pool.PongHandler) + server.HandleCall(GetCartIds, pool.GetCartIdHandler) + server.HandleCall(RemoteNegotiate, pool.NegotiateHandler) + server.HandleCall(RemoteGrainChanged, pool.GrainOwnerChangeHandler) + + // // TODO FIX THIS, ONLY CLIENT OR SERVER SHOULD PING + // go func() { + // for { + // for range time.Tick(time.Second * 2) { + // for _, r := range pool.remotes { + // err := DoPing(r) + // if err != nil { + // r.MissedPings++ + // log.Printf("Error pinging remote %s: %v\n, missed pings: %d", r.Host, err, r.MissedPings) + // if r.MissedPings > 3 { + // log.Printf("Removing remote %s\n", r.Host) + // go pool.RemoveHost(r) + // //pool.remotes = append(pool.remotes[:i], pool.remotes[i+1:]...) + + // } + // } else { + // r.MissedPings = 0 + // } + // } + // connectedRemotes.Set(float64(len(pool.remotes))) + // } + // } + // }() + + if discovery != nil { + go func() { + ch, err := discovery.Watch() + if err != nil { + log.Printf("Error discovering hosts: %v", err) + return + } + for host := range ch { + if pool.IsKnown(host) { + continue + } + go func(h string) { + log.Printf("Discovered host %s, waiting for startup", h) + time.Sleep(time.Second) + err := pool.AddRemote(h) + if err != nil { + log.Printf("Error adding remote %s: %v", h, err) + } + }(host) + } + }() + } else { + log.Printf("No discovery, waiting for remotes to connect") + } + + return pool, nil +} + +func (p *SyncedPool) IsKnown(host string) bool { + for _, r := range p.remotes { + if r.Host == host { + return true + } + } + return host != p.Hostname +} + +func (p *SyncedPool) ExcludeKnown(hosts []string) []string { + ret := make([]string, 0, len(hosts)) + for _, h := range hosts { + if !p.IsKnown(h) { + ret = append(ret, h) + } + } + return ret +} + +func (p *SyncedPool) RemoveHost(host *RemoteHost) { + for i, r := range p.remotes { + if r == host { + p.RemoveHostMappedCarts(r) + p.remotes = append(p.remotes[:i], p.remotes[i+1:]...) + connectedRemotes.Set(float64(len(p.remotes))) + return + } + } +} + +func (p *SyncedPool) RemoveHostMappedCarts(host *RemoteHost) { + p.mu.Lock() + defer p.mu.Unlock() + for id, r := range p.remoteIndex { + if r == host.Pool { + delete(p.remoteIndex, id) + } + } +} + const ( RemoteNegotiate = uint16(3) RemoteGrainChanged = uint16(4) AckChange = uint16(5) //AckError = uint16(6) - Ping = uint16(7) - Pong = uint16(8) - GetCartIds = uint16(9) - CartIdsResponse = uint16(10) + Ping = uint16(7) + Pong = uint16(8) + GetCartIds = uint16(9) + CartIdsResponse = uint16(10) + RemoteNegotiateResponse = uint16(11) ) -func (p *SyncedPool) handleConnection(conn net.Conn) { - defer conn.Close() - var packet Packet - for { - err := binary.Read(conn, binary.LittleEndian, &packet) - if err != nil { - if err == io.EOF { - break - } - log.Printf("Error in connection: %v\n", err) - } - // if packet.Version != 1 { - // log.Printf("Invalid version %d\n", packet.Version) - // return - // } - switch packet.MessageType { - case Ping: - err = SendPacket(conn, Pong, func(w io.Writer) error { - return nil - }) - if err != nil { - log.Printf("Error sending pong: %v\n", err) - } - case RemoteNegotiate: - negotiationCount.Inc() - data := make([]byte, packet.DataLength) - conn.Read(data) - knownHosts := strings.Split(string(data), ";") - log.Printf("Negotiated with remote, found %v hosts\n", knownHosts) - - SendPacket(conn, RemoteNegotiate, func(w io.Writer) error { - hostnames := make([]string, 0, len(p.remotes)) - for _, r := range p.remotes { - hostnames = append(hostnames, r.Host) - } - w.Write([]byte(strings.Join(hostnames, ";"))) - return nil - }) - for _, h := range knownHosts { - err = p.AddRemote(h) - if err != nil { - log.Printf("Error adding remote %s: %v\n", h, err) - } - } - case RemoteGrainChanged: - // remote grain changed - grainSyncCount.Inc() - - idAndHost := make([]byte, packet.DataLength) - _, err = conn.Read(idAndHost) - if err != nil { - break - } - idAndHostParts := strings.Split(string(idAndHost), ";") - if len(idAndHostParts) != 2 { - log.Printf("Invalid remote grain change message\n") - break - } - found := false - for _, r := range p.remotes { - if r.Host == string(idAndHostParts[1]) { - found = true - log.Printf("Remote grain %s changed to %s\n", idAndHostParts[0], idAndHostParts[1]) - p.mu.Lock() - if p.local.grains[ToCartId(idAndHostParts[0])] != nil { - log.Printf("Grain %s already exists locally, deleting\n", idAndHostParts[0]) - delete(p.local.grains, ToCartId(idAndHostParts[0])) - } - p.remoteIndex[ToCartId(idAndHostParts[0])] = r.Pool - p.mu.Unlock() - } - } - - if !found { - log.Printf("Remote host %s not found\n", idAndHostParts[1]) - } else { - SendPacket(conn, AckChange, func(w io.Writer) error { - _, err := w.Write([]byte("ok")) - return err - }) - } - - case GetCartIds: - ids := make([]string, 0, len(p.local.grains)) - for id := range p.local.grains { - ids = append(ids, id.String()) - } - SendPacket(conn, CartIdsResponse, func(w io.Writer) error { - _, err := w.Write([]byte(strings.Join(ids, ";"))) - return err - }) - } - } -} - func (h *RemoteHost) Negotiate(knownHosts []string) ([]string, error) { - err := SendPacket(h.connection, RemoteNegotiate, func(w io.Writer) error { - w.Write([]byte(strings.Join(knownHosts, ";"))) - return nil - }) - if err != nil { - return nil, err - } - packet, err := h.Expect(RemoteNegotiate, time.Second) + data, err := h.Call(RemoteNegotiate, RemoteNegotiateResponse, []byte(strings.Join(knownHosts, ";"))) if err != nil { return nil, err } - return strings.Split(string(packet.Data), ";"), nil + return strings.Split(string(data), ";"), nil } -func (g *RemoteHost) GetCartMappings() []CartId { - err := SendPacket(g.connection, GetCartIds, func(w io.Writer) error { - return nil - }) +func (g *RemoteHost) GetCartMappings() ([]CartId, error) { + data, err := g.Call(GetCartIds, CartIdsResponse, nil) if err != nil { - log.Printf("Error getting mappings: %v\n", err) - return nil + return nil, err } - packet, err := g.Expect(CartIdsResponse, time.Second*3) - if err != nil { - log.Printf("Error getting mappings: %v\n", err) - return nil - } - parts := strings.Split(string(packet.Data), ";") + parts := strings.Split(string(data), ";") ids := make([]CartId, 0, len(parts)) for _, p := range parts { ids = append(ids, ToCartId(p)) } - return ids + return ids, nil } func (p *SyncedPool) Negotiate(knownHosts []string) ([]string, error) { @@ -357,18 +286,14 @@ func (p *SyncedPool) Negotiate(knownHosts []string) ([]string, error) { } func (r *RemoteHost) ConfirmChange(id CartId, host string) error { - err := SendPacket(r.connection, RemoteGrainChanged, func(w io.Writer) error { - _, err := w.Write([]byte(fmt.Sprintf("%s;%s", id, host))) - return err - }) - if err != nil { - return err - } - _, err = r.Expect(AckChange, time.Second) + data, err := r.Call(RemoteGrainChanged, AckChange, []byte(fmt.Sprintf("%s;%s", id, host))) if err != nil { return err } + if string(data) != "ok" { + return fmt.Errorf("remote grain change failed %s", string(data)) + } return nil } @@ -386,23 +311,6 @@ func (p *SyncedPool) RequestOwnership(id CartId) error { return nil } -func DoPing(host *RemoteHost) error { - - err := SendPacket(host, Ping, func(w io.Writer) error { - return nil - }) - if err != nil { - return err - } - _, err = host.Expect(Pong, time.Second) - - if err != nil { - return err - } - - return nil -} - func (p *SyncedPool) addRemoteHost(address string, remote *RemoteHost) error { known := make([]string, 0, len(p.remotes)) for _, r := range p.remotes { @@ -413,18 +321,17 @@ func (p *SyncedPool) addRemoteHost(address string, remote *RemoteHost) error { } } - err := DoPing(remote) - if err != nil { - log.Printf("Error pinging remote %s: %v\n", address, err) - } - p.remotes = append(p.remotes, remote) connectedRemotes.Set(float64(len(p.remotes))) log.Printf("Added remote %s\n", remote.Host) go func() { p.Negotiate(known) - ids := remote.GetCartMappings() + ids, err := remote.GetCartMappings() + if err != nil { + log.Printf("Error getting remote mappings: %v\n", err) + return + } p.mu.Lock() for _, id := range ids { if p.local.grains[id] != nil { @@ -442,7 +349,8 @@ func (p *SyncedPool) AddRemote(address string) error { if address == "" || p.IsKnown(address) { return nil } - connection, err := net.Dial("tcp", fmt.Sprintf("%s:1338", address)) + client, err := Dial(fmt.Sprintf("%s:1338", address)) + if err != nil { log.Printf("Error connecting to remote %s: %v\n", address, err) return err @@ -450,10 +358,9 @@ func (p *SyncedPool) AddRemote(address string) error { pool := NewRemoteGrainPool(address) remote := RemoteHost{ - Conn: connection, - PacketQueue: NewPacketQueue(connection), - Pool: pool, - Host: address, + Client: client, + Pool: pool, + Host: address, } return p.addRemoteHost(address, &remote) diff --git a/synced-pool_test.go b/synced-pool_test.go index 89d1ec2..147cb0c 100644 --- a/synced-pool_test.go +++ b/synced-pool_test.go @@ -32,4 +32,14 @@ func TestConnection(t *testing.T) { if len(allHosts) != 1 { t.Errorf("Expected 1 host, got %d", len(allHosts)) } + + data, err := pool.Get(ToCartId("kalle")) + if err != nil { + t.Errorf("Error getting data: %v", err) + } + if data == nil { + t.Errorf("Expected data, got nil") + } + time.Sleep(2 * time.Millisecond) + } diff --git a/tcp-mux-client.go b/tcp-mux-client.go new file mode 100644 index 0000000..e1816d7 --- /dev/null +++ b/tcp-mux-client.go @@ -0,0 +1,76 @@ +package main + +import ( + "encoding/binary" + "io" + "net" + "time" +) + +type Client struct { + *TCPClientMux +} + +func Dial(address string) (*Client, error) { + conn, err := net.Dial("tcp", address) + if err != nil { + return nil, err + } + client := &Client{ + TCPClientMux: NewTCPClientMux(conn), + } + return client, nil +} + +func (c *Client) Close() { + c.Conn.Close() +} + +type TCPClientMux struct { + net.Conn + *PacketQueue +} + +func NewTCPClientMux(connection net.Conn) *TCPClientMux { + return &TCPClientMux{ + Conn: connection, + PacketQueue: NewPacketQueue(connection), + } +} + +func (m *TCPClientMux) Close() { + m.Conn.Close() +} + +func (m *TCPClientMux) SendPacket(messageType uint16, data []byte) error { + err := binary.Write(m.Conn, binary.LittleEndian, Packet{ + Version: 1, + MessageType: messageType, + DataLength: uint16(len(data)), + }) + if err != nil { + return err + } + _, err = m.Conn.Write(data) + return err +} + +func (m *TCPClientMux) SendPacketFn(messageType uint16, datafn func(w io.Writer) error) error { + data, err := GetData(datafn) + if err != nil { + return err + } + return m.SendPacket(messageType, data) +} + +func (m *TCPClientMux) Call(messageType uint16, responseType uint16, data []byte) ([]byte, error) { + err := m.SendPacket(messageType, data) + if err != nil { + return nil, err + } + packet, err := m.Expect(responseType, time.Second) + if err != nil { + return nil, err + } + return packet.Data, nil +} diff --git a/tcp-mux-server.go b/tcp-mux-server.go new file mode 100644 index 0000000..089e0c3 --- /dev/null +++ b/tcp-mux-server.go @@ -0,0 +1,133 @@ +package main + +import ( + "encoding/binary" + "io" + "log" + "net" + "sync" +) + +type Server struct { + *TCPServerMux +} + +func Listen(address string) (*Server, error) { + listener, err := net.Listen("tcp", address) + server := &Server{ + NewTCPServerMux(100), + } + + if err != nil { + return nil, err + } + go func() { + for { + conn, err := listener.Accept() + if err != nil { + log.Printf("Error accepting connection: %v\n", err) + continue + } + go server.HandleConnection(conn) + } + }() + return server, nil +} + +type TCPServerMux struct { + mu sync.RWMutex + listeners map[uint16]func(data []byte) error + functions map[uint16]func(data []byte) (uint16, []byte, error) + connections []net.Conn +} + +func NewTCPServerMux(maxClients int) *TCPServerMux { + m := &TCPServerMux{ + connections: make([]net.Conn, 0, maxClients), + mu: sync.RWMutex{}, + listeners: make(map[uint16]func(data []byte) error), + functions: make(map[uint16]func(data []byte) (uint16, []byte, error)), + } + + return m +} + +func (m *TCPServerMux) handleListener(messageType uint16, data []byte) (bool, error) { + m.mu.RLock() + handler, ok := m.listeners[messageType] + m.mu.RUnlock() + if ok { + err := handler(data) + if err != nil { + return true, err + } + } + return false, nil +} + +func (m *TCPServerMux) handleFunction(connection net.Conn, messageType uint16, data []byte) (bool, error) { + m.mu.RLock() + function, ok := m.functions[messageType] + m.mu.RUnlock() + if ok { + responseType, responseData, err := function(data) + if err != nil { + return true, err + } + err = binary.Write(connection, binary.LittleEndian, Packet{ + Version: 1, + MessageType: responseType, + DataLength: uint16(len(responseData)), + }) + if err != nil { + return true, err + } + packetsSent.Inc() + _, err = connection.Write(responseData) + return true, err + } + return false, nil +} + +func (m *TCPServerMux) HandleConnection(connection net.Conn) error { + m.mu.Lock() + m.connections = append(m.connections, connection) + m.mu.Unlock() + defer connection.Close() + for { + messageType, data, err := ReceivePacket(connection) + if err != nil { + if err == io.EOF { + return nil + } + log.Printf("Error receiving packet: %v\n", err) + return err + } + + status, err := m.handleListener(messageType, data) + if err != nil { + log.Printf("Error handling listener: %v\n", err) + } + if !status { + status, err = m.handleFunction(connection, messageType, data) + if err != nil { + log.Printf("Error handling function: %v\n", err) + } + if !status { + log.Printf("Unknown message type: %d\n", messageType) + } + } + } +} + +func (m *TCPServerMux) ListenFor(messageType uint16, handler func(data []byte) error) { + m.mu.Lock() + m.listeners[messageType] = handler + m.mu.Unlock() +} + +func (m *TCPServerMux) HandleCall(messageType uint16, handler func(data []byte) (uint16, []byte, error)) { + m.mu.Lock() + m.functions[messageType] = handler + m.mu.Unlock() +} diff --git a/tcp-mux_test.go b/tcp-mux_test.go new file mode 100644 index 0000000..e58a600 --- /dev/null +++ b/tcp-mux_test.go @@ -0,0 +1,40 @@ +package main + +import ( + "log" + "testing" +) + +func TestTcpHelpers(t *testing.T) { + + server, err := Listen(":1337") + if err != nil { + t.Errorf("Error listening: %v\n", err) + } + client, err := Dial("localhost:1337") + if err != nil { + t.Errorf("Error dialing: %v\n", err) + } + var messageData string + server.ListenFor(1, func(data []byte) error { + log.Printf("Received message: %s\n", string(data)) + messageData = string(data) + return nil + }) + server.HandleCall(2, func(data []byte) (uint16, []byte, error) { + log.Printf("Received call: %s\n", string(data)) + return 3, []byte("Hello, client!"), nil + }) + + client.SendPacket(1, []byte("Hello, world!")) + answer, err := client.Call(2, 3, []byte("Hello, server!")) + if err != nil { + t.Errorf("Error calling: %v\n", err) + } + if string(answer) != "Hello, client!" { + t.Errorf("Expected answer 'Hello, client!', got %s\n", string(answer)) + } + if messageData != "Hello, world!" { + t.Errorf("Expected message 'Hello, world!', got %s\n", messageData) + } +}