package main import ( "encoding/binary" "fmt" "io" "log" "net" "strings" "sync" "time" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" ) type Quorum interface { Negotiate(knownHosts []string) ([]string, error) OwnerChanged(CartId, host string) error } type RemoteHost struct { net.Conn *PacketQueue Host string MissedPings int Pool *RemoteGrainPool } type SyncedPool struct { mu sync.RWMutex Discovery Discovery listener net.Listener Hostname string local *GrainLocalPool remotes []*RemoteHost remoteIndex map[CartId]*RemoteGrainPool } func NewSyncedPool(local *GrainLocalPool, hostname string, d Discovery) (*SyncedPool, error) { listen := fmt.Sprintf("%s:1338", hostname) l, err := net.Listen("tcp", listen) log.Printf("Listening on %s", listen) if err != nil { return nil, err } pool := &SyncedPool{ Discovery: d, Hostname: hostname, local: local, listener: l, remotes: make([]*RemoteHost, 0), remoteIndex: make(map[CartId]*RemoteGrainPool), } go func() { for { for range time.Tick(time.Second * 2) { for _, r := range pool.remotes { err := DoPing(r) if err != nil { r.MissedPings++ log.Printf("Error pinging remote %s: %v\n, missed pings: %d", r.Host, err, r.MissedPings) if r.MissedPings > 3 { log.Printf("Removing remote %s\n", r.Host) go pool.RemoveHost(r) //pool.remotes = append(pool.remotes[:i], pool.remotes[i+1:]...) } } else { r.MissedPings = 0 } } connectedRemotes.Set(float64(len(pool.remotes))) } } }() if d != nil { go func() { ch, err := d.Watch() if err != nil { log.Printf("Error discovering hosts: %v", err) return } for host := range ch { if pool.IsKnown(host) { continue } go func(h string) { log.Printf("Discovered host %s, waiting for startup", h) time.Sleep(time.Second) err := pool.AddRemote(h) if err != nil { log.Printf("Error adding remote %s: %v", h, err) } }(host) } }() } else { log.Printf("No discovery, waiting for remotes to connect") } go func() { for { conn, err := l.Accept() if err != nil { log.Printf("Error accepting connection: %v\n", err) continue } log.Printf("Got connection from %s", conn.RemoteAddr()) go pool.handleConnection(conn) } }() return pool, nil } func (p *SyncedPool) IsKnown(host string) bool { for _, r := range p.remotes { if r.Host == host { return true } } return host != p.Hostname } func (p *SyncedPool) ExcludeKnown(hosts []string) []string { ret := make([]string, 0, len(hosts)) for _, h := range hosts { found := false for _, r := range p.remotes { if r.Host == h { found = true break } } if !found && h != p.Hostname { ret = append(ret, h) } } return ret } func (p *SyncedPool) RemoveHost(host *RemoteHost) { for i, r := range p.remotes { if r == host { p.RemoveHostMappedCarts(r) p.remotes = append(p.remotes[:i], p.remotes[i+1:]...) connectedRemotes.Set(float64(len(p.remotes))) return } } } func (p *SyncedPool) RemoveHostMappedCarts(host *RemoteHost) { p.mu.Lock() defer p.mu.Unlock() for id, r := range p.remoteIndex { if r == host.Pool { delete(p.remoteIndex, id) } } } var ( negotiationCount = promauto.NewCounter(prometheus.CounterOpts{ Name: "cart_remote_negotiation_total", Help: "The total number of remote negotiations", }) grainSyncCount = promauto.NewCounter(prometheus.CounterOpts{ Name: "cart_grain_sync_total", Help: "The total number of grain owner changes", }) connectedRemotes = promauto.NewGauge(prometheus.GaugeOpts{ Name: "cart_connected_remotes", Help: "The number of connected remotes", }) remoteLookupCount = promauto.NewCounter(prometheus.CounterOpts{ Name: "cart_remote_lookup_total", Help: "The total number of remote lookups", }) packetQueue = promauto.NewGauge(prometheus.GaugeOpts{ Name: "cart_packet_queue_size", Help: "The total number of packets in the queue", }) packetsSent = promauto.NewCounter(prometheus.CounterOpts{ Name: "cart_pool_packets_sent_total", Help: "The total number of packets sent", }) packetsReceived = promauto.NewCounter(prometheus.CounterOpts{ Name: "cart_pool_packets_received_total", Help: "The total number of packets received", }) ) const ( RemoteNegotiate = uint16(3) RemoteGrainChanged = uint16(4) AckChange = uint16(5) //AckError = uint16(6) Ping = uint16(7) Pong = uint16(8) GetCartIds = uint16(9) CartIdsResponse = uint16(10) ) type PacketWithData struct { MessageType uint16 Added time.Time Data []byte } type PacketQueue struct { mu sync.RWMutex Packets []PacketWithData connection net.Conn } func NewPacketQueue(connection net.Conn) *PacketQueue { queue := &PacketQueue{ Packets: make([]PacketWithData, 0), connection: connection, } go func() { defer connection.Close() for { messageType, data, err := ReceivePacket(queue.connection) ts := time.Now() if err != nil { log.Printf("Error receiving packet: %v\n", err) if err == io.EOF { return } //return } queue.mu.Lock() for i, packet := range queue.Packets { if time.Since(packet.Added) < time.Second { stillInQueue := queue.Packets[i:] log.Printf("DEBUG: Requeueing %v packets\n", stillInQueue) queue.Packets = stillInQueue packetQueue.Set(float64(len(queue.Packets))) break } } queue.Packets = append(queue.Packets, PacketWithData{ MessageType: messageType, Added: ts, Data: data, }) queue.mu.Unlock() packetsReceived.Inc() packetQueue.Inc() } }() return queue } func (p *PacketQueue) Expect(messageType uint16, timeToWait time.Duration) (*PacketWithData, error) { start := time.Now().Add(-time.Millisecond) for { if time.Since(start) > timeToWait { return nil, fmt.Errorf("timeout waiting for message type %d", messageType) } p.mu.RLock() for _, packet := range p.Packets { if packet.MessageType == messageType && packet.Added.After(start) { p.mu.RUnlock() return &packet, nil } } p.mu.RUnlock() time.Sleep(time.Millisecond * 5) } } func (p *SyncedPool) handleConnection(conn net.Conn) { defer conn.Close() var packet Packet for { err := binary.Read(conn, binary.LittleEndian, &packet) if err != nil { if err == io.EOF { break } log.Printf("Error in connection: %v\n", err) } // if packet.Version != 1 { // log.Printf("Invalid version %d\n", packet.Version) // return // } switch packet.MessageType { case Ping: err = SendPacket(conn, Pong, func(w io.Writer) error { return nil }) if err != nil { log.Printf("Error sending pong: %v\n", err) } case RemoteNegotiate: negotiationCount.Inc() data := make([]byte, packet.DataLength) conn.Read(data) knownHosts := strings.Split(string(data), ";") log.Printf("Negotiated with remote, found %v hosts\n", knownHosts) SendPacket(conn, RemoteNegotiate, func(w io.Writer) error { hostnames := make([]string, 0, len(p.remotes)) for _, r := range p.remotes { hostnames = append(hostnames, r.Host) } w.Write([]byte(strings.Join(hostnames, ";"))) return nil }) for _, h := range knownHosts { err = p.AddRemote(h) if err != nil { log.Printf("Error adding remote %s: %v\n", h, err) } } case RemoteGrainChanged: // remote grain changed grainSyncCount.Inc() idAndHost := make([]byte, packet.DataLength) _, err = conn.Read(idAndHost) if err != nil { break } idAndHostParts := strings.Split(string(idAndHost), ";") if len(idAndHostParts) != 2 { log.Printf("Invalid remote grain change message\n") break } found := false for _, r := range p.remotes { if r.Host == string(idAndHostParts[1]) { found = true log.Printf("Remote grain %s changed to %s\n", idAndHostParts[0], idAndHostParts[1]) p.mu.Lock() p.remoteIndex[ToCartId(idAndHostParts[0])] = r.Pool p.mu.Unlock() } } if !found { log.Printf("Remote host %s not found\n", idAndHostParts[1]) } else { SendPacket(conn, AckChange, func(w io.Writer) error { _, err := w.Write([]byte("ok")) return err }) } case GetCartIds: ids := make([]string, 0, len(p.local.grains)) for id := range p.local.grains { ids = append(ids, id.String()) } SendPacket(conn, CartIdsResponse, func(w io.Writer) error { _, err := w.Write([]byte(strings.Join(ids, ";"))) return err }) } } } func (h *RemoteHost) Negotiate(knownHosts []string) ([]string, error) { err := SendPacket(h.connection, RemoteNegotiate, func(w io.Writer) error { w.Write([]byte(strings.Join(knownHosts, ";"))) return nil }) if err != nil { return nil, err } packet, err := h.Expect(RemoteNegotiate, time.Second) if err != nil { return nil, err } return strings.Split(string(packet.Data), ";"), nil } func (g *RemoteHost) GetCartMappings() []CartId { err := SendPacket(g.connection, GetCartIds, func(w io.Writer) error { return nil }) if err != nil { log.Printf("Error getting mappings: %v\n", err) return nil } packet, err := g.Expect(CartIdsResponse, time.Second*3) if err != nil { log.Printf("Error getting mappings: %v\n", err) return nil } parts := strings.Split(string(packet.Data), ";") ids := make([]CartId, 0, len(parts)) for _, p := range parts { ids = append(ids, ToCartId(p)) } return ids } func (p *SyncedPool) Negotiate(knownHosts []string) ([]string, error) { allHosts := make(map[string]struct{}, 0) for _, r := range p.remotes { hosts, err := r.Negotiate(knownHosts) if err != nil { return nil, err } for _, h := range hosts { allHosts[h] = struct{}{} } } ret := make([]string, 0, len(allHosts)) for h := range allHosts { ret = append(ret, h) } return ret, nil } func (r *RemoteHost) ConfirmChange(id CartId, host string) error { err := SendPacket(r.connection, RemoteGrainChanged, func(w io.Writer) error { _, err := w.Write([]byte(fmt.Sprintf("%s;%s", id, host))) return err }) if err != nil { return err } _, err = r.Expect(AckChange, time.Second) if err != nil { return err } return nil } func (p *SyncedPool) OwnerChanged(id CartId, host string) error { for _, r := range p.remotes { err := r.ConfirmChange(id, host) if err != nil { log.Printf("Error confirming change: %v from %s\n", err, host) return err } } return nil } func DoPing(host *RemoteHost) error { SendPacket(host.connection, Ping, func(w io.Writer) error { return nil }) _, err := host.Expect(Pong, time.Second) if err != nil { return err } return nil } func (p *SyncedPool) addRemoteHost(address string, remote *RemoteHost) error { known := make([]string, 0, len(p.remotes)) for _, r := range p.remotes { known = append(known, r.Host) if r.Host == address { log.Printf("Remote %s already exists\n", address) return fmt.Errorf("remote %s already exists", address) } } err := DoPing(remote) if err != nil { log.Printf("Error pinging remote %s: %v\n", address, err) } p.remotes = append(p.remotes, remote) connectedRemotes.Set(float64(len(p.remotes))) log.Printf("Added remote %s\n", remote.Host) go func() { p.Negotiate(known) ids := remote.GetCartMappings() p.mu.Lock() for _, id := range ids { p.remoteIndex[id] = remote.Pool } p.mu.Unlock() }() return nil } func (p *SyncedPool) AddRemote(address string) error { if address == "" || p.IsKnown(address) { return nil } connection, err := net.Dial("tcp", fmt.Sprintf("%s:1338", address)) if err != nil { log.Printf("Error connecting to remote %s: %v\n", address, err) return err } pool := NewRemoteGrainPool(address) remote := RemoteHost{ Conn: connection, PacketQueue: NewPacketQueue(connection), Pool: pool, Host: address, } return p.addRemoteHost(address, &remote) } func (p *SyncedPool) Process(id CartId, messages ...Message) ([]byte, error) { // check if local grain exists _, ok := p.local.grains[id] if !ok { // check if remote grain exists p.mu.RLock() remoteGrain, ok := p.remoteIndex[id] p.mu.RUnlock() if ok { remoteLookupCount.Inc() return remoteGrain.Process(id, messages...) } err := p.OwnerChanged(id, p.Hostname) if err != nil { return nil, err } } return p.local.Process(id, messages...) } func (p *SyncedPool) Get(id CartId) ([]byte, error) { // check if local grain exists _, ok := p.local.grains[id] if !ok { // check if remote grain exists p.mu.RLock() remoteGrain, ok := p.remoteIndex[id] p.mu.RUnlock() if ok { remoteLookupCount.Inc() return remoteGrain.Get(id) } err := p.OwnerChanged(id, p.Hostname) if err != nil { return nil, err } } return p.local.Get(id) }