package main import ( "context" "encoding/binary" "fmt" "io" "log" "net" "strings" "time" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/client-go/kubernetes" ) type Discovery interface { Discover() ([]string, error) } type K8sDiscovery struct { ctx context.Context client *kubernetes.Clientset } func (k *K8sDiscovery) Discover() ([]string, error) { return k.DiscoverInNamespace("") } func (k *K8sDiscovery) DiscoverInNamespace(namespace string) ([]string, error) { pods, err := k.client.CoreV1().Pods(namespace).List(k.ctx, metav1.ListOptions{ LabelSelector: "actor-pool=cart", }) if err != nil { return nil, err } hosts := make([]string, 0, len(pods.Items)) for _, pod := range pods.Items { hosts = append(hosts, pod.Status.PodIP) } return hosts, nil } func NewK8sDiscovery(client *kubernetes.Clientset) *K8sDiscovery { return &K8sDiscovery{ ctx: context.Background(), client: client, } } type Quorum interface { Negotiate(knownHosts []string) ([]string, error) OwnerChanged(CartId, host string) error } type RemoteHost struct { Host string Pool *RemoteGrainPool connection net.Conn } type SyncedPool struct { Discovery Discovery listener net.Listener Hostname string local *GrainLocalPool remotes []RemoteHost remoteIndex map[CartId]*RemoteGrainPool } func NewSyncedPool(local *GrainLocalPool, hostname string, d Discovery) (*SyncedPool, error) { listen := fmt.Sprintf("%s:1338", hostname) l, err := net.Listen("tcp", listen) if err != nil { return nil, err } pool := &SyncedPool{ Discovery: d, Hostname: hostname, local: local, listener: l, remotes: make([]RemoteHost, 0), remoteIndex: make(map[CartId]*RemoteGrainPool), } if d != nil { discoveryTimer := time.NewTicker(time.Second * 5) go func() { <-discoveryTimer.C hosts, err := d.Discover() if err != nil { log.Printf("Error discovering hosts: %v\n", err) return } for _, h := range hosts { if h == hostname { continue } log.Printf("Discovered host %s\n", h) err := pool.AddRemote(h) if err != nil { log.Printf("Error adding remote %s: %v\n", h, err) } } }() } go func() { for { conn, err := l.Accept() if err != nil { log.Printf("Error accepting connection: %v\n", err) continue } log.Printf("Got connection from %s", conn.RemoteAddr().String()) go pool.handleConnection(conn) } }() return pool, nil } const ( RemoteNegotiate = uint16(3) RemoteGrainChanged = uint16(4) AckChange = uint16(5) AckError = uint16(6) ) func (p *SyncedPool) handleConnection(conn net.Conn) { defer conn.Close() var packet Packet for { err := binary.Read(conn, binary.LittleEndian, &packet) if err != nil { if err == io.EOF { break } log.Printf("Error in connection: %v\n", err) } // if packet.Version != 1 { // log.Printf("Invalid version %d\n", packet.Version) // return // } switch packet.MessageType { case RemoteNegotiate: data := make([]byte, packet.DataLength) conn.Read(data) knownHosts := strings.Split(string(data), ";") log.Printf("Negotiated with remote, found %v hosts\n", knownHosts) for _, h := range knownHosts { err = p.AddRemote(h) if err != nil { log.Printf("Error adding remote %s: %v\n", h, err) } } SendPacket(conn, RemoteNegotiate, func(w io.Writer) error { hostnames := make([]string, 0, len(p.remotes)) for _, r := range p.remotes { hostnames = append(hostnames, r.Host) } w.Write([]byte(strings.Join(hostnames, ";"))) return nil }) case RemoteGrainChanged: // remote grain changed log.Printf("Remote grain changed\n") for err == nil { idAndHost := make([]byte, packet.DataLength) _, err = conn.Read(idAndHost) log.Printf("Remote grain %s changed\n", idAndHost) if err != nil { break } idAndHostParts := strings.Split(string(idAndHost), ";") if len(idAndHostParts) != 2 { log.Printf("Invalid remote grain change message\n") break } found := false for _, r := range p.remotes { if r.Host == string(idAndHostParts[1]) { found = true log.Printf("Remote grain %s changed to %s\n", idAndHostParts[0], idAndHostParts[1]) p.remoteIndex[ToCartId(idAndHostParts[0])] = r.Pool } } if !found { log.Printf("Remote host %s not found\n", idAndHostParts[1]) SendPacket(conn, AckError, func(w io.Writer) error { w.Write([]byte("remote host not found")) return nil }) } else { SendPacket(conn, AckChange, func(w io.Writer) error { w.Write([]byte("ok")) return nil }) } } } } } func (h *RemoteHost) Negotiate(knownHosts []string) ([]string, error) { SendPacket(h.connection, RemoteNegotiate, func(w io.Writer) error { w.Write([]byte(strings.Join(knownHosts, ";"))) return nil }) t, data, err := ReceivePacket(h.connection) if err != nil { return nil, err } if t != RemoteNegotiate { return nil, fmt.Errorf("unexpected message type %d", t) } return strings.Split(string(data), ";"), nil } func (p *SyncedPool) Negotiate(knownHosts []string) ([]string, error) { allHosts := make(map[string]struct{}, 0) for _, r := range p.remotes { hosts, err := r.Negotiate(knownHosts) if err != nil { return nil, err } for _, h := range hosts { allHosts[h] = struct{}{} } } ret := make([]string, 0, len(allHosts)) for h := range allHosts { ret = append(ret, h) } return ret, nil } func (r *RemoteHost) ConfirmChange(id CartId, host string) error { SendPacket(r.connection, RemoteGrainChanged, func(w io.Writer) error { w.Write([]byte(fmt.Sprintf("%s;%s", id, host))) return nil }) t, _, err := ReceivePacket(r.connection) if err != nil { return err } if t != AckChange { return fmt.Errorf("unexpected message type %d", t) } return nil } func (p *SyncedPool) OwnerChanged(id CartId, host string) error { for _, r := range p.remotes { err := r.ConfirmChange(id, host) if err != nil { log.Printf("Error confirming change: %v\n", err) return err } } return nil } func (p *SyncedPool) AddRemote(address string) error { for _, r := range p.remotes { if r.Host == address { log.Printf("Remote %s already exists\n", address) return fmt.Errorf("remote %s already exists", address) } } connection, err := net.Dial("tcp", fmt.Sprintf("%s:1338", address)) if err != nil { return err } pool := NewRemoteGrainPool(fmt.Sprintf(address, 1337)) remote := RemoteHost{ connection: connection, Pool: pool, Host: address, } p.remotes = append(p.remotes, remote) log.Printf("Added remote %s\n", remote.Host) return nil } func (p *SyncedPool) Process(id CartId, messages ...Message) ([]byte, error) { // check if local grain exists _, ok := p.local.grains[id] if !ok { // check if remote grain exists remoteGrain, ok := p.remoteIndex[id] if ok { return remoteGrain.Process(id, messages...) } } err := p.OwnerChanged(id, p.Hostname) if err != nil { return nil, err } return p.local.Process(id, messages...) } func (p *SyncedPool) Get(id CartId) ([]byte, error) { // check if local grain exists _, ok := p.local.grains[id] if !ok { // check if remote grain exists remoteGrain, ok := p.remoteIndex[id] if ok { return remoteGrain.Get(id) } } return p.local.Get(id) }