Files
go-cart-actor/synced-pool.go
matst80 a7c5332db0
All checks were successful
Build and Publish / BuildAndDeploy (push) Successful in 1m50s
update
2024-11-09 13:57:09 +01:00

256 lines
5.7 KiB
Go

package main
import (
"context"
"encoding/binary"
"fmt"
"io"
"log"
"net"
"strings"
"time"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/client-go/kubernetes"
)
type Discovery interface {
Discover() ([]string, error)
}
type K8sDiscovery struct {
ctx context.Context
client *kubernetes.Clientset
}
func (k *K8sDiscovery) Discover() ([]string, error) {
return k.DiscoverInNamespace("")
}
func (k *K8sDiscovery) DiscoverInNamespace(namespace string) ([]string, error) {
pods, err := k.client.CoreV1().Pods(namespace).List(k.ctx, metav1.ListOptions{
LabelSelector: "actor-pool=cart",
})
if err != nil {
return nil, err
}
hosts := make([]string, 0, len(pods.Items))
for _, pod := range pods.Items {
hosts = append(hosts, pod.Name)
log.Printf("Found pod %s\n", pod.Name)
}
return hosts, nil
}
func NewK8sDiscovery(client *kubernetes.Clientset) *K8sDiscovery {
return &K8sDiscovery{
ctx: context.Background(),
client: client,
}
}
type Quorum interface {
Negotiate(knownHosts []string) ([]string, error)
ListChanged([]CartId) error
}
type RemoteHost struct {
Host string
Pool *RemoteGrainPool
connection net.Conn
}
type SyncedPool struct {
Discovery Discovery
listener net.Listener
Hostname string
local *GrainLocalPool
remotes []RemoteHost
remoteIndex map[CartId]*RemoteGrainPool
}
func NewSyncedPool(local *GrainLocalPool, hostname string, d Discovery) (*SyncedPool, error) {
listen := fmt.Sprintf("%s:1338", hostname)
l, err := net.Listen("tcp", listen)
if err != nil {
return nil, err
}
pool := &SyncedPool{
Discovery: d,
Hostname: hostname,
local: local,
listener: l,
remotes: make([]RemoteHost, 0),
remoteIndex: make(map[CartId]*RemoteGrainPool),
}
if d != nil {
discoveryTimer := time.NewTicker(time.Second * 5)
go func() {
<-discoveryTimer.C
hosts, err := d.Discover()
if err != nil {
log.Printf("Error discovering hosts: %v\n", err)
return
}
for _, h := range hosts {
err := pool.AddRemote(h)
if err != nil {
log.Printf("Error adding remote %s: %v\n", h, err)
}
}
}()
}
go func() {
for {
conn, err := l.Accept()
if err != nil {
log.Printf("Error accepting connection: %v\n", err)
continue
}
log.Printf("Got connection from %s", conn.RemoteAddr().String())
go pool.handleConnection(conn)
}
}()
return pool, nil
}
const (
RemoteNegotiate = uint16(3)
RemoteGrainChanged = uint16(4)
)
func (p *SyncedPool) handleConnection(conn net.Conn) {
defer conn.Close()
var packet Packet
for {
err := binary.Read(conn, binary.LittleEndian, &packet)
if err != nil {
if err == io.EOF {
break
}
log.Printf("Error in connection: %v\n", err)
}
// if packet.Version != 1 {
// log.Printf("Invalid version %d\n", packet.Version)
// return
// }
switch packet.MessageType {
case RemoteNegotiate:
data := make([]byte, packet.DataLength)
conn.Read(data)
knownHosts := strings.Split(string(data), ";")
log.Printf("Negotiated with remote, found %v hosts\n", knownHosts)
for _, h := range knownHosts {
err = p.AddRemote(h)
if err != nil {
log.Printf("Error adding remote %s: %v\n", h, err)
}
}
SendPacket(conn, RemoteNegotiate, func(w io.Writer) error {
hostnames := make([]string, 0, len(p.remotes))
for _, r := range p.remotes {
hostnames = append(hostnames, r.Host)
}
w.Write([]byte(strings.Join(hostnames, ";")))
return nil
})
case RemoteGrainChanged:
// remote grain changed
log.Printf("Remote grain changed\n")
for err == nil {
id := make([]byte, 16)
_, err = conn.Read(id)
log.Printf("Remote grain %s changed\n", id)
}
}
}
}
func (h *RemoteHost) Negotiate(knownHosts []string) ([]string, error) {
SendPacket(h.connection, RemoteNegotiate, func(w io.Writer) error {
w.Write([]byte(strings.Join(knownHosts, ";")))
return nil
})
t, data, err := ReceivePacket(h.connection)
if err != nil {
return nil, err
}
if t != RemoteNegotiate {
return nil, fmt.Errorf("unexpected message type %d", t)
}
return strings.Split(string(data), ";"), nil
}
func (p *SyncedPool) Negotiate(knownHosts []string) ([]string, error) {
allHosts := make(map[string]struct{}, 0)
for _, r := range p.remotes {
hosts, err := r.Negotiate(knownHosts)
if err != nil {
return nil, err
}
for _, h := range hosts {
allHosts[h] = struct{}{}
}
}
ret := make([]string, 0, len(allHosts))
for h := range allHosts {
ret = append(ret, h)
}
return ret, nil
}
func (p *SyncedPool) ListChanged(ids []CartId) error {
return nil
}
func (p *SyncedPool) AddRemote(address string) error {
for _, r := range p.remotes {
if r.Host == address {
log.Printf("Remote %s already exists\n", address)
return fmt.Errorf("remote %s already exists", address)
}
}
connection, err := net.Dial("tcp", fmt.Sprintf("%s:1338", address))
if err != nil {
return err
}
pool := NewRemoteGrainPool(fmt.Sprintf(address, 1337))
remote := RemoteHost{
connection: connection,
Pool: pool,
Host: address,
}
p.remotes = append(p.remotes, remote)
log.Printf("Added remote %s\n", remote.Host)
return nil
}
func (p *SyncedPool) Process(id CartId, messages ...Message) ([]byte, error) {
// check if local grain exists
_, ok := p.local.grains[id]
if !ok {
// check if remote grain exists
remoteGrain, ok := p.remoteIndex[id]
if ok {
return remoteGrain.Process(id, messages...)
}
}
return p.local.Process(id, messages...)
}
func (p *SyncedPool) Get(id CartId) ([]byte, error) {
// check if local grain exists
_, ok := p.local.grains[id]
if !ok {
// check if remote grain exists
remoteGrain, ok := p.remoteIndex[id]
if ok {
return remoteGrain.Get(id)
}
}
return p.local.Get(id)
}