tcp mux and stuff
All checks were successful
Build and Publish / BuildAndDeploy (push) Successful in 1m49s

This commit is contained in:
matst80
2024-11-10 16:40:52 +01:00
parent 547c32d4a7
commit 10d85350d0
6 changed files with 470 additions and 295 deletions

View File

@@ -1,11 +1,8 @@
package main
import (
"encoding/binary"
"fmt"
"io"
"log"
"net"
"strings"
"sync"
"time"
@@ -20,149 +17,22 @@ type Quorum interface {
}
type RemoteHost struct {
net.Conn
*PacketQueue
*Client
Host string
MissedPings int
Pool *RemoteGrainPool
}
type SyncedPool struct {
mu sync.RWMutex
Discovery Discovery
listener net.Listener
*Server
mu sync.RWMutex
//Discovery Discovery
Hostname string
local *GrainLocalPool
remotes []*RemoteHost
remoteIndex map[CartId]*RemoteGrainPool
}
func NewSyncedPool(local *GrainLocalPool, hostname string, d Discovery) (*SyncedPool, error) {
listen := fmt.Sprintf("%s:1338", hostname)
l, err := net.Listen("tcp", listen)
log.Printf("Listening on %s", listen)
if err != nil {
return nil, err
}
pool := &SyncedPool{
Discovery: d,
Hostname: hostname,
local: local,
listener: l,
remotes: make([]*RemoteHost, 0),
remoteIndex: make(map[CartId]*RemoteGrainPool),
}
go func() {
for {
for range time.Tick(time.Second * 2) {
for _, r := range pool.remotes {
err := DoPing(r)
if err != nil {
r.MissedPings++
log.Printf("Error pinging remote %s: %v\n, missed pings: %d", r.Host, err, r.MissedPings)
if r.MissedPings > 3 {
log.Printf("Removing remote %s\n", r.Host)
go pool.RemoveHost(r)
//pool.remotes = append(pool.remotes[:i], pool.remotes[i+1:]...)
}
} else {
r.MissedPings = 0
}
}
connectedRemotes.Set(float64(len(pool.remotes)))
}
}
}()
if d != nil {
go func() {
ch, err := d.Watch()
if err != nil {
log.Printf("Error discovering hosts: %v", err)
return
}
for host := range ch {
if pool.IsKnown(host) {
continue
}
go func(h string) {
log.Printf("Discovered host %s, waiting for startup", h)
time.Sleep(time.Second)
err := pool.AddRemote(h)
if err != nil {
log.Printf("Error adding remote %s: %v", h, err)
}
}(host)
}
}()
} else {
log.Printf("No discovery, waiting for remotes to connect")
}
go func() {
for {
conn, err := l.Accept()
if err != nil {
log.Printf("Error accepting connection: %v\n", err)
continue
}
log.Printf("Got connection from %s", conn.RemoteAddr())
go pool.handleConnection(conn)
}
}()
return pool, nil
}
func (p *SyncedPool) IsKnown(host string) bool {
for _, r := range p.remotes {
if r.Host == host {
return true
}
}
return host != p.Hostname
}
func (p *SyncedPool) ExcludeKnown(hosts []string) []string {
ret := make([]string, 0, len(hosts))
for _, h := range hosts {
found := false
for _, r := range p.remotes {
if r.Host == h {
found = true
break
}
}
if !found && h != p.Hostname {
ret = append(ret, h)
}
}
return ret
}
func (p *SyncedPool) RemoveHost(host *RemoteHost) {
for i, r := range p.remotes {
if r == host {
p.RemoveHostMappedCarts(r)
p.remotes = append(p.remotes[:i], p.remotes[i+1:]...)
connectedRemotes.Set(float64(len(p.remotes)))
return
}
}
}
func (p *SyncedPool) RemoveHostMappedCarts(host *RemoteHost) {
p.mu.Lock()
defer p.mu.Unlock()
for id, r := range p.remoteIndex {
if r == host.Pool {
delete(p.remoteIndex, id)
}
}
}
var (
negotiationCount = promauto.NewCounter(prometheus.CounterOpts{
Name: "cart_remote_negotiation_total",
@@ -194,148 +64,207 @@ var (
})
)
func (p *SyncedPool) PongHandler(data []byte) (uint16, []byte, error) {
return Pong, data, nil
}
func (p *SyncedPool) GetCartIdHandler(data []byte) (uint16, []byte, error) {
ids := make([]string, 0, len(p.local.grains))
for id := range p.local.grains {
ids = append(ids, id.String())
}
return CartIdsResponse, []byte(strings.Join(ids, ";")), nil
}
func (p *SyncedPool) NegotiateHandler(data []byte) (uint16, []byte, error) {
negotiationCount.Inc()
log.Printf("Handling negotiation\n")
for _, host := range p.ExcludeKnown(strings.Split(string(data), ";")) {
err := p.AddRemote(host)
if err != nil {
log.Printf("Error adding remote %s: %v\n", host, err)
}
}
return RemoteNegotiateResponse, []byte("ok"), nil
}
func (p *SyncedPool) GrainOwnerChangeHandler(data []byte) (uint16, []byte, error) {
grainSyncCount.Inc()
idAndHostParts := strings.Split(string(data), ";")
if len(idAndHostParts) != 2 {
log.Printf("Invalid remote grain change message\n")
return AckChange, []byte("incorrect"), nil
}
for _, r := range p.remotes {
if r.Host == string(idAndHostParts[1]) {
log.Printf("Remote grain %s changed to %s\n", idAndHostParts[0], idAndHostParts[1])
p.mu.Lock()
if p.local.grains[ToCartId(idAndHostParts[0])] != nil {
log.Printf("Grain %s already exists locally, deleting\n", idAndHostParts[0])
delete(p.local.grains, ToCartId(idAndHostParts[0]))
}
p.remoteIndex[ToCartId(idAndHostParts[0])] = r.Pool
p.mu.Unlock()
return AckChange, []byte("ok"), nil
}
}
return AckChange, []byte("not found"), nil
}
func NewSyncedPool(local *GrainLocalPool, hostname string, discovery Discovery) (*SyncedPool, error) {
listen := fmt.Sprintf("%s:1338", hostname)
server, err := Listen(listen)
if err != nil {
return nil, err
}
log.Printf("Listening on %s", listen)
pool := &SyncedPool{
Server: server,
//Discovery: discovery,
Hostname: hostname,
local: local,
remotes: make([]*RemoteHost, 0),
remoteIndex: make(map[CartId]*RemoteGrainPool),
}
server.HandleCall(Ping, pool.PongHandler)
server.HandleCall(GetCartIds, pool.GetCartIdHandler)
server.HandleCall(RemoteNegotiate, pool.NegotiateHandler)
server.HandleCall(RemoteGrainChanged, pool.GrainOwnerChangeHandler)
// // TODO FIX THIS, ONLY CLIENT OR SERVER SHOULD PING
// go func() {
// for {
// for range time.Tick(time.Second * 2) {
// for _, r := range pool.remotes {
// err := DoPing(r)
// if err != nil {
// r.MissedPings++
// log.Printf("Error pinging remote %s: %v\n, missed pings: %d", r.Host, err, r.MissedPings)
// if r.MissedPings > 3 {
// log.Printf("Removing remote %s\n", r.Host)
// go pool.RemoveHost(r)
// //pool.remotes = append(pool.remotes[:i], pool.remotes[i+1:]...)
// }
// } else {
// r.MissedPings = 0
// }
// }
// connectedRemotes.Set(float64(len(pool.remotes)))
// }
// }
// }()
if discovery != nil {
go func() {
ch, err := discovery.Watch()
if err != nil {
log.Printf("Error discovering hosts: %v", err)
return
}
for host := range ch {
if pool.IsKnown(host) {
continue
}
go func(h string) {
log.Printf("Discovered host %s, waiting for startup", h)
time.Sleep(time.Second)
err := pool.AddRemote(h)
if err != nil {
log.Printf("Error adding remote %s: %v", h, err)
}
}(host)
}
}()
} else {
log.Printf("No discovery, waiting for remotes to connect")
}
return pool, nil
}
func (p *SyncedPool) IsKnown(host string) bool {
for _, r := range p.remotes {
if r.Host == host {
return true
}
}
return host != p.Hostname
}
func (p *SyncedPool) ExcludeKnown(hosts []string) []string {
ret := make([]string, 0, len(hosts))
for _, h := range hosts {
if !p.IsKnown(h) {
ret = append(ret, h)
}
}
return ret
}
func (p *SyncedPool) RemoveHost(host *RemoteHost) {
for i, r := range p.remotes {
if r == host {
p.RemoveHostMappedCarts(r)
p.remotes = append(p.remotes[:i], p.remotes[i+1:]...)
connectedRemotes.Set(float64(len(p.remotes)))
return
}
}
}
func (p *SyncedPool) RemoveHostMappedCarts(host *RemoteHost) {
p.mu.Lock()
defer p.mu.Unlock()
for id, r := range p.remoteIndex {
if r == host.Pool {
delete(p.remoteIndex, id)
}
}
}
const (
RemoteNegotiate = uint16(3)
RemoteGrainChanged = uint16(4)
AckChange = uint16(5)
//AckError = uint16(6)
Ping = uint16(7)
Pong = uint16(8)
GetCartIds = uint16(9)
CartIdsResponse = uint16(10)
Ping = uint16(7)
Pong = uint16(8)
GetCartIds = uint16(9)
CartIdsResponse = uint16(10)
RemoteNegotiateResponse = uint16(11)
)
func (p *SyncedPool) handleConnection(conn net.Conn) {
defer conn.Close()
var packet Packet
for {
err := binary.Read(conn, binary.LittleEndian, &packet)
if err != nil {
if err == io.EOF {
break
}
log.Printf("Error in connection: %v\n", err)
}
// if packet.Version != 1 {
// log.Printf("Invalid version %d\n", packet.Version)
// return
// }
switch packet.MessageType {
case Ping:
err = SendPacket(conn, Pong, func(w io.Writer) error {
return nil
})
if err != nil {
log.Printf("Error sending pong: %v\n", err)
}
case RemoteNegotiate:
negotiationCount.Inc()
data := make([]byte, packet.DataLength)
conn.Read(data)
knownHosts := strings.Split(string(data), ";")
log.Printf("Negotiated with remote, found %v hosts\n", knownHosts)
SendPacket(conn, RemoteNegotiate, func(w io.Writer) error {
hostnames := make([]string, 0, len(p.remotes))
for _, r := range p.remotes {
hostnames = append(hostnames, r.Host)
}
w.Write([]byte(strings.Join(hostnames, ";")))
return nil
})
for _, h := range knownHosts {
err = p.AddRemote(h)
if err != nil {
log.Printf("Error adding remote %s: %v\n", h, err)
}
}
case RemoteGrainChanged:
// remote grain changed
grainSyncCount.Inc()
idAndHost := make([]byte, packet.DataLength)
_, err = conn.Read(idAndHost)
if err != nil {
break
}
idAndHostParts := strings.Split(string(idAndHost), ";")
if len(idAndHostParts) != 2 {
log.Printf("Invalid remote grain change message\n")
break
}
found := false
for _, r := range p.remotes {
if r.Host == string(idAndHostParts[1]) {
found = true
log.Printf("Remote grain %s changed to %s\n", idAndHostParts[0], idAndHostParts[1])
p.mu.Lock()
if p.local.grains[ToCartId(idAndHostParts[0])] != nil {
log.Printf("Grain %s already exists locally, deleting\n", idAndHostParts[0])
delete(p.local.grains, ToCartId(idAndHostParts[0]))
}
p.remoteIndex[ToCartId(idAndHostParts[0])] = r.Pool
p.mu.Unlock()
}
}
if !found {
log.Printf("Remote host %s not found\n", idAndHostParts[1])
} else {
SendPacket(conn, AckChange, func(w io.Writer) error {
_, err := w.Write([]byte("ok"))
return err
})
}
case GetCartIds:
ids := make([]string, 0, len(p.local.grains))
for id := range p.local.grains {
ids = append(ids, id.String())
}
SendPacket(conn, CartIdsResponse, func(w io.Writer) error {
_, err := w.Write([]byte(strings.Join(ids, ";")))
return err
})
}
}
}
func (h *RemoteHost) Negotiate(knownHosts []string) ([]string, error) {
err := SendPacket(h.connection, RemoteNegotiate, func(w io.Writer) error {
w.Write([]byte(strings.Join(knownHosts, ";")))
return nil
})
if err != nil {
return nil, err
}
packet, err := h.Expect(RemoteNegotiate, time.Second)
data, err := h.Call(RemoteNegotiate, RemoteNegotiateResponse, []byte(strings.Join(knownHosts, ";")))
if err != nil {
return nil, err
}
return strings.Split(string(packet.Data), ";"), nil
return strings.Split(string(data), ";"), nil
}
func (g *RemoteHost) GetCartMappings() []CartId {
err := SendPacket(g.connection, GetCartIds, func(w io.Writer) error {
return nil
})
func (g *RemoteHost) GetCartMappings() ([]CartId, error) {
data, err := g.Call(GetCartIds, CartIdsResponse, nil)
if err != nil {
log.Printf("Error getting mappings: %v\n", err)
return nil
return nil, err
}
packet, err := g.Expect(CartIdsResponse, time.Second*3)
if err != nil {
log.Printf("Error getting mappings: %v\n", err)
return nil
}
parts := strings.Split(string(packet.Data), ";")
parts := strings.Split(string(data), ";")
ids := make([]CartId, 0, len(parts))
for _, p := range parts {
ids = append(ids, ToCartId(p))
}
return ids
return ids, nil
}
func (p *SyncedPool) Negotiate(knownHosts []string) ([]string, error) {
@@ -357,18 +286,14 @@ func (p *SyncedPool) Negotiate(knownHosts []string) ([]string, error) {
}
func (r *RemoteHost) ConfirmChange(id CartId, host string) error {
err := SendPacket(r.connection, RemoteGrainChanged, func(w io.Writer) error {
_, err := w.Write([]byte(fmt.Sprintf("%s;%s", id, host)))
return err
})
if err != nil {
return err
}
_, err = r.Expect(AckChange, time.Second)
data, err := r.Call(RemoteGrainChanged, AckChange, []byte(fmt.Sprintf("%s;%s", id, host)))
if err != nil {
return err
}
if string(data) != "ok" {
return fmt.Errorf("remote grain change failed %s", string(data))
}
return nil
}
@@ -386,23 +311,6 @@ func (p *SyncedPool) RequestOwnership(id CartId) error {
return nil
}
func DoPing(host *RemoteHost) error {
err := SendPacket(host, Ping, func(w io.Writer) error {
return nil
})
if err != nil {
return err
}
_, err = host.Expect(Pong, time.Second)
if err != nil {
return err
}
return nil
}
func (p *SyncedPool) addRemoteHost(address string, remote *RemoteHost) error {
known := make([]string, 0, len(p.remotes))
for _, r := range p.remotes {
@@ -413,18 +321,17 @@ func (p *SyncedPool) addRemoteHost(address string, remote *RemoteHost) error {
}
}
err := DoPing(remote)
if err != nil {
log.Printf("Error pinging remote %s: %v\n", address, err)
}
p.remotes = append(p.remotes, remote)
connectedRemotes.Set(float64(len(p.remotes)))
log.Printf("Added remote %s\n", remote.Host)
go func() {
p.Negotiate(known)
ids := remote.GetCartMappings()
ids, err := remote.GetCartMappings()
if err != nil {
log.Printf("Error getting remote mappings: %v\n", err)
return
}
p.mu.Lock()
for _, id := range ids {
if p.local.grains[id] != nil {
@@ -442,7 +349,8 @@ func (p *SyncedPool) AddRemote(address string) error {
if address == "" || p.IsKnown(address) {
return nil
}
connection, err := net.Dial("tcp", fmt.Sprintf("%s:1338", address))
client, err := Dial(fmt.Sprintf("%s:1338", address))
if err != nil {
log.Printf("Error connecting to remote %s: %v\n", address, err)
return err
@@ -450,10 +358,9 @@ func (p *SyncedPool) AddRemote(address string) error {
pool := NewRemoteGrainPool(address)
remote := RemoteHost{
Conn: connection,
PacketQueue: NewPacketQueue(connection),
Pool: pool,
Host: address,
Client: client,
Pool: pool,
Host: address,
}
return p.addRemoteHost(address, &remote)