Files
go-cart-actor/synced-pool.go
matst80 411b91252b
All checks were successful
Build and Publish / BuildAndDeploy (push) Successful in 1m47s
implement queue
2024-11-09 20:54:23 +01:00

480 lines
11 KiB
Go

package main
import (
"context"
"encoding/binary"
"fmt"
"io"
"log"
"net"
"strings"
"sync"
"time"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/client-go/kubernetes"
)
type Discovery interface {
Discover() ([]string, error)
}
type K8sDiscovery struct {
ctx context.Context
client *kubernetes.Clientset
}
func (k *K8sDiscovery) Discover() ([]string, error) {
return k.DiscoverInNamespace("")
}
func (k *K8sDiscovery) DiscoverInNamespace(namespace string) ([]string, error) {
pods, err := k.client.CoreV1().Pods(namespace).List(k.ctx, metav1.ListOptions{
LabelSelector: "actor-pool=cart",
})
if err != nil {
return nil, err
}
hosts := make([]string, 0, len(pods.Items))
for _, pod := range pods.Items {
hosts = append(hosts, pod.Status.PodIP)
}
return hosts, nil
}
func NewK8sDiscovery(client *kubernetes.Clientset) *K8sDiscovery {
return &K8sDiscovery{
ctx: context.Background(),
client: client,
}
}
type Quorum interface {
Negotiate(knownHosts []string) ([]string, error)
OwnerChanged(CartId, host string) error
}
type RemoteHost struct {
Host string
MissedPings int
Pool *RemoteGrainPool
connection net.Conn
queue *PacketQueue
}
type SyncedPool struct {
Discovery Discovery
listener net.Listener
Hostname string
local *GrainLocalPool
remotes []*RemoteHost
remoteIndex map[CartId]*RemoteGrainPool
}
func NewSyncedPool(local *GrainLocalPool, hostname string, d Discovery) (*SyncedPool, error) {
listen := fmt.Sprintf("%s:1338", hostname)
l, err := net.Listen("tcp", listen)
log.Printf("Listening on %s", listen)
if err != nil {
return nil, err
}
pool := &SyncedPool{
Discovery: d,
Hostname: hostname,
local: local,
listener: l,
remotes: make([]*RemoteHost, 0),
remoteIndex: make(map[CartId]*RemoteGrainPool),
}
pingTimer := time.NewTicker(time.Second)
go func() {
for {
<-pingTimer.C
for i, r := range pool.remotes {
err := DoPing(r)
if err != nil {
r.MissedPings++
log.Printf("Error pinging remote %s: %v\n, missed pings: %d", r.Host, err, r.MissedPings)
if r.MissedPings > 3 {
log.Printf("Removing remote %s\n", r.Host)
pool.remotes = append(pool.remotes[:i], pool.remotes[i+1:]...)
}
} else {
r.MissedPings = 0
}
}
}
}()
if d != nil {
discoveryTimer := time.NewTicker(time.Second * 5)
go func() {
<-discoveryTimer.C
hosts, err := d.Discover()
if err != nil {
log.Printf("Error discovering hosts: %v\n", err)
return
}
for _, h := range hosts {
if h == hostname {
continue
}
log.Printf("Discovered host %s\n", h)
err := pool.AddRemote(h)
if err != nil {
log.Printf("Error adding remote %s: %v\n", h, err)
}
}
otherHosts, err := pool.Negotiate(hosts)
if err != nil {
log.Printf("Error negotiating: %v\n", err)
}
for _, h := range otherHosts {
if h == hostname {
continue
}
found := false
for _, r := range pool.remotes {
if r.Host == h {
found = true
}
}
if found {
continue
}
err := pool.AddRemote(h)
if err != nil {
log.Printf("Error adding undiscovered remote %s: %v\n", h, err)
}
}
}()
}
go func() {
for {
conn, err := l.Accept()
if err != nil {
log.Printf("Error accepting connection: %v\n", err)
continue
}
log.Printf("Got connection from %s", conn.RemoteAddr().String())
go pool.handleConnection(conn)
}
}()
return pool, nil
}
var (
negotiationCount = promauto.NewCounter(prometheus.CounterOpts{
Name: "cart_remote_negotiation_total",
Help: "The total number of remote negotiations",
})
grainSyncCount = promauto.NewCounter(prometheus.CounterOpts{
Name: "cart_grain_sync_total",
Help: "The total number of grain owner changes",
})
connectedRemotes = promauto.NewGauge(prometheus.GaugeOpts{
Name: "cart_connected_remotes",
Help: "The number of connected remotes",
})
)
const (
RemoteNegotiate = uint16(3)
RemoteGrainChanged = uint16(4)
AckChange = uint16(5)
//AckError = uint16(6)
Ping = uint16(7)
Pong = uint16(8)
)
type PacketWithData struct {
MessageType uint16
Data []byte
}
type PacketQueue struct {
mu sync.Mutex
Packets []PacketWithData
connection net.Conn
}
func NewPacketQueue(connection net.Conn) *PacketQueue {
queue := &PacketQueue{
Packets: make([]PacketWithData, 0),
connection: connection,
}
go func() {
for {
messageType, data, err := ReceivePacket(queue.connection)
if err != nil {
log.Printf("Error receiving packet: %v\n", err)
return
}
queue.mu.Lock()
queue.Packets = append(queue.Packets, PacketWithData{
MessageType: messageType,
Data: data,
})
queue.mu.Unlock()
}
}()
return queue
}
func (p *PacketQueue) Expect(messageType uint16, timeToWait time.Duration) (PacketWithData, error) {
start := time.Now()
for {
if time.Since(start) > timeToWait {
return PacketWithData{}, fmt.Errorf("timeout waiting for message type %d", messageType)
}
for i, packet := range p.Packets {
if packet.MessageType == messageType {
p.mu.Lock()
p.Packets = append(p.Packets[:i], p.Packets[i+1:]...)
p.mu.Unlock()
return packet, nil
}
}
time.Sleep(time.Millisecond * 50)
}
}
func (p *SyncedPool) handleConnection(conn net.Conn) {
defer conn.Close()
var packet Packet
for {
err := binary.Read(conn, binary.LittleEndian, &packet)
if err != nil {
if err == io.EOF {
break
}
log.Printf("Error in connection: %v\n", err)
}
// if packet.Version != 1 {
// log.Printf("Invalid version %d\n", packet.Version)
// return
// }
switch packet.MessageType {
case Ping:
err = SendPacket(conn, Pong, func(w io.Writer) error {
return nil
})
if err != nil {
log.Printf("Error sending pong: %v\n", err)
}
case RemoteNegotiate:
negotiationCount.Inc()
data := make([]byte, packet.DataLength)
conn.Read(data)
knownHosts := strings.Split(string(data), ";")
log.Printf("Negotiated with remote, found %v hosts\n", knownHosts)
for _, h := range knownHosts {
err = p.AddRemoteWithConnection(h, conn)
if err != nil {
log.Printf("Error adding remote %s: %v\n", h, err)
}
}
SendPacket(conn, RemoteNegotiate, func(w io.Writer) error {
hostnames := make([]string, 0, len(p.remotes))
for _, r := range p.remotes {
hostnames = append(hostnames, r.Host)
}
w.Write([]byte(strings.Join(hostnames, ";")))
return nil
})
case RemoteGrainChanged:
// remote grain changed
grainSyncCount.Inc()
log.Printf("Remote grain changed\n")
for err == nil {
idAndHost := make([]byte, packet.DataLength)
_, err = conn.Read(idAndHost)
log.Printf("Remote grain %s changed\n", idAndHost)
if err != nil {
break
}
idAndHostParts := strings.Split(string(idAndHost), ";")
if len(idAndHostParts) != 2 {
log.Printf("Invalid remote grain change message\n")
break
}
found := false
for _, r := range p.remotes {
if r.Host == string(idAndHostParts[1]) {
found = true
log.Printf("Remote grain %s changed to %s\n", idAndHostParts[0], idAndHostParts[1])
p.remoteIndex[ToCartId(idAndHostParts[0])] = r.Pool
}
}
if !found {
log.Printf("Remote host %s not found\n", idAndHostParts[1])
log.Printf("Remotes %v\n", p.remotes)
} else {
err = SendPacket(conn, AckChange, func(w io.Writer) error {
_, err := w.Write([]byte("ok"))
return err
})
}
}
}
}
}
func (h *RemoteHost) Negotiate(knownHosts []string) ([]string, error) {
SendPacket(h.connection, RemoteNegotiate, func(w io.Writer) error {
w.Write([]byte(strings.Join(knownHosts, ";")))
return nil
})
packet, err := h.queue.Expect(RemoteNegotiate, time.Second)
if err != nil {
return nil, err
}
return strings.Split(string(packet.Data), ";"), nil
}
func (p *SyncedPool) Negotiate(knownHosts []string) ([]string, error) {
allHosts := make(map[string]struct{}, 0)
for _, r := range p.remotes {
hosts, err := r.Negotiate(knownHosts)
if err != nil {
return nil, err
}
for _, h := range hosts {
allHosts[h] = struct{}{}
}
}
ret := make([]string, 0, len(allHosts))
for h := range allHosts {
ret = append(ret, h)
}
return ret, nil
}
func (r *RemoteHost) ConfirmChange(id CartId, host string) error {
SendPacket(r.connection, RemoteGrainChanged, func(w io.Writer) error {
_, err := w.Write([]byte(fmt.Sprintf("%s;%s", id, host)))
return err
})
_, err := r.queue.Expect(AckChange, time.Second)
if err != nil {
return err
}
return nil
}
func (p *SyncedPool) OwnerChanged(id CartId, host string) error {
for _, r := range p.remotes {
err := r.ConfirmChange(id, host)
if err != nil {
log.Printf("Error confirming change: %v\n", err)
return err
}
}
return nil
}
func (p *SyncedPool) AddRemoteWithConnection(address string, connection net.Conn) error {
pool := NewRemoteGrainPool(fmt.Sprintf(address, 1337))
remote := RemoteHost{
connection: connection,
queue: NewPacketQueue(connection),
Pool: pool,
Host: address,
}
return p.addRemoteHost(address, &remote)
}
func DoPing(host *RemoteHost) error {
SendPacket(host.connection, Ping, func(w io.Writer) error {
return nil
})
_, err := host.queue.Expect(Pong, time.Second)
if err != nil {
return err
}
return nil
}
func (p *SyncedPool) addRemoteHost(address string, remote *RemoteHost) error {
for _, r := range p.remotes {
if r.Host == address {
log.Printf("Remote %s already exists\n", address)
return fmt.Errorf("remote %s already exists", address)
}
}
err := DoPing(remote)
if err != nil {
log.Printf("Error pinging remote %s: %v\n", address, err)
}
p.remotes = append(p.remotes, remote)
connectedRemotes.Set(float64(len(p.remotes)))
log.Printf("Added remote %s\n", remote.Host)
return nil
}
func (p *SyncedPool) AddRemote(address string) error {
connection, err := net.Dial("tcp", fmt.Sprintf("%s:1338", address))
if err != nil {
log.Printf("Error connecting to remote %s: %v\n", address, err)
return err
}
pool := NewRemoteGrainPool(fmt.Sprintf(address, 1337))
remote := RemoteHost{
connection: connection,
queue: NewPacketQueue(connection),
Pool: pool,
Host: address,
}
return p.addRemoteHost(address, &remote)
}
func (p *SyncedPool) Process(id CartId, messages ...Message) ([]byte, error) {
// check if local grain exists
_, ok := p.local.grains[id]
if !ok {
// check if remote grain exists
remoteGrain, ok := p.remoteIndex[id]
if ok {
return remoteGrain.Process(id, messages...)
}
}
err := p.OwnerChanged(id, p.Hostname)
if err != nil {
return nil, err
}
return p.local.Process(id, messages...)
}
func (p *SyncedPool) Get(id CartId) ([]byte, error) {
// check if local grain exists
_, ok := p.local.grains[id]
if !ok {
// check if remote grain exists
remoteGrain, ok := p.remoteIndex[id]
if ok {
return remoteGrain.Get(id)
}
}
return p.local.Get(id)
}