521 lines
12 KiB
Go
521 lines
12 KiB
Go
package main
|
|
|
|
import (
|
|
"encoding/binary"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"net"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/prometheus/client_golang/prometheus"
|
|
"github.com/prometheus/client_golang/prometheus/promauto"
|
|
)
|
|
|
|
type Quorum interface {
|
|
Negotiate(knownHosts []string) ([]string, error)
|
|
OwnerChanged(CartId, host string) error
|
|
}
|
|
|
|
type RemoteHost struct {
|
|
Host string
|
|
MissedPings int
|
|
Pool *RemoteGrainPool
|
|
connection net.Conn
|
|
queue *PacketQueue
|
|
}
|
|
|
|
type SyncedPool struct {
|
|
mu sync.RWMutex
|
|
Discovery Discovery
|
|
listener net.Listener
|
|
Hostname string
|
|
local *GrainLocalPool
|
|
remotes []*RemoteHost
|
|
remoteIndex map[CartId]*RemoteGrainPool
|
|
}
|
|
|
|
func NewSyncedPool(local *GrainLocalPool, hostname string, d Discovery) (*SyncedPool, error) {
|
|
listen := fmt.Sprintf("%s:1338", hostname)
|
|
l, err := net.Listen("tcp", listen)
|
|
log.Printf("Listening on %s", listen)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
pool := &SyncedPool{
|
|
Discovery: d,
|
|
Hostname: hostname,
|
|
local: local,
|
|
listener: l,
|
|
remotes: make([]*RemoteHost, 0),
|
|
remoteIndex: make(map[CartId]*RemoteGrainPool),
|
|
}
|
|
|
|
go func() {
|
|
for {
|
|
for range time.Tick(time.Second * 2) {
|
|
for _, r := range pool.remotes {
|
|
err := DoPing(r)
|
|
if err != nil {
|
|
r.MissedPings++
|
|
log.Printf("Error pinging remote %s: %v\n, missed pings: %d", r.Host, err, r.MissedPings)
|
|
if r.MissedPings > 3 {
|
|
log.Printf("Removing remote %s\n", r.Host)
|
|
go pool.RemoveHost(r)
|
|
//pool.remotes = append(pool.remotes[:i], pool.remotes[i+1:]...)
|
|
|
|
}
|
|
} else {
|
|
r.MissedPings = 0
|
|
}
|
|
}
|
|
connectedRemotes.Set(float64(len(pool.remotes)))
|
|
}
|
|
}
|
|
}()
|
|
|
|
if d != nil {
|
|
go func() {
|
|
for range time.Tick(time.Second * 5) {
|
|
log.Printf("Looking for new nodes")
|
|
hosts, err := d.Discover()
|
|
if err != nil {
|
|
log.Printf("Error discovering hosts: %v", err)
|
|
}
|
|
for _, h := range hosts {
|
|
if h == hostname {
|
|
continue
|
|
}
|
|
log.Printf("Discovered host %s", h)
|
|
|
|
err := pool.AddRemote(h)
|
|
if err != nil {
|
|
log.Printf("Error adding remote %s: %v", h, err)
|
|
}
|
|
}
|
|
}
|
|
}()
|
|
} else {
|
|
log.Printf("No discovery, waiting for remotes to connect")
|
|
}
|
|
go func() {
|
|
for {
|
|
conn, err := l.Accept()
|
|
if err != nil {
|
|
log.Printf("Error accepting connection: %v\n", err)
|
|
continue
|
|
}
|
|
log.Printf("Got connection from %s", conn.RemoteAddr().Network())
|
|
|
|
go pool.handleConnection(conn)
|
|
}
|
|
}()
|
|
return pool, nil
|
|
}
|
|
|
|
func (p *SyncedPool) RemoveHost(host *RemoteHost) {
|
|
|
|
for i, r := range p.remotes {
|
|
if r == host {
|
|
p.RemoveHostMappedCarts(r)
|
|
p.remotes = append(p.remotes[:i], p.remotes[i+1:]...)
|
|
connectedRemotes.Set(float64(len(p.remotes)))
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func (p *SyncedPool) RemoveHostMappedCarts(host *RemoteHost) {
|
|
p.mu.Lock()
|
|
defer p.mu.Unlock()
|
|
for id, r := range p.remoteIndex {
|
|
if r == host.Pool {
|
|
delete(p.remoteIndex, id)
|
|
}
|
|
}
|
|
}
|
|
|
|
var (
|
|
negotiationCount = promauto.NewCounter(prometheus.CounterOpts{
|
|
Name: "cart_remote_negotiation_total",
|
|
Help: "The total number of remote negotiations",
|
|
})
|
|
grainSyncCount = promauto.NewCounter(prometheus.CounterOpts{
|
|
Name: "cart_grain_sync_total",
|
|
Help: "The total number of grain owner changes",
|
|
})
|
|
connectedRemotes = promauto.NewGauge(prometheus.GaugeOpts{
|
|
Name: "cart_connected_remotes",
|
|
Help: "The number of connected remotes",
|
|
})
|
|
remoteLookupCount = promauto.NewCounter(prometheus.CounterOpts{
|
|
Name: "cart_remote_lookup_total",
|
|
Help: "The total number of remote lookups",
|
|
})
|
|
packetQueue = promauto.NewGauge(prometheus.GaugeOpts{
|
|
Name: "cart_packet_queue_size",
|
|
Help: "The total number of packets in the queue",
|
|
})
|
|
)
|
|
|
|
const (
|
|
RemoteNegotiate = uint16(3)
|
|
RemoteGrainChanged = uint16(4)
|
|
AckChange = uint16(5)
|
|
//AckError = uint16(6)
|
|
Ping = uint16(7)
|
|
Pong = uint16(8)
|
|
GetCartIds = uint16(9)
|
|
CartIdsResponse = uint16(10)
|
|
)
|
|
|
|
type PacketWithData struct {
|
|
MessageType uint16
|
|
Added time.Time
|
|
Data []byte
|
|
}
|
|
|
|
type PacketQueue struct {
|
|
mu sync.RWMutex
|
|
Packets []PacketWithData
|
|
connection net.Conn
|
|
}
|
|
|
|
func NewPacketQueue(connection net.Conn) *PacketQueue {
|
|
queue := &PacketQueue{
|
|
Packets: make([]PacketWithData, 0),
|
|
connection: connection,
|
|
}
|
|
go func() {
|
|
for {
|
|
messageType, data, err := ReceivePacket(queue.connection)
|
|
ts := time.Now()
|
|
if err != nil {
|
|
log.Printf("Error receiving packet: %v\n", err)
|
|
if err == io.EOF {
|
|
return
|
|
}
|
|
|
|
//return
|
|
}
|
|
packetQueue.Inc()
|
|
queue.mu.Lock()
|
|
for i, packet := range queue.Packets {
|
|
if time.Since(packet.Added) < time.Second*5 {
|
|
queue.Packets = queue.Packets[i:]
|
|
packetQueue.Set(float64(len(queue.Packets)))
|
|
break
|
|
}
|
|
}
|
|
queue.Packets = append(queue.Packets, PacketWithData{
|
|
MessageType: messageType,
|
|
Added: ts,
|
|
Data: data,
|
|
})
|
|
queue.mu.Unlock()
|
|
}
|
|
}()
|
|
return queue
|
|
}
|
|
|
|
func (p *PacketQueue) Expect(messageType uint16, timeToWait time.Duration) (*PacketWithData, error) {
|
|
start := time.Now().Add(-time.Millisecond)
|
|
|
|
for {
|
|
if time.Since(start) > timeToWait {
|
|
return nil, fmt.Errorf("timeout waiting for message type %d", messageType)
|
|
}
|
|
p.mu.RLock()
|
|
for _, packet := range p.Packets {
|
|
if packet.MessageType == messageType && packet.Added.After(start) {
|
|
p.mu.RUnlock()
|
|
return &packet, nil
|
|
}
|
|
}
|
|
p.mu.RUnlock()
|
|
time.Sleep(time.Millisecond * 5)
|
|
}
|
|
}
|
|
|
|
func (p *SyncedPool) handleConnection(conn net.Conn) {
|
|
defer conn.Close()
|
|
var packet Packet
|
|
for {
|
|
err := binary.Read(conn, binary.LittleEndian, &packet)
|
|
if err != nil {
|
|
if err == io.EOF {
|
|
break
|
|
}
|
|
log.Printf("Error in connection: %v\n", err)
|
|
}
|
|
// if packet.Version != 1 {
|
|
// log.Printf("Invalid version %d\n", packet.Version)
|
|
// return
|
|
// }
|
|
switch packet.MessageType {
|
|
case Ping:
|
|
err = SendPacket(conn, Pong, func(w io.Writer) error {
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
log.Printf("Error sending pong: %v\n", err)
|
|
}
|
|
case RemoteNegotiate:
|
|
negotiationCount.Inc()
|
|
data := make([]byte, packet.DataLength)
|
|
conn.Read(data)
|
|
knownHosts := strings.Split(string(data), ";")
|
|
log.Printf("Negotiated with remote, found %v hosts\n", knownHosts)
|
|
|
|
SendPacket(conn, RemoteNegotiate, func(w io.Writer) error {
|
|
hostnames := make([]string, 0, len(p.remotes))
|
|
for _, r := range p.remotes {
|
|
hostnames = append(hostnames, r.Host)
|
|
}
|
|
w.Write([]byte(strings.Join(hostnames, ";")))
|
|
return nil
|
|
})
|
|
for _, h := range knownHosts {
|
|
err = p.AddRemote(h)
|
|
if err != nil {
|
|
log.Printf("Error adding remote %s: %v\n", h, err)
|
|
}
|
|
}
|
|
case RemoteGrainChanged:
|
|
// remote grain changed
|
|
grainSyncCount.Inc()
|
|
log.Printf("Remote grain changed\n")
|
|
|
|
idAndHost := make([]byte, packet.DataLength)
|
|
_, err = conn.Read(idAndHost)
|
|
log.Printf("Remote grain %s changed\n", idAndHost)
|
|
if err != nil {
|
|
break
|
|
}
|
|
idAndHostParts := strings.Split(string(idAndHost), ";")
|
|
if len(idAndHostParts) != 2 {
|
|
log.Printf("Invalid remote grain change message\n")
|
|
break
|
|
}
|
|
found := false
|
|
for _, r := range p.remotes {
|
|
if r.Host == string(idAndHostParts[1]) {
|
|
found = true
|
|
log.Printf("Remote grain %s changed to %s\n", idAndHostParts[0], idAndHostParts[1])
|
|
p.mu.Lock()
|
|
p.remoteIndex[ToCartId(idAndHostParts[0])] = r.Pool
|
|
p.mu.Unlock()
|
|
}
|
|
}
|
|
|
|
if !found {
|
|
log.Printf("Remote host %s not found\n", idAndHostParts[1])
|
|
log.Printf("Remotes %v\n", p.remotes)
|
|
} else {
|
|
SendPacket(conn, AckChange, func(w io.Writer) error {
|
|
_, err := w.Write([]byte("ok"))
|
|
return err
|
|
})
|
|
}
|
|
|
|
case GetCartIds:
|
|
ids := make([]string, 0, len(p.local.grains))
|
|
for id := range p.local.grains {
|
|
ids = append(ids, id.String())
|
|
}
|
|
SendPacket(conn, CartIdsResponse, func(w io.Writer) error {
|
|
_, err := w.Write([]byte(strings.Join(ids, ";")))
|
|
return err
|
|
})
|
|
}
|
|
}
|
|
}
|
|
|
|
func (h *RemoteHost) Negotiate(knownHosts []string) ([]string, error) {
|
|
err := SendPacket(h.connection, RemoteNegotiate, func(w io.Writer) error {
|
|
w.Write([]byte(strings.Join(knownHosts, ";")))
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
packet, err := h.queue.Expect(RemoteNegotiate, time.Second)
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return strings.Split(string(packet.Data), ";"), nil
|
|
}
|
|
|
|
func (g *RemoteHost) GetCartMappings() []CartId {
|
|
err := SendPacket(g.connection, GetCartIds, func(w io.Writer) error {
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
log.Printf("Error getting mappings: %v\n", err)
|
|
return nil
|
|
}
|
|
packet, err := g.queue.Expect(CartIdsResponse, time.Second*3)
|
|
if err != nil {
|
|
log.Printf("Error getting mappings: %v\n", err)
|
|
return nil
|
|
}
|
|
parts := strings.Split(string(packet.Data), ";")
|
|
ids := make([]CartId, 0, len(parts))
|
|
for _, p := range parts {
|
|
ids = append(ids, ToCartId(p))
|
|
}
|
|
return ids
|
|
}
|
|
|
|
func (p *SyncedPool) Negotiate(knownHosts []string) ([]string, error) {
|
|
allHosts := make(map[string]struct{}, 0)
|
|
for _, r := range p.remotes {
|
|
hosts, err := r.Negotiate(knownHosts)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
for _, h := range hosts {
|
|
allHosts[h] = struct{}{}
|
|
}
|
|
}
|
|
ret := make([]string, 0, len(allHosts))
|
|
for h := range allHosts {
|
|
ret = append(ret, h)
|
|
}
|
|
return ret, nil
|
|
}
|
|
|
|
func (r *RemoteHost) ConfirmChange(id CartId, host string) error {
|
|
err := SendPacket(r.connection, RemoteGrainChanged, func(w io.Writer) error {
|
|
_, err := w.Write([]byte(fmt.Sprintf("%s;%s", id, host)))
|
|
return err
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
_, err = r.queue.Expect(AckChange, time.Second)
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (p *SyncedPool) OwnerChanged(id CartId, host string) error {
|
|
for _, r := range p.remotes {
|
|
err := r.ConfirmChange(id, host)
|
|
|
|
if err != nil {
|
|
log.Printf("Error confirming change: %v from %s\n", err, host)
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func DoPing(host *RemoteHost) error {
|
|
SendPacket(host.connection, Ping, func(w io.Writer) error {
|
|
return nil
|
|
})
|
|
_, err := host.queue.Expect(Pong, time.Second)
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (p *SyncedPool) addRemoteHost(address string, remote *RemoteHost) error {
|
|
for _, r := range p.remotes {
|
|
if r.Host == address {
|
|
log.Printf("Remote %s already exists\n", address)
|
|
return fmt.Errorf("remote %s already exists", address)
|
|
}
|
|
}
|
|
|
|
err := DoPing(remote)
|
|
if err != nil {
|
|
log.Printf("Error pinging remote %s: %v\n", address, err)
|
|
}
|
|
|
|
p.remotes = append(p.remotes, remote)
|
|
connectedRemotes.Set(float64(len(p.remotes)))
|
|
log.Printf("Added remote %s\n", remote.Host)
|
|
go func() {
|
|
ids := remote.GetCartMappings()
|
|
p.mu.Lock()
|
|
for _, id := range ids {
|
|
p.remoteIndex[id] = remote.Pool
|
|
}
|
|
p.mu.Unlock()
|
|
}()
|
|
return nil
|
|
}
|
|
|
|
func (p *SyncedPool) AddRemote(address string) error {
|
|
|
|
connection, err := net.Dial("tcp", fmt.Sprintf("%s:1338", address))
|
|
if err != nil {
|
|
log.Printf("Error connecting to remote %s: %v\n", address, err)
|
|
return err
|
|
}
|
|
|
|
pool := NewRemoteGrainPool(fmt.Sprintf(address, 1337))
|
|
remote := RemoteHost{
|
|
connection: connection,
|
|
queue: NewPacketQueue(connection),
|
|
Pool: pool,
|
|
Host: address,
|
|
}
|
|
|
|
return p.addRemoteHost(address, &remote)
|
|
}
|
|
|
|
func (p *SyncedPool) Process(id CartId, messages ...Message) ([]byte, error) {
|
|
// check if local grain exists
|
|
_, ok := p.local.grains[id]
|
|
if !ok {
|
|
// check if remote grain exists
|
|
p.mu.RLock()
|
|
remoteGrain, ok := p.remoteIndex[id]
|
|
p.mu.RUnlock()
|
|
if ok {
|
|
remoteLookupCount.Inc()
|
|
return remoteGrain.Process(id, messages...)
|
|
}
|
|
err := p.OwnerChanged(id, p.Hostname)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
return p.local.Process(id, messages...)
|
|
}
|
|
|
|
func (p *SyncedPool) Get(id CartId) ([]byte, error) {
|
|
// check if local grain exists
|
|
_, ok := p.local.grains[id]
|
|
if !ok {
|
|
// check if remote grain exists
|
|
p.mu.RLock()
|
|
remoteGrain, ok := p.remoteIndex[id]
|
|
p.mu.RUnlock()
|
|
if ok {
|
|
remoteLookupCount.Inc()
|
|
return remoteGrain.Get(id)
|
|
}
|
|
err := p.OwnerChanged(id, p.Hostname)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
return p.local.Get(id)
|
|
}
|