tcp mux and stuff
All checks were successful
Build and Publish / BuildAndDeploy (push) Successful in 1m49s
All checks were successful
Build and Publish / BuildAndDeploy (push) Successful in 1m49s
This commit is contained in:
17
packet.go
17
packet.go
@@ -63,12 +63,15 @@ func SendPacket(conn io.Writer, messageType uint16, datafn func(w io.Writer) err
|
||||
}
|
||||
|
||||
func SendRawResponse(conn io.Writer, data []byte) error {
|
||||
binary.Write(conn, binary.LittleEndian, Packet{
|
||||
err := binary.Write(conn, binary.LittleEndian, Packet{
|
||||
Version: 1,
|
||||
MessageType: ResponseBody,
|
||||
DataLength: uint16(len(data)),
|
||||
})
|
||||
_, err := conn.Write(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = conn.Write(data)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -90,6 +93,12 @@ func ReceivePacket(conn io.Reader) (uint16, []byte, error) {
|
||||
return packet.MessageType, nil, err
|
||||
}
|
||||
data := make([]byte, packet.DataLength)
|
||||
_, err = conn.Read(data)
|
||||
return packet.MessageType, data, err
|
||||
l, err := conn.Read(data)
|
||||
if err != nil {
|
||||
return packet.MessageType, nil, err
|
||||
}
|
||||
if l != int(packet.DataLength) {
|
||||
return packet.MessageType, nil, fmt.Errorf("expected %d bytes, got %d", packet.DataLength, l)
|
||||
}
|
||||
return packet.MessageType, data, nil
|
||||
}
|
||||
|
||||
489
synced-pool.go
489
synced-pool.go
@@ -1,11 +1,8 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -20,149 +17,22 @@ type Quorum interface {
|
||||
}
|
||||
|
||||
type RemoteHost struct {
|
||||
net.Conn
|
||||
*PacketQueue
|
||||
*Client
|
||||
Host string
|
||||
MissedPings int
|
||||
Pool *RemoteGrainPool
|
||||
}
|
||||
|
||||
type SyncedPool struct {
|
||||
mu sync.RWMutex
|
||||
Discovery Discovery
|
||||
listener net.Listener
|
||||
*Server
|
||||
mu sync.RWMutex
|
||||
//Discovery Discovery
|
||||
Hostname string
|
||||
local *GrainLocalPool
|
||||
remotes []*RemoteHost
|
||||
remoteIndex map[CartId]*RemoteGrainPool
|
||||
}
|
||||
|
||||
func NewSyncedPool(local *GrainLocalPool, hostname string, d Discovery) (*SyncedPool, error) {
|
||||
listen := fmt.Sprintf("%s:1338", hostname)
|
||||
l, err := net.Listen("tcp", listen)
|
||||
log.Printf("Listening on %s", listen)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
pool := &SyncedPool{
|
||||
Discovery: d,
|
||||
Hostname: hostname,
|
||||
local: local,
|
||||
listener: l,
|
||||
remotes: make([]*RemoteHost, 0),
|
||||
remoteIndex: make(map[CartId]*RemoteGrainPool),
|
||||
}
|
||||
|
||||
go func() {
|
||||
for {
|
||||
for range time.Tick(time.Second * 2) {
|
||||
for _, r := range pool.remotes {
|
||||
err := DoPing(r)
|
||||
if err != nil {
|
||||
r.MissedPings++
|
||||
log.Printf("Error pinging remote %s: %v\n, missed pings: %d", r.Host, err, r.MissedPings)
|
||||
if r.MissedPings > 3 {
|
||||
log.Printf("Removing remote %s\n", r.Host)
|
||||
go pool.RemoveHost(r)
|
||||
//pool.remotes = append(pool.remotes[:i], pool.remotes[i+1:]...)
|
||||
|
||||
}
|
||||
} else {
|
||||
r.MissedPings = 0
|
||||
}
|
||||
}
|
||||
connectedRemotes.Set(float64(len(pool.remotes)))
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
if d != nil {
|
||||
go func() {
|
||||
ch, err := d.Watch()
|
||||
if err != nil {
|
||||
log.Printf("Error discovering hosts: %v", err)
|
||||
return
|
||||
}
|
||||
for host := range ch {
|
||||
if pool.IsKnown(host) {
|
||||
continue
|
||||
}
|
||||
go func(h string) {
|
||||
log.Printf("Discovered host %s, waiting for startup", h)
|
||||
time.Sleep(time.Second)
|
||||
err := pool.AddRemote(h)
|
||||
if err != nil {
|
||||
log.Printf("Error adding remote %s: %v", h, err)
|
||||
}
|
||||
}(host)
|
||||
}
|
||||
}()
|
||||
} else {
|
||||
log.Printf("No discovery, waiting for remotes to connect")
|
||||
}
|
||||
go func() {
|
||||
for {
|
||||
conn, err := l.Accept()
|
||||
if err != nil {
|
||||
log.Printf("Error accepting connection: %v\n", err)
|
||||
continue
|
||||
}
|
||||
log.Printf("Got connection from %s", conn.RemoteAddr())
|
||||
|
||||
go pool.handleConnection(conn)
|
||||
}
|
||||
}()
|
||||
return pool, nil
|
||||
}
|
||||
|
||||
func (p *SyncedPool) IsKnown(host string) bool {
|
||||
for _, r := range p.remotes {
|
||||
if r.Host == host {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return host != p.Hostname
|
||||
}
|
||||
|
||||
func (p *SyncedPool) ExcludeKnown(hosts []string) []string {
|
||||
ret := make([]string, 0, len(hosts))
|
||||
for _, h := range hosts {
|
||||
found := false
|
||||
for _, r := range p.remotes {
|
||||
if r.Host == h {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found && h != p.Hostname {
|
||||
ret = append(ret, h)
|
||||
}
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func (p *SyncedPool) RemoveHost(host *RemoteHost) {
|
||||
|
||||
for i, r := range p.remotes {
|
||||
if r == host {
|
||||
p.RemoveHostMappedCarts(r)
|
||||
p.remotes = append(p.remotes[:i], p.remotes[i+1:]...)
|
||||
connectedRemotes.Set(float64(len(p.remotes)))
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *SyncedPool) RemoveHostMappedCarts(host *RemoteHost) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
for id, r := range p.remoteIndex {
|
||||
if r == host.Pool {
|
||||
delete(p.remoteIndex, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
negotiationCount = promauto.NewCounter(prometheus.CounterOpts{
|
||||
Name: "cart_remote_negotiation_total",
|
||||
@@ -194,148 +64,207 @@ var (
|
||||
})
|
||||
)
|
||||
|
||||
func (p *SyncedPool) PongHandler(data []byte) (uint16, []byte, error) {
|
||||
return Pong, data, nil
|
||||
}
|
||||
|
||||
func (p *SyncedPool) GetCartIdHandler(data []byte) (uint16, []byte, error) {
|
||||
ids := make([]string, 0, len(p.local.grains))
|
||||
for id := range p.local.grains {
|
||||
ids = append(ids, id.String())
|
||||
}
|
||||
return CartIdsResponse, []byte(strings.Join(ids, ";")), nil
|
||||
}
|
||||
|
||||
func (p *SyncedPool) NegotiateHandler(data []byte) (uint16, []byte, error) {
|
||||
negotiationCount.Inc()
|
||||
log.Printf("Handling negotiation\n")
|
||||
for _, host := range p.ExcludeKnown(strings.Split(string(data), ";")) {
|
||||
err := p.AddRemote(host)
|
||||
if err != nil {
|
||||
log.Printf("Error adding remote %s: %v\n", host, err)
|
||||
}
|
||||
}
|
||||
|
||||
return RemoteNegotiateResponse, []byte("ok"), nil
|
||||
}
|
||||
|
||||
func (p *SyncedPool) GrainOwnerChangeHandler(data []byte) (uint16, []byte, error) {
|
||||
grainSyncCount.Inc()
|
||||
|
||||
idAndHostParts := strings.Split(string(data), ";")
|
||||
if len(idAndHostParts) != 2 {
|
||||
log.Printf("Invalid remote grain change message\n")
|
||||
return AckChange, []byte("incorrect"), nil
|
||||
}
|
||||
|
||||
for _, r := range p.remotes {
|
||||
if r.Host == string(idAndHostParts[1]) {
|
||||
|
||||
log.Printf("Remote grain %s changed to %s\n", idAndHostParts[0], idAndHostParts[1])
|
||||
p.mu.Lock()
|
||||
if p.local.grains[ToCartId(idAndHostParts[0])] != nil {
|
||||
log.Printf("Grain %s already exists locally, deleting\n", idAndHostParts[0])
|
||||
delete(p.local.grains, ToCartId(idAndHostParts[0]))
|
||||
}
|
||||
p.remoteIndex[ToCartId(idAndHostParts[0])] = r.Pool
|
||||
p.mu.Unlock()
|
||||
return AckChange, []byte("ok"), nil
|
||||
}
|
||||
}
|
||||
return AckChange, []byte("not found"), nil
|
||||
}
|
||||
|
||||
func NewSyncedPool(local *GrainLocalPool, hostname string, discovery Discovery) (*SyncedPool, error) {
|
||||
listen := fmt.Sprintf("%s:1338", hostname)
|
||||
|
||||
server, err := Listen(listen)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Printf("Listening on %s", listen)
|
||||
|
||||
pool := &SyncedPool{
|
||||
Server: server,
|
||||
//Discovery: discovery,
|
||||
Hostname: hostname,
|
||||
local: local,
|
||||
|
||||
remotes: make([]*RemoteHost, 0),
|
||||
remoteIndex: make(map[CartId]*RemoteGrainPool),
|
||||
}
|
||||
|
||||
server.HandleCall(Ping, pool.PongHandler)
|
||||
server.HandleCall(GetCartIds, pool.GetCartIdHandler)
|
||||
server.HandleCall(RemoteNegotiate, pool.NegotiateHandler)
|
||||
server.HandleCall(RemoteGrainChanged, pool.GrainOwnerChangeHandler)
|
||||
|
||||
// // TODO FIX THIS, ONLY CLIENT OR SERVER SHOULD PING
|
||||
// go func() {
|
||||
// for {
|
||||
// for range time.Tick(time.Second * 2) {
|
||||
// for _, r := range pool.remotes {
|
||||
// err := DoPing(r)
|
||||
// if err != nil {
|
||||
// r.MissedPings++
|
||||
// log.Printf("Error pinging remote %s: %v\n, missed pings: %d", r.Host, err, r.MissedPings)
|
||||
// if r.MissedPings > 3 {
|
||||
// log.Printf("Removing remote %s\n", r.Host)
|
||||
// go pool.RemoveHost(r)
|
||||
// //pool.remotes = append(pool.remotes[:i], pool.remotes[i+1:]...)
|
||||
|
||||
// }
|
||||
// } else {
|
||||
// r.MissedPings = 0
|
||||
// }
|
||||
// }
|
||||
// connectedRemotes.Set(float64(len(pool.remotes)))
|
||||
// }
|
||||
// }
|
||||
// }()
|
||||
|
||||
if discovery != nil {
|
||||
go func() {
|
||||
ch, err := discovery.Watch()
|
||||
if err != nil {
|
||||
log.Printf("Error discovering hosts: %v", err)
|
||||
return
|
||||
}
|
||||
for host := range ch {
|
||||
if pool.IsKnown(host) {
|
||||
continue
|
||||
}
|
||||
go func(h string) {
|
||||
log.Printf("Discovered host %s, waiting for startup", h)
|
||||
time.Sleep(time.Second)
|
||||
err := pool.AddRemote(h)
|
||||
if err != nil {
|
||||
log.Printf("Error adding remote %s: %v", h, err)
|
||||
}
|
||||
}(host)
|
||||
}
|
||||
}()
|
||||
} else {
|
||||
log.Printf("No discovery, waiting for remotes to connect")
|
||||
}
|
||||
|
||||
return pool, nil
|
||||
}
|
||||
|
||||
func (p *SyncedPool) IsKnown(host string) bool {
|
||||
for _, r := range p.remotes {
|
||||
if r.Host == host {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return host != p.Hostname
|
||||
}
|
||||
|
||||
func (p *SyncedPool) ExcludeKnown(hosts []string) []string {
|
||||
ret := make([]string, 0, len(hosts))
|
||||
for _, h := range hosts {
|
||||
if !p.IsKnown(h) {
|
||||
ret = append(ret, h)
|
||||
}
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func (p *SyncedPool) RemoveHost(host *RemoteHost) {
|
||||
for i, r := range p.remotes {
|
||||
if r == host {
|
||||
p.RemoveHostMappedCarts(r)
|
||||
p.remotes = append(p.remotes[:i], p.remotes[i+1:]...)
|
||||
connectedRemotes.Set(float64(len(p.remotes)))
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *SyncedPool) RemoveHostMappedCarts(host *RemoteHost) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
for id, r := range p.remoteIndex {
|
||||
if r == host.Pool {
|
||||
delete(p.remoteIndex, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
RemoteNegotiate = uint16(3)
|
||||
RemoteGrainChanged = uint16(4)
|
||||
AckChange = uint16(5)
|
||||
//AckError = uint16(6)
|
||||
Ping = uint16(7)
|
||||
Pong = uint16(8)
|
||||
GetCartIds = uint16(9)
|
||||
CartIdsResponse = uint16(10)
|
||||
Ping = uint16(7)
|
||||
Pong = uint16(8)
|
||||
GetCartIds = uint16(9)
|
||||
CartIdsResponse = uint16(10)
|
||||
RemoteNegotiateResponse = uint16(11)
|
||||
)
|
||||
|
||||
func (p *SyncedPool) handleConnection(conn net.Conn) {
|
||||
defer conn.Close()
|
||||
var packet Packet
|
||||
for {
|
||||
err := binary.Read(conn, binary.LittleEndian, &packet)
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
log.Printf("Error in connection: %v\n", err)
|
||||
}
|
||||
// if packet.Version != 1 {
|
||||
// log.Printf("Invalid version %d\n", packet.Version)
|
||||
// return
|
||||
// }
|
||||
switch packet.MessageType {
|
||||
case Ping:
|
||||
err = SendPacket(conn, Pong, func(w io.Writer) error {
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
log.Printf("Error sending pong: %v\n", err)
|
||||
}
|
||||
case RemoteNegotiate:
|
||||
negotiationCount.Inc()
|
||||
data := make([]byte, packet.DataLength)
|
||||
conn.Read(data)
|
||||
knownHosts := strings.Split(string(data), ";")
|
||||
log.Printf("Negotiated with remote, found %v hosts\n", knownHosts)
|
||||
|
||||
SendPacket(conn, RemoteNegotiate, func(w io.Writer) error {
|
||||
hostnames := make([]string, 0, len(p.remotes))
|
||||
for _, r := range p.remotes {
|
||||
hostnames = append(hostnames, r.Host)
|
||||
}
|
||||
w.Write([]byte(strings.Join(hostnames, ";")))
|
||||
return nil
|
||||
})
|
||||
for _, h := range knownHosts {
|
||||
err = p.AddRemote(h)
|
||||
if err != nil {
|
||||
log.Printf("Error adding remote %s: %v\n", h, err)
|
||||
}
|
||||
}
|
||||
case RemoteGrainChanged:
|
||||
// remote grain changed
|
||||
grainSyncCount.Inc()
|
||||
|
||||
idAndHost := make([]byte, packet.DataLength)
|
||||
_, err = conn.Read(idAndHost)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
idAndHostParts := strings.Split(string(idAndHost), ";")
|
||||
if len(idAndHostParts) != 2 {
|
||||
log.Printf("Invalid remote grain change message\n")
|
||||
break
|
||||
}
|
||||
found := false
|
||||
for _, r := range p.remotes {
|
||||
if r.Host == string(idAndHostParts[1]) {
|
||||
found = true
|
||||
log.Printf("Remote grain %s changed to %s\n", idAndHostParts[0], idAndHostParts[1])
|
||||
p.mu.Lock()
|
||||
if p.local.grains[ToCartId(idAndHostParts[0])] != nil {
|
||||
log.Printf("Grain %s already exists locally, deleting\n", idAndHostParts[0])
|
||||
delete(p.local.grains, ToCartId(idAndHostParts[0]))
|
||||
}
|
||||
p.remoteIndex[ToCartId(idAndHostParts[0])] = r.Pool
|
||||
p.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
if !found {
|
||||
log.Printf("Remote host %s not found\n", idAndHostParts[1])
|
||||
} else {
|
||||
SendPacket(conn, AckChange, func(w io.Writer) error {
|
||||
_, err := w.Write([]byte("ok"))
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
case GetCartIds:
|
||||
ids := make([]string, 0, len(p.local.grains))
|
||||
for id := range p.local.grains {
|
||||
ids = append(ids, id.String())
|
||||
}
|
||||
SendPacket(conn, CartIdsResponse, func(w io.Writer) error {
|
||||
_, err := w.Write([]byte(strings.Join(ids, ";")))
|
||||
return err
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *RemoteHost) Negotiate(knownHosts []string) ([]string, error) {
|
||||
err := SendPacket(h.connection, RemoteNegotiate, func(w io.Writer) error {
|
||||
w.Write([]byte(strings.Join(knownHosts, ";")))
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
packet, err := h.Expect(RemoteNegotiate, time.Second)
|
||||
data, err := h.Call(RemoteNegotiate, RemoteNegotiateResponse, []byte(strings.Join(knownHosts, ";")))
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return strings.Split(string(packet.Data), ";"), nil
|
||||
return strings.Split(string(data), ";"), nil
|
||||
}
|
||||
|
||||
func (g *RemoteHost) GetCartMappings() []CartId {
|
||||
err := SendPacket(g.connection, GetCartIds, func(w io.Writer) error {
|
||||
return nil
|
||||
})
|
||||
func (g *RemoteHost) GetCartMappings() ([]CartId, error) {
|
||||
data, err := g.Call(GetCartIds, CartIdsResponse, nil)
|
||||
if err != nil {
|
||||
log.Printf("Error getting mappings: %v\n", err)
|
||||
return nil
|
||||
return nil, err
|
||||
}
|
||||
packet, err := g.Expect(CartIdsResponse, time.Second*3)
|
||||
if err != nil {
|
||||
log.Printf("Error getting mappings: %v\n", err)
|
||||
return nil
|
||||
}
|
||||
parts := strings.Split(string(packet.Data), ";")
|
||||
parts := strings.Split(string(data), ";")
|
||||
ids := make([]CartId, 0, len(parts))
|
||||
for _, p := range parts {
|
||||
ids = append(ids, ToCartId(p))
|
||||
}
|
||||
return ids
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
func (p *SyncedPool) Negotiate(knownHosts []string) ([]string, error) {
|
||||
@@ -357,18 +286,14 @@ func (p *SyncedPool) Negotiate(knownHosts []string) ([]string, error) {
|
||||
}
|
||||
|
||||
func (r *RemoteHost) ConfirmChange(id CartId, host string) error {
|
||||
err := SendPacket(r.connection, RemoteGrainChanged, func(w io.Writer) error {
|
||||
_, err := w.Write([]byte(fmt.Sprintf("%s;%s", id, host)))
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = r.Expect(AckChange, time.Second)
|
||||
data, err := r.Call(RemoteGrainChanged, AckChange, []byte(fmt.Sprintf("%s;%s", id, host)))
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if string(data) != "ok" {
|
||||
return fmt.Errorf("remote grain change failed %s", string(data))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -386,23 +311,6 @@ func (p *SyncedPool) RequestOwnership(id CartId) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func DoPing(host *RemoteHost) error {
|
||||
|
||||
err := SendPacket(host, Ping, func(w io.Writer) error {
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = host.Expect(Pong, time.Second)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *SyncedPool) addRemoteHost(address string, remote *RemoteHost) error {
|
||||
known := make([]string, 0, len(p.remotes))
|
||||
for _, r := range p.remotes {
|
||||
@@ -413,18 +321,17 @@ func (p *SyncedPool) addRemoteHost(address string, remote *RemoteHost) error {
|
||||
}
|
||||
}
|
||||
|
||||
err := DoPing(remote)
|
||||
if err != nil {
|
||||
log.Printf("Error pinging remote %s: %v\n", address, err)
|
||||
}
|
||||
|
||||
p.remotes = append(p.remotes, remote)
|
||||
connectedRemotes.Set(float64(len(p.remotes)))
|
||||
log.Printf("Added remote %s\n", remote.Host)
|
||||
|
||||
go func() {
|
||||
p.Negotiate(known)
|
||||
ids := remote.GetCartMappings()
|
||||
ids, err := remote.GetCartMappings()
|
||||
if err != nil {
|
||||
log.Printf("Error getting remote mappings: %v\n", err)
|
||||
return
|
||||
}
|
||||
p.mu.Lock()
|
||||
for _, id := range ids {
|
||||
if p.local.grains[id] != nil {
|
||||
@@ -442,7 +349,8 @@ func (p *SyncedPool) AddRemote(address string) error {
|
||||
if address == "" || p.IsKnown(address) {
|
||||
return nil
|
||||
}
|
||||
connection, err := net.Dial("tcp", fmt.Sprintf("%s:1338", address))
|
||||
client, err := Dial(fmt.Sprintf("%s:1338", address))
|
||||
|
||||
if err != nil {
|
||||
log.Printf("Error connecting to remote %s: %v\n", address, err)
|
||||
return err
|
||||
@@ -450,10 +358,9 @@ func (p *SyncedPool) AddRemote(address string) error {
|
||||
|
||||
pool := NewRemoteGrainPool(address)
|
||||
remote := RemoteHost{
|
||||
Conn: connection,
|
||||
PacketQueue: NewPacketQueue(connection),
|
||||
Pool: pool,
|
||||
Host: address,
|
||||
Client: client,
|
||||
Pool: pool,
|
||||
Host: address,
|
||||
}
|
||||
|
||||
return p.addRemoteHost(address, &remote)
|
||||
|
||||
@@ -32,4 +32,14 @@ func TestConnection(t *testing.T) {
|
||||
if len(allHosts) != 1 {
|
||||
t.Errorf("Expected 1 host, got %d", len(allHosts))
|
||||
}
|
||||
|
||||
data, err := pool.Get(ToCartId("kalle"))
|
||||
if err != nil {
|
||||
t.Errorf("Error getting data: %v", err)
|
||||
}
|
||||
if data == nil {
|
||||
t.Errorf("Expected data, got nil")
|
||||
}
|
||||
time.Sleep(2 * time.Millisecond)
|
||||
|
||||
}
|
||||
|
||||
76
tcp-mux-client.go
Normal file
76
tcp-mux-client.go
Normal file
@@ -0,0 +1,76 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
*TCPClientMux
|
||||
}
|
||||
|
||||
func Dial(address string) (*Client, error) {
|
||||
conn, err := net.Dial("tcp", address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
client := &Client{
|
||||
TCPClientMux: NewTCPClientMux(conn),
|
||||
}
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func (c *Client) Close() {
|
||||
c.Conn.Close()
|
||||
}
|
||||
|
||||
type TCPClientMux struct {
|
||||
net.Conn
|
||||
*PacketQueue
|
||||
}
|
||||
|
||||
func NewTCPClientMux(connection net.Conn) *TCPClientMux {
|
||||
return &TCPClientMux{
|
||||
Conn: connection,
|
||||
PacketQueue: NewPacketQueue(connection),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *TCPClientMux) Close() {
|
||||
m.Conn.Close()
|
||||
}
|
||||
|
||||
func (m *TCPClientMux) SendPacket(messageType uint16, data []byte) error {
|
||||
err := binary.Write(m.Conn, binary.LittleEndian, Packet{
|
||||
Version: 1,
|
||||
MessageType: messageType,
|
||||
DataLength: uint16(len(data)),
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = m.Conn.Write(data)
|
||||
return err
|
||||
}
|
||||
|
||||
func (m *TCPClientMux) SendPacketFn(messageType uint16, datafn func(w io.Writer) error) error {
|
||||
data, err := GetData(datafn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return m.SendPacket(messageType, data)
|
||||
}
|
||||
|
||||
func (m *TCPClientMux) Call(messageType uint16, responseType uint16, data []byte) ([]byte, error) {
|
||||
err := m.SendPacket(messageType, data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
packet, err := m.Expect(responseType, time.Second)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return packet.Data, nil
|
||||
}
|
||||
133
tcp-mux-server.go
Normal file
133
tcp-mux-server.go
Normal file
@@ -0,0 +1,133 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type Server struct {
|
||||
*TCPServerMux
|
||||
}
|
||||
|
||||
func Listen(address string) (*Server, error) {
|
||||
listener, err := net.Listen("tcp", address)
|
||||
server := &Server{
|
||||
NewTCPServerMux(100),
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
go func() {
|
||||
for {
|
||||
conn, err := listener.Accept()
|
||||
if err != nil {
|
||||
log.Printf("Error accepting connection: %v\n", err)
|
||||
continue
|
||||
}
|
||||
go server.HandleConnection(conn)
|
||||
}
|
||||
}()
|
||||
return server, nil
|
||||
}
|
||||
|
||||
type TCPServerMux struct {
|
||||
mu sync.RWMutex
|
||||
listeners map[uint16]func(data []byte) error
|
||||
functions map[uint16]func(data []byte) (uint16, []byte, error)
|
||||
connections []net.Conn
|
||||
}
|
||||
|
||||
func NewTCPServerMux(maxClients int) *TCPServerMux {
|
||||
m := &TCPServerMux{
|
||||
connections: make([]net.Conn, 0, maxClients),
|
||||
mu: sync.RWMutex{},
|
||||
listeners: make(map[uint16]func(data []byte) error),
|
||||
functions: make(map[uint16]func(data []byte) (uint16, []byte, error)),
|
||||
}
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *TCPServerMux) handleListener(messageType uint16, data []byte) (bool, error) {
|
||||
m.mu.RLock()
|
||||
handler, ok := m.listeners[messageType]
|
||||
m.mu.RUnlock()
|
||||
if ok {
|
||||
err := handler(data)
|
||||
if err != nil {
|
||||
return true, err
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (m *TCPServerMux) handleFunction(connection net.Conn, messageType uint16, data []byte) (bool, error) {
|
||||
m.mu.RLock()
|
||||
function, ok := m.functions[messageType]
|
||||
m.mu.RUnlock()
|
||||
if ok {
|
||||
responseType, responseData, err := function(data)
|
||||
if err != nil {
|
||||
return true, err
|
||||
}
|
||||
err = binary.Write(connection, binary.LittleEndian, Packet{
|
||||
Version: 1,
|
||||
MessageType: responseType,
|
||||
DataLength: uint16(len(responseData)),
|
||||
})
|
||||
if err != nil {
|
||||
return true, err
|
||||
}
|
||||
packetsSent.Inc()
|
||||
_, err = connection.Write(responseData)
|
||||
return true, err
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (m *TCPServerMux) HandleConnection(connection net.Conn) error {
|
||||
m.mu.Lock()
|
||||
m.connections = append(m.connections, connection)
|
||||
m.mu.Unlock()
|
||||
defer connection.Close()
|
||||
for {
|
||||
messageType, data, err := ReceivePacket(connection)
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
return nil
|
||||
}
|
||||
log.Printf("Error receiving packet: %v\n", err)
|
||||
return err
|
||||
}
|
||||
|
||||
status, err := m.handleListener(messageType, data)
|
||||
if err != nil {
|
||||
log.Printf("Error handling listener: %v\n", err)
|
||||
}
|
||||
if !status {
|
||||
status, err = m.handleFunction(connection, messageType, data)
|
||||
if err != nil {
|
||||
log.Printf("Error handling function: %v\n", err)
|
||||
}
|
||||
if !status {
|
||||
log.Printf("Unknown message type: %d\n", messageType)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *TCPServerMux) ListenFor(messageType uint16, handler func(data []byte) error) {
|
||||
m.mu.Lock()
|
||||
m.listeners[messageType] = handler
|
||||
m.mu.Unlock()
|
||||
}
|
||||
|
||||
func (m *TCPServerMux) HandleCall(messageType uint16, handler func(data []byte) (uint16, []byte, error)) {
|
||||
m.mu.Lock()
|
||||
m.functions[messageType] = handler
|
||||
m.mu.Unlock()
|
||||
}
|
||||
40
tcp-mux_test.go
Normal file
40
tcp-mux_test.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestTcpHelpers(t *testing.T) {
|
||||
|
||||
server, err := Listen(":1337")
|
||||
if err != nil {
|
||||
t.Errorf("Error listening: %v\n", err)
|
||||
}
|
||||
client, err := Dial("localhost:1337")
|
||||
if err != nil {
|
||||
t.Errorf("Error dialing: %v\n", err)
|
||||
}
|
||||
var messageData string
|
||||
server.ListenFor(1, func(data []byte) error {
|
||||
log.Printf("Received message: %s\n", string(data))
|
||||
messageData = string(data)
|
||||
return nil
|
||||
})
|
||||
server.HandleCall(2, func(data []byte) (uint16, []byte, error) {
|
||||
log.Printf("Received call: %s\n", string(data))
|
||||
return 3, []byte("Hello, client!"), nil
|
||||
})
|
||||
|
||||
client.SendPacket(1, []byte("Hello, world!"))
|
||||
answer, err := client.Call(2, 3, []byte("Hello, server!"))
|
||||
if err != nil {
|
||||
t.Errorf("Error calling: %v\n", err)
|
||||
}
|
||||
if string(answer) != "Hello, client!" {
|
||||
t.Errorf("Expected answer 'Hello, client!', got %s\n", string(answer))
|
||||
}
|
||||
if messageData != "Hello, world!" {
|
||||
t.Errorf("Expected message 'Hello, world!', got %s\n", messageData)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user