diff --git a/go.mod b/go.mod index 9f69bec..28dee33 100644 --- a/go.mod +++ b/go.mod @@ -41,6 +41,7 @@ require ( github.com/prometheus/procfs v0.15.1 // indirect github.com/spf13/pflag v1.0.5 // indirect github.com/x448/float16 v0.8.4 // indirect + github.com/yudhasubki/netpool v0.0.0-20230717065341-3c1353ca328e // indirect golang.org/x/net v0.26.0 // indirect golang.org/x/oauth2 v0.21.0 // indirect golang.org/x/sys v0.22.0 // indirect diff --git a/go.sum b/go.sum index 4cdb221..709846b 100644 --- a/go.sum +++ b/go.sum @@ -100,6 +100,8 @@ github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsT github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM= github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg= +github.com/yudhasubki/netpool v0.0.0-20230717065341-3c1353ca328e h1:fAzVSmKQkWflN25ED65CH/C1T3iVWq2BQfN7eQsg4E4= +github.com/yudhasubki/netpool v0.0.0-20230717065341-3c1353ca328e/go.mod h1:gQsFrHrY6nviQu+VX7zKWDyhtLPNzngtYZ+C+7cywdk= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= diff --git a/main.go b/main.go index f8fb35d..ef33359 100644 --- a/main.go +++ b/main.go @@ -233,6 +233,7 @@ func main() { go func() { sig := <-sigs fmt.Println("Shutting down due to signal:", sig) + go syncedPool.Close() app.Save() done <- true }() diff --git a/remote-grain-pool._go b/remote-grain-pool._go new file mode 100644 index 0000000..bc13b41 --- /dev/null +++ b/remote-grain-pool._go @@ -0,0 +1,67 @@ +// package main + +// import "sync" + +// type RemoteGrainPool struct { +// mu sync.RWMutex +// Host string +// grains map[CartId]*RemoteGrain +// } + +// func NewRemoteGrainPool(addr string) *RemoteGrainPool { +// return &RemoteGrainPool{ +// Host: addr, +// grains: make(map[CartId]*RemoteGrain), +// } +// } + +// func (p *RemoteGrainPool) findRemoteGrain(id CartId) *RemoteGrain { +// p.mu.RLock() +// grain, ok := p.grains[id] +// p.mu.RUnlock() +// if !ok { +// return nil +// } +// return grain +// } + +// func (p *RemoteGrainPool) findOrCreateGrain(id CartId) (*RemoteGrain, error) { +// grain := p.findRemoteGrain(id) + +// if grain == nil { +// grain, err := NewRemoteGrain(id, p.Host) +// if err != nil { +// return nil, err +// } +// p.mu.Lock() +// p.grains[id] = grain +// p.mu.Unlock() +// } +// return grain, nil +// } + +// func (p *RemoteGrainPool) Delete(id CartId) { +// p.mu.Lock() +// delete(p.grains, id) +// p.mu.Unlock() +// } + +// func (p *RemoteGrainPool) Process(id CartId, messages ...Message) (*FrameWithPayload, error) { +// var result *FrameWithPayload +// grain, err := p.findOrCreateGrain(id) +// if err != nil { +// return nil, err +// } +// for _, message := range messages { +// result, err = grain.HandleMessage(&message, false) +// } +// return result, err +// } + +// func (p *RemoteGrainPool) Get(id CartId) (*FrameWithPayload, error) { +// grain, err := p.findOrCreateGrain(id) +// if err != nil { +// return nil, err +// } +// return grain.GetCurrentState() +// } diff --git a/remote-grain-pool.go b/remote-grain-pool.go deleted file mode 100644 index 3be28ff..0000000 --- a/remote-grain-pool.go +++ /dev/null @@ -1,67 +0,0 @@ -package main - -import "sync" - -type RemoteGrainPool struct { - mu sync.RWMutex - Host string - grains map[CartId]*RemoteGrain -} - -func NewRemoteGrainPool(addr string) *RemoteGrainPool { - return &RemoteGrainPool{ - Host: addr, - grains: make(map[CartId]*RemoteGrain), - } -} - -func (p *RemoteGrainPool) findRemoteGrain(id CartId) *RemoteGrain { - p.mu.RLock() - grain, ok := p.grains[id] - p.mu.RUnlock() - if !ok { - return nil - } - return grain -} - -func (p *RemoteGrainPool) findOrCreateGrain(id CartId) (*RemoteGrain, error) { - grain := p.findRemoteGrain(id) - - if grain == nil { - grain, err := NewRemoteGrain(id, p.Host) - if err != nil { - return nil, err - } - p.mu.Lock() - p.grains[id] = grain - p.mu.Unlock() - } - return grain, nil -} - -func (p *RemoteGrainPool) Delete(id CartId) { - p.mu.Lock() - delete(p.grains, id) - p.mu.Unlock() -} - -func (p *RemoteGrainPool) Process(id CartId, messages ...Message) (*FrameWithPayload, error) { - var result *FrameWithPayload - grain, err := p.findOrCreateGrain(id) - if err != nil { - return nil, err - } - for _, message := range messages { - result, err = grain.HandleMessage(&message, false) - } - return result, err -} - -func (p *RemoteGrainPool) Get(id CartId) (*FrameWithPayload, error) { - grain, err := p.findOrCreateGrain(id) - if err != nil { - return nil, err - } - return grain.GetCurrentState() -} diff --git a/remote-grain.go b/remote-grain.go index 5dad844..e6cfffb 100644 --- a/remote-grain.go +++ b/remote-grain.go @@ -3,6 +3,8 @@ package main import ( "fmt" "strings" + + "github.com/yudhasubki/netpool" ) func (id CartId) String() string { @@ -43,12 +45,13 @@ type RemoteGrain struct { Host string } -func NewRemoteGrain(id CartId, host string) (*RemoteGrain, error) { +func NewRemoteGrain(id CartId, host string, pool netpool.Netpooler) *RemoteGrain { + addr := fmt.Sprintf("%s:1337", host) return &RemoteGrain{ Id: id, Host: host, - Connection: NewConnection(fmt.Sprintf("%s:1337", host)), - }, nil + Connection: NewConnection(addr, pool), + } } func (g *RemoteGrain) HandleMessage(message *Message, isReplay bool) (*FrameWithPayload, error) { diff --git a/remote-host.go b/remote-host.go index 385dfba..5ca06c8 100644 --- a/remote-host.go +++ b/remote-host.go @@ -4,10 +4,13 @@ import ( "fmt" "log" "strings" + + "github.com/yudhasubki/netpool" ) type RemoteHost struct { *Connection + HostPool netpool.Netpooler Host string MissedPings int } diff --git a/rpc-server.go b/rpc-server.go index ef9d419..7492a85 100644 --- a/rpc-server.go +++ b/rpc-server.go @@ -20,7 +20,7 @@ func (h *GrainHandler) GetState(id CartId, reply *Grain) error { } func NewGrainHandler(pool *GrainLocalPool, listen string) (*GrainHandler, error) { - conn := NewConnection(listen) + conn := NewConnection(listen, nil) server, err := conn.Listen() handler := &GrainHandler{ GenericListener: server, diff --git a/synced-pool.go b/synced-pool.go index 884d1ac..2d6bf76 100644 --- a/synced-pool.go +++ b/synced-pool.go @@ -3,12 +3,14 @@ package main import ( "fmt" "log" + "net" "strings" "sync" "time" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" + "github.com/yudhasubki/netpool" "k8s.io/apimachinery/pkg/watch" ) @@ -154,14 +156,23 @@ func (p *SyncedPool) SpawnRemoteGrain(id CartId, host string) { } go func(i CartId, h string) { - remote, err := NewRemoteGrain(i, h) - if err != nil { - log.Printf("Error creating remote grain %v", err) + var pool netpool.Netpooler + p.mu.RLock() + for _, r := range p.remotes { + if r.Host == h { + pool = r.HostPool + break + } + } + p.mu.RUnlock() + if pool == nil { + log.Printf("Error spawning remote grain, no pool for %s", h) return } + remoteGrain := NewRemoteGrain(i, h, pool) p.mu.Lock() - p.remoteIndex[i] = remote + p.remoteIndex[i] = remoteGrain p.mu.Unlock() }(id, host) } @@ -181,7 +192,7 @@ func (p *SyncedPool) HandleHostError(host string) { func NewSyncedPool(local *GrainLocalPool, hostname string, discovery Discovery) (*SyncedPool, error) { listen := fmt.Sprintf("%s:1338", hostname) - conn := NewConnection(listen) + conn := NewConnection(listen, nil) server, err := conn.Listen() if err != nil { return nil, err @@ -202,6 +213,7 @@ func NewSyncedPool(local *GrainLocalPool, hostname string, discovery Discovery) server.AddHandler(GetCartIds, pool.GetCartIdHandler) server.AddHandler(RemoteNegotiate, pool.NegotiateHandler) server.AddHandler(RemoteGrainChanged, pool.GrainOwnerChangeHandler) + server.AddHandler(Closing, pool.HostTerminatingHandler) if discovery != nil { go func() { @@ -241,6 +253,21 @@ func NewSyncedPool(local *GrainLocalPool, hostname string, discovery Discovery) return pool, nil } +func (p *SyncedPool) HostTerminatingHandler(data *FrameWithPayload, resultChan chan<- FrameWithPayload) error { + log.Printf("Remote host terminating") + host := string(data.Payload) + p.mu.RLock() + defer p.mu.RUnlock() + for _, r := range p.remotes { + if r.Host == host { + go p.RemoveHost(r) + break + } + } + resultChan <- MakeFrameWithPayload(Pong, 200, []byte("ok")) + return nil +} + func (p *SyncedPool) IsHealthy() bool { for _, r := range p.remotes { if !r.IsHealthy() { @@ -299,6 +326,7 @@ const ( GetCartIds = FrameType(9) CartIdsResponse = FrameType(10) RemoteNegotiateResponse = FrameType(11) + Closing = FrameType(12) ) func (p *SyncedPool) Negotiate() { @@ -324,8 +352,8 @@ func (p *SyncedPool) Negotiate() { } func (p *SyncedPool) GetHealthyRemotes() []*RemoteHost { - // p.mu.RLock() - // defer p.mu.RUnlock() + p.mu.RLock() + defer p.mu.RUnlock() remotes := make([]*RemoteHost, 0, len(p.remotes)) for _, r := range p.remotes { if r.IsHealthy() { @@ -340,10 +368,7 @@ func (p *SyncedPool) RequestOwnership(id CartId) error { all := 0 for _, r := range p.GetHealthyRemotes() { - if !r.IsHealthy() { - continue - } - //log.Printf("Asking for confirmation change of %s to %s (me) with %s\n", id, p.Hostname, r.Host) + err := r.ConfirmChange(id, p.Hostname) all++ if err != nil { @@ -381,9 +406,17 @@ func (p *SyncedPool) AddRemote(host string) { return } - client := NewConnection(fmt.Sprintf("%s:1338", host)) + host_pool, err := netpool.New(func() (net.Conn, error) { + return net.Dial("tcp", fmt.Sprintf("%s:1338", host)) + }, netpool.WithMaxPool(1024), netpool.WithMinPool(5)) + + if err != nil { + log.Printf("Error creating pool: %v\n", err) + return + } + + client := NewConnection(fmt.Sprintf("%s:1338", host), host_pool) - var err error pings := 3 for pings >= 0 { _, err = client.Call(Ping, nil) @@ -397,7 +430,12 @@ func (p *SyncedPool) AddRemote(host string) { } log.Printf("Connected to remote %s", host) + cart_pool, err := netpool.New(func() (net.Conn, error) { + return net.Dial("tcp", fmt.Sprintf("%s:1337", host)) + }, netpool.WithMaxPool(1024), netpool.WithMinPool(5)) + remote := RemoteHost{ + HostPool: cart_pool, Connection: client, MissedPings: 0, Host: host, @@ -458,6 +496,15 @@ func (p *SyncedPool) getGrain(id CartId) (Grain, error) { return localGrain, nil } +func (p *SyncedPool) Close() { + p.mu.Lock() + defer p.mu.Unlock() + payload := []byte(p.Hostname) + for _, r := range p.remotes { + go r.Call(Closing, payload) + } +} + func (p *SyncedPool) Process(id CartId, messages ...Message) (*FrameWithPayload, error) { pool, err := p.getGrain(id) var res *FrameWithPayload diff --git a/tcp-connection.go b/tcp-connection.go index 6487cef..0078256 100644 --- a/tcp-connection.go +++ b/tcp-connection.go @@ -1,15 +1,20 @@ package main import ( + "bufio" "encoding/binary" "fmt" + "io" "log" "net" "time" + + "github.com/yudhasubki/netpool" ) type Connection struct { address string + pool netpool.Netpooler count uint64 } @@ -56,9 +61,10 @@ type FrameData interface { FromBytes([]byte) error } -func NewConnection(address string) *Connection { +func NewConnection(address string, pool netpool.Netpooler) *Connection { return &Connection{ count: 0, + pool: pool, address: address, } } @@ -75,7 +81,8 @@ func SendFrame(conn net.Conn, data *FrameWithPayload) error { } func (c *Connection) CallAsync(msg FrameType, payload []byte, ch chan<- FrameWithPayload) (net.Conn, error) { - conn, err := net.Dial("tcp", c.address) + conn, err := c.pool.Get() + //conn, err := net.Dial("tcp", c.address) if err != nil { return conn, err } @@ -91,7 +98,7 @@ func (c *Connection) CallAsync(msg FrameType, payload []byte, ch chan<- FrameWit }(MakeFrameWithPayload(msg, 1, payload)) c.count++ - return conn, nil + return conn, err } func (c *Connection) Call(msg FrameType, data []byte) (*FrameWithPayload, error) { @@ -102,15 +109,17 @@ func (c *Connection) Call(msg FrameType, data []byte) (*FrameWithPayload, error) return nil, err } - defer conn.Close() + defer c.pool.Put(conn, err) // conn.Close() defer close(ch) - select { - case ret := <-ch: - return &ret, nil - case <-time.After(MaxCallDuration): - return nil, fmt.Errorf("timeout") - } + ret := <-ch + return &ret, nil + // select { + // case ret := <-ch: + // return &ret, nil + // case <-time.After(MaxCallDuration): + // return nil, fmt.Errorf("timeout waiting for frame") + // } } func WaitForFrame(conn net.Conn, resultChan chan<- FrameWithPayload) error { @@ -125,6 +134,7 @@ func WaitForFrame(conn net.Conn, resultChan chan<- FrameWithPayload) error { payload := make([]byte, frame.Length) _, err = conn.Read(payload) if err != nil { + conn.Close() return err } resultChan <- FrameWithPayload{ @@ -154,7 +164,8 @@ func (c *Connection) Listen() (*GenericListener, error) { for !ret.StopListener { connection, err := l.Accept() if err != nil { - log.Fatalf("Error accepting connection: %v\n", err) + log.Printf("Error accepting connection: %v\n", err) + continue } go ret.HandleConnection(connection) } @@ -163,22 +174,28 @@ func (c *Connection) Listen() (*GenericListener, error) { } const ( - MaxCallDuration = 500 * time.Millisecond + MaxCallDuration = 300 * time.Millisecond + ListenerKeepalive = 5 * time.Second ) func (l *GenericListener) HandleConnection(conn net.Conn) { - ch := make(chan FrameWithPayload, 1) - conn.SetReadDeadline(time.Now().Add(MaxCallDuration)) - go WaitForFrame(conn, ch) - select { - case frame := <-ch: - err := l.HandleFrame(conn, &frame) - if err != nil { - log.Fatalf("Error in handler: %v\n", err) + var err error + var frame Frame + b := bufio.NewReader(conn) + for err != io.EOF { + + err = binary.Read(b, binary.LittleEndian, &frame) + + if err == nil && frame.IsValid() { + payload := make([]byte, frame.Length) + _, err = b.Read(payload) + if err == nil { + err = l.HandleFrame(conn, &FrameWithPayload{ + Frame: frame, + Payload: payload, + }) + } } - case <-time.After(MaxCallDuration): - close(ch) - log.Printf("Timeout waiting for frame\n") } } @@ -194,13 +211,15 @@ func (l *GenericListener) HandleFrame(conn net.Conn, frame *FrameWithPayload) er defer close(resultChan) err := handler(frame, resultChan) if err != nil { - resultChan <- MakeFrameWithPayload(frame.Type, 500, []byte(err.Error())) + errFrame := MakeFrameWithPayload(frame.Type, 500, []byte(err.Error())) + SendFrame(conn, &errFrame) log.Printf("Handler returned error: %s", err) + return } result := <-resultChan err = SendFrame(conn, &result) if err != nil { - log.Fatalf("Error sending frame: %v\n", err) + log.Printf("Error sending frame: %s", err) } }() } else { diff --git a/tcp-connection_test.go b/tcp-connection_test.go index 331295d..7f72270 100644 --- a/tcp-connection_test.go +++ b/tcp-connection_test.go @@ -2,15 +2,27 @@ package main import ( "fmt" + "net" "testing" + + "github.com/yudhasubki/netpool" ) func TestGenericConnection(t *testing.T) { - conn := NewConnection("localhost:51337") - listener, err := conn.Listen() + + listenConn := NewConnection("127.0.0.1:51337", nil) + listener, err := listenConn.Listen() if err != nil { t.Errorf("Error listening: %v\n", err) } + pool, err := netpool.New(func() (net.Conn, error) { + return net.Dial("tcp", "127.0.0.1:51337") + }, netpool.WithMaxPool(512), netpool.WithMinPool(5)) + if err != nil { + t.Errorf("Error creating pool: %v\n", err) + } + conn := NewConnection("127.0.0.1:51337", pool) + datta := []byte("Hello, world!") listener.AddHandler(Ping, func(input *FrameWithPayload, resultChan chan<- FrameWithPayload) error { resultChan <- MakeFrameWithPayload(Pong, 200, nil) @@ -30,11 +42,12 @@ func TestGenericConnection(t *testing.T) { if r.Type != 2 { t.Errorf("Expected type 2, got %d\n", r.Type) } - response, err := conn.Call(Ping, nil) - if err != nil || response.StatusCode != 200 || response.Type != Pong { - t.Errorf("Error connecting to remote %v\n", response) - } + res, err := conn.Call(3, datta) + if err != nil { + t.Errorf("Did not expect error, got %v\n", err) + return + } if res.StatusCode == 200 { t.Errorf("Expected error, got %v\n", res) } @@ -53,4 +66,9 @@ func TestGenericConnection(t *testing.T) { i++ } + response, err := conn.Call(Ping, nil) + if err != nil || response.StatusCode != 200 || response.Type != Pong { + t.Errorf("Error connecting to remote %v, err: %v\n", response, err) + } + }