major refactor
All checks were successful
Build and Publish / BuildAndDeployAmd64 (push) Successful in 28s
Build and Publish / BuildAndDeploy (push) Successful in 2m18s

This commit is contained in:
matst80
2024-11-13 21:56:40 +01:00
parent 9f7c8227c2
commit abf561c3fe
20 changed files with 310 additions and 1292 deletions

View File

@@ -54,8 +54,8 @@ type CartGrain struct {
type Grain interface { type Grain interface {
GetId() CartId GetId() CartId
HandleMessage(message *Message, isReplay bool) (*CallResult, error) HandleMessage(message *Message, isReplay bool) (*FrameWithPayload, error)
GetCurrentState() (*CallResult, error) GetCurrentState() (*FrameWithPayload, error)
} }
func (c *CartGrain) GetId() CartId { func (c *CartGrain) GetId() CartId {
@@ -69,12 +69,14 @@ func (c *CartGrain) GetLastChange() int64 {
return *c.storageMessages[len(c.storageMessages)-1].TimeStamp return *c.storageMessages[len(c.storageMessages)-1].TimeStamp
} }
func (c *CartGrain) GetCurrentState() (*CallResult, error) { func (c *CartGrain) GetCurrentState() (*FrameWithPayload, error) {
result, err := json.Marshal(c) result, err := json.Marshal(c)
return &CallResult{ if err != nil {
StatusCode: 200, ret := MakeFrameWithPayload(0, 400, []byte(err.Error()))
Data: result, return &ret, nil
}, err }
ret := MakeFrameWithPayload(0, 200, result)
return &ret, nil
} }
func getItemData(sku string, qty int) (*messages.AddItem, error) { func getItemData(sku string, qty int) (*messages.AddItem, error) {
@@ -108,7 +110,7 @@ func getItemData(sku string, qty int) (*messages.AddItem, error) {
}, nil }, nil
} }
func (c *CartGrain) AddItem(sku string, qty int) (*CallResult, error) { func (c *CartGrain) AddItem(sku string, qty int) (*FrameWithPayload, error) {
cartItem, err := getItemData(sku, qty) cartItem, err := getItemData(sku, qty)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -180,7 +182,7 @@ func (c *CartGrain) FindItemWithSku(sku string) (*CartItem, bool) {
return nil, false return nil, false
} }
func (c *CartGrain) HandleMessage(message *Message, isReplay bool) (*CallResult, error) { func (c *CartGrain) HandleMessage(message *Message, isReplay bool) (*FrameWithPayload, error) {
if message.TimeStamp == nil { if message.TimeStamp == nil {
now := time.Now().Unix() now := time.Now().Unix()
message.TimeStamp = &now message.TimeStamp = &now
@@ -305,8 +307,10 @@ func (c *CartGrain) HandleMessage(message *Message, isReplay bool) (*CallResult,
c.mu.Unlock() c.mu.Unlock()
} }
result, err := json.Marshal(c) result, err := json.Marshal(c)
return &CallResult{ return &FrameWithPayload{
Frame: Frame{
StatusCode: 200, StatusCode: 200,
Data: result, },
Payload: result,
}, err }, err
} }

View File

@@ -1,163 +0,0 @@
package main
import (
"bufio"
"fmt"
"log"
"sync"
"time"
)
type CartPacketQueue struct {
mu sync.RWMutex
expectedPackages map[CartMessage]*CartListener
}
const CurrentPacketVersion = 2
type CartListener map[CartId]Listener
func NewCartPacketQueue(connection *PersistentConnection) *CartPacketQueue {
queue := &CartPacketQueue{
expectedPackages: make(map[CartMessage]*CartListener),
}
go queue.HandleConnection(connection)
return queue
}
func (p *CartPacketQueue) RemoveListeners() {
p.mu.Lock()
defer p.mu.Unlock()
for _, l := range p.expectedPackages {
for _, l := range *l {
close(l.Chan)
}
}
p.expectedPackages = make(map[CartMessage]*CartListener)
}
func (p *CartPacketQueue) HandleConnection(connection *PersistentConnection) error {
defer p.RemoveListeners()
defer connection.Close()
var packet CartPacket
reader := bufio.NewReader(connection)
for {
err := ReadCartPacket(reader, &packet)
if err != nil {
log.Printf("Error receiving packet: %v\n", err)
return connection.HandleConnectionError(err)
}
if packet.Version != CurrentPacketVersion {
log.Printf("Incorrect version: %v\n", packet.Version)
return connection.HandleConnectionError(fmt.Errorf("incorrect version: %d", packet.Version))
}
if packet.DataLength == 0 {
go p.HandleData(packet.MessageType, packet.Id, CallResult{
StatusCode: packet.StatusCode,
Data: []byte{},
})
continue
}
data, err := GetPacketData(reader, packet.DataLength)
if err != nil {
log.Printf("Error receiving packet data: %v\n", err)
return connection.HandleConnectionError(err)
}
go p.HandleData(packet.MessageType, packet.Id, CallResult{
StatusCode: packet.StatusCode,
Data: data,
})
}
}
func (p *CartPacketQueue) HandleData(t CartMessage, id CartId, data CallResult) {
p.getListener(t, id, func(l *Listener) {
l.Chan <- data
l.Count--
})
// p.mu.Lock()
// defer p.mu.Unlock()
// pl, ok := p.expectedPackages[t]
// if ok {
// l, ok := (*pl)[id]
// if ok {
// l.Chan <- data
// l.Count--
// if l.Count == 0 {
// close(l.Chan)
// delete(*pl, id)
// }
// }
// }
}
func (p *CartPacketQueue) getListener(t CartMessage, id CartId, fn func(*Listener)) {
p.mu.Lock()
defer p.mu.Unlock()
pl, ok := p.expectedPackages[t]
if ok {
l, ok := (*pl)[id]
if ok {
fn(&l)
if l.Count == 0 {
close(l.Chan)
delete(*pl, id)
}
}
}
}
func CallResultWithTimeout(onTimeout func() CallResult) chan CallResult {
ch := make(chan CallResult, 1)
resultCh := make(chan CallResult, 1)
select {
case ret := <-resultCh:
ch <- ret
case <-time.After(300 * time.Millisecond):
ch <- onTimeout()
}
return ch
}
func (p *CartPacketQueue) MakeChannel(messageType CartMessage, id CartId) chan CallResult {
return CallResultWithTimeout(func() CallResult {
p.getListener(messageType, id, func(l *Listener) {
l.Count--
})
return CallResult{
StatusCode: 504,
Data: []byte("timeout cart call"),
}
})
}
func (p *CartPacketQueue) Expect(messageType CartMessage, id CartId) <-chan CallResult {
p.mu.Lock()
defer p.mu.Unlock()
l, ok := p.expectedPackages[messageType]
if ok {
if idl, idOk := (*l)[id]; idOk {
idl.Count++
return idl.Chan
}
ch := p.MakeChannel(messageType, id)
(*l)[id] = Listener{
Chan: ch,
Count: 1,
}
return ch
}
ch := p.MakeChannel(messageType, id)
p.expectedPackages[messageType] = &CartListener{
id: Listener{
Chan: ch,
Count: 1,
},
}
return ch
}

View File

@@ -27,8 +27,8 @@ var (
) )
type GrainPool interface { type GrainPool interface {
Process(id CartId, messages ...Message) (*CallResult, error) Process(id CartId, messages ...Message) (*FrameWithPayload, error)
Get(id CartId) (*CallResult, error) Get(id CartId) (*FrameWithPayload, error)
} }
type Ttl struct { type Ttl struct {
@@ -142,23 +142,29 @@ func (p *GrainLocalPool) GetGrain(id CartId) (*CartGrain, error) {
return grain, err return grain, err
} }
func (p *GrainLocalPool) Process(id CartId, messages ...Message) ([]byte, error) { func (p *GrainLocalPool) Process(id CartId, messages ...Message) (*FrameWithPayload, error) {
grain, err := p.GetGrain(id) grain, err := p.GetGrain(id)
var result *FrameWithPayload
if err == nil && grain != nil { if err == nil && grain != nil {
for _, message := range messages { for _, message := range messages {
_, err = grain.HandleMessage(&message, false) result, err = grain.HandleMessage(&message, false)
} }
} }
if err != nil { if err != nil {
return nil, err return result, err
} }
return json.Marshal(grain) return result, err
} }
func (p *GrainLocalPool) Get(id CartId) ([]byte, error) { func (p *GrainLocalPool) Get(id CartId) (*FrameWithPayload, error) {
grain, err := p.GetGrain(id) grain, err := p.GetGrain(id)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return json.Marshal(grain) data, err := json.Marshal(grain)
if err != nil {
return nil, err
}
ret := MakeFrameWithPayload(0, 200, data)
return &ret, nil
} }

View File

@@ -1,122 +0,0 @@
package main
import (
"bufio"
"fmt"
"log"
"sync"
"time"
)
type PacketQueue struct {
mu sync.RWMutex
expectedPackages map[PoolMessage]*Listener
}
type CallResult struct {
StatusCode uint32
Data []byte
}
type Listener struct {
Count int
Chan chan CallResult
}
func NewPacketQueue(connection *PersistentConnection) *PacketQueue {
queue := &PacketQueue{
expectedPackages: make(map[PoolMessage]*Listener),
}
go queue.HandleConnection(connection)
return queue
}
func (p *PacketQueue) RemoveListeners() {
p.mu.Lock()
defer p.mu.Unlock()
for _, l := range p.expectedPackages {
close(l.Chan)
}
p.expectedPackages = make(map[PoolMessage]*Listener)
}
func (p *PacketQueue) HandleConnection(connection *PersistentConnection) error {
defer connection.Close()
defer p.RemoveListeners()
var packet Packet
reader := bufio.NewReader(connection)
for {
err := ReadPacket(reader, &packet)
if err != nil {
return connection.HandleConnectionError(err)
}
if packet.Version != CurrentPacketVersion {
log.Printf("Incorrect packet version: %v\n", packet.Version)
return connection.HandleConnectionError(fmt.Errorf("incorrect packet version: %d", packet.Version))
}
if packet.DataLength == 0 {
go p.HandleData(packet.MessageType, CallResult{
StatusCode: packet.StatusCode,
Data: []byte{},
})
continue
}
data, err := GetPacketData(reader, packet.DataLength)
if err != nil {
log.Printf("Error receiving packet data: %v\n", err)
return connection.HandleConnectionError(err)
} else {
go p.HandleData(packet.MessageType, CallResult{
StatusCode: packet.StatusCode,
Data: data,
})
}
}
}
func (p *PacketQueue) HandleData(t PoolMessage, data CallResult) {
p.mu.Lock()
defer p.mu.Unlock()
l, ok := p.expectedPackages[t]
if ok {
l.Chan <- data
l.Count--
if l.Count == 0 {
close(l.Chan)
delete(p.expectedPackages, t)
}
return
}
}
func (p *PacketQueue) Expect(messageType PoolMessage) <-chan CallResult {
p.mu.Lock()
defer p.mu.Unlock()
l, ok := p.expectedPackages[messageType]
if ok {
l.Count++
return l.Chan
}
ch := make(chan CallResult, 1)
go func() {
time.Sleep(time.Millisecond * 300)
p.mu.Lock()
defer p.mu.Unlock()
ch <- CallResult{
StatusCode: 504,
Data: []byte("timeout cart call"),
}
close(ch)
}()
p.expectedPackages[messageType] = &Listener{
Count: 1,
Chan: ch,
}
return ch
}

View File

@@ -1,28 +0,0 @@
package main
import (
"testing"
"time"
)
func TestQueue(t *testing.T) {
localPool := NewGrainLocalPool(100, time.Minute, func(id CartId) (*CartGrain, error) {
return &CartGrain{
Id: id,
storageMessages: []Message{},
Items: []*CartItem{},
TotalPrice: 0,
}, nil
})
pool, err := NewSyncedPool(localPool, "localhost", nil)
if err != nil {
t.Errorf("Error creating pool: %v", err)
}
err = pool.AddRemote("localhost")
if err != nil {
t.Errorf("Error adding remote: %v", err)
return
}
}

130
packet.go
View File

@@ -1,85 +1,77 @@
package main package main
import (
"encoding/binary"
"io"
)
type CartMessage uint32
type PackageVersion uint32
const ( const (
RemoteGetState = CartMessage(0x01) RemoteGetState = FrameType(0x01)
RemoteHandleMutation = CartMessage(0x02) RemoteHandleMutation = FrameType(0x02)
ResponseBody = CartMessage(0x03) ResponseBody = FrameType(0x03)
RemoteGetStateReply = CartMessage(0x04) RemoteGetStateReply = FrameType(0x04)
RemoteHandleMutationReply = CartMessage(0x05) RemoteHandleMutationReply = FrameType(0x05)
) )
type CartPacket struct { // type CartPacket struct {
Version PackageVersion // Version PackageVersion
MessageType CartMessage // MessageType CartMessage
DataLength uint32 // DataLength uint32
StatusCode uint32 // StatusCode uint32
Id CartId // Id CartId
} // }
type Packet struct { // type Packet struct {
Version PackageVersion // Version PackageVersion
MessageType PoolMessage // MessageType PoolMessage
DataLength uint32 // DataLength uint32
StatusCode uint32 // StatusCode uint32
} // }
var headerData = make([]byte, 4) // var headerData = make([]byte, 4)
func matchHeader(conn io.Reader) error { // func matchHeader(conn io.Reader) error {
pos := 0 // pos := 0
for pos < 4 { // for pos < 4 {
l, err := conn.Read(headerData) // l, err := conn.Read(headerData)
if err != nil { // if err != nil {
return err // return err
} // }
for i := 0; i < l; i++ { // for i := 0; i < l; i++ {
if headerData[i] == header[pos] { // if headerData[i] == header[pos] {
pos++ // pos++
if pos == 4 { // if pos == 4 {
return nil // return nil
} // }
} else { // } else {
pos = 0 // pos = 0
} // }
} // }
} // }
return nil // return nil
} // }
func ReadPacket(conn io.Reader, packet *Packet) error { // func ReadPacket(conn io.Reader, packet *Packet) error {
err := matchHeader(conn) // err := matchHeader(conn)
if err != nil { // if err != nil {
return err // return err
} // }
return binary.Read(conn, binary.LittleEndian, packet) // return binary.Read(conn, binary.LittleEndian, packet)
} // }
func ReadCartPacket(conn io.Reader, packet *CartPacket) error { // func ReadCartPacket(conn io.Reader, packet *CartPacket) error {
err := matchHeader(conn) // err := matchHeader(conn)
if err != nil { // if err != nil {
return err // return err
} // }
return binary.Read(conn, binary.LittleEndian, packet) // return binary.Read(conn, binary.LittleEndian, packet)
} // }
func GetPacketData(conn io.Reader, len uint32) ([]byte, error) { // func GetPacketData(conn io.Reader, len uint32) ([]byte, error) {
if len == 0 { // if len == 0 {
return []byte{}, nil // return []byte{}, nil
} // }
data := make([]byte, len) // data := make([]byte, len)
_, err := conn.Read(data) // _, err := conn.Read(data)
return data, err // return data, err
} // }
// func ReceivePacket(conn io.Reader) (uint32, []byte, error) { // func ReceivePacket(conn io.Reader) (uint32, []byte, error) {
// var packet Packet // var packet Packet

View File

@@ -55,7 +55,7 @@ func ErrorHandler(fn func(w http.ResponseWriter, r *http.Request) error) func(w
} }
} }
func (s *PoolServer) WriteResult(w http.ResponseWriter, result *CallResult) error { func (s *PoolServer) WriteResult(w http.ResponseWriter, result *FrameWithPayload) error {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
w.Header().Set("X-Pod-Name", s.pod_name) w.Header().Set("X-Pod-Name", s.pod_name)
if result.StatusCode != 200 { if result.StatusCode != 200 {
@@ -65,11 +65,11 @@ func (s *PoolServer) WriteResult(w http.ResponseWriter, result *CallResult) erro
} else { } else {
w.WriteHeader(http.StatusInternalServerError) w.WriteHeader(http.StatusInternalServerError)
} }
w.Write([]byte(result.Data)) w.Write([]byte(result.Payload))
return nil return nil
} }
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
_, err := w.Write(result.Data) _, err := w.Write(result.Payload)
return err return err
} }

View File

@@ -46,8 +46,8 @@ func (p *RemoteGrainPool) Delete(id CartId) {
p.mu.Unlock() p.mu.Unlock()
} }
func (p *RemoteGrainPool) Process(id CartId, messages ...Message) (*CallResult, error) { func (p *RemoteGrainPool) Process(id CartId, messages ...Message) (*FrameWithPayload, error) {
var result *CallResult var result *FrameWithPayload
grain, err := p.findOrCreateGrain(id) grain, err := p.findOrCreateGrain(id)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -58,7 +58,7 @@ func (p *RemoteGrainPool) Process(id CartId, messages ...Message) (*CallResult,
return result, err return result, err
} }
func (p *RemoteGrainPool) Get(id CartId) (*CallResult, error) { func (p *RemoteGrainPool) Get(id CartId) (*FrameWithPayload, error) {
grain, err := p.findOrCreateGrain(id) grain, err := p.findOrCreateGrain(id)
if err != nil { if err != nil {
return nil, err return nil, err

View File

@@ -3,7 +3,6 @@ package main
import ( import (
"fmt" "fmt"
"strings" "strings"
"time"
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto" "github.com/prometheus/client_golang/prometheus/promauto"
@@ -13,6 +12,25 @@ func (id CartId) String() string {
return strings.Trim(string(id[:]), "\x00") return strings.Trim(string(id[:]), "\x00")
} }
type CartIdPayload struct {
Id CartId
Data []byte
}
func MakeCartInnerFrame(id CartId, payload []byte) []byte {
return append(id[:], payload...)
}
func GetCartFrame(data []byte) (*CartIdPayload, error) {
if len(data) < 16 {
return nil, fmt.Errorf("data too short")
}
return &CartIdPayload{
Id: CartId(data[:16]),
Data: data[16:],
}, nil
}
func ToCartId(id string) CartId { func ToCartId(id string) CartId {
var result [16]byte var result [16]byte
copy(result[:], []byte(id)) copy(result[:], []byte(id))
@@ -20,21 +38,16 @@ func ToCartId(id string) CartId {
} }
type RemoteGrain struct { type RemoteGrain struct {
*CartClient *Connection
Id CartId Id CartId
Host string Host string
} }
func NewRemoteGrain(id CartId, host string) (*RemoteGrain, error) { func NewRemoteGrain(id CartId, host string) (*RemoteGrain, error) {
client, err := CartDial(fmt.Sprintf("%s:1337", host))
if err != nil {
return nil, err
}
return &RemoteGrain{ return &RemoteGrain{
Id: id, Id: id,
Host: host, Host: host,
CartClient: client, Connection: NewConnection(fmt.Sprintf("%s:1337", host)),
}, nil }, nil
} }
@@ -49,47 +62,35 @@ var (
}) })
) )
var start time.Time // var start time.Time
func MeasureLatency(fn func() (*CallResult, error)) (*CallResult, error) { // func MeasureLatency(fn func() (*CallResult, error)) (*CallResult, error) {
start = time.Now() // start = time.Now()
data, err := fn() // data, err := fn()
if err != nil { // if err != nil {
return data, err // return data, err
} // }
elapsed := time.Since(start).Milliseconds() // elapsed := time.Since(start).Milliseconds()
go func() { // go func() {
remoteCartLatency.Add(float64(elapsed)) // remoteCartLatency.Add(float64(elapsed))
remoteCartCallsTotal.Inc() // remoteCartCallsTotal.Inc()
}() // }()
return data, nil // return data, nil
} // }
func (g *RemoteGrain) HandleMessage(message *Message, isReplay bool) (*CallResult, error) { func (g *RemoteGrain) HandleMessage(message *Message, isReplay bool) (*FrameWithPayload, error) {
data, err := GetData(message.Write) data, err := GetData(message.Write)
if err != nil { if err != nil {
return nil, err return nil, err
} }
reply, err := MeasureLatency(func() (*CallResult, error) { return g.Call(RemoteHandleMutation, MakeCartInnerFrame(g.Id, data))
return g.Call(RemoteHandleMutation, g.Id, RemoteHandleMutationReply, data)
})
if err != nil {
return nil, err
}
return reply, err
}
func (g *RemoteGrain) Close() {
g.CartClient.PersistentConnection.Close()
} }
func (g *RemoteGrain) GetId() CartId { func (g *RemoteGrain) GetId() CartId {
return g.Id return g.Id
} }
func (g *RemoteGrain) GetCurrentState() (*CallResult, error) { func (g *RemoteGrain) GetCurrentState() (*FrameWithPayload, error) {
return MeasureLatency(func() (*CallResult, error) { return g.Call(RemoteGetState, g.Id, RemoteGetStateReply, []byte{}) }) return g.Call(RemoteGetState, MakeCartInnerFrame(g.Id, nil))
} }

View File

@@ -7,13 +7,13 @@ import (
) )
type RemoteHost struct { type RemoteHost struct {
*Client *Connection
Host string Host string
MissedPings int MissedPings int
} }
func (h *RemoteHost) IsHealthy() bool { func (h *RemoteHost) IsHealthy() bool {
return !h.PersistentConnection.Dead && h.MissedPings < 3 return h.MissedPings < 3
} }
func (h *RemoteHost) Initialize(p *SyncedPool) { func (h *RemoteHost) Initialize(p *SyncedPool) {
@@ -38,15 +38,11 @@ func (h *RemoteHost) Initialize(p *SyncedPool) {
} }
func (h *RemoteHost) Ping() error { func (h *RemoteHost) Ping() error {
_, err := h.Call(Ping, Pong, []byte{}) result, err := h.Call(Ping, nil)
if err != nil { if err != nil || result.StatusCode != 200 || result.Type != Pong {
h.MissedPings++ h.MissedPings++
log.Printf("Error pinging remote %s, missed pings: %d", h.Host, h.MissedPings) log.Printf("Error pinging remote %s, missed pings: %d", h.Host, h.MissedPings)
if !h.IsHealthy() {
h.Close()
return fmt.Errorf("remote %s is dead", h.Host)
}
} else { } else {
h.MissedPings = 0 h.MissedPings = 0
} }
@@ -54,28 +50,28 @@ func (h *RemoteHost) Ping() error {
} }
func (h *RemoteHost) Negotiate(knownHosts []string) ([]string, error) { func (h *RemoteHost) Negotiate(knownHosts []string) ([]string, error) {
reply, err := h.Call(RemoteNegotiate, RemoteNegotiateResponse, []byte(strings.Join(knownHosts, ";"))) reply, err := h.Call(RemoteNegotiate, []byte(strings.Join(knownHosts, ";")))
if err != nil { if err != nil {
return nil, err return nil, err
} }
if reply.StatusCode != 200 { if reply.StatusCode != 200 {
return nil, fmt.Errorf("remote returned error on negotiate: %s", string(reply.Data)) return nil, fmt.Errorf("remote returned error on negotiate: %s", string(reply.Payload))
} }
return strings.Split(string(reply.Data), ";"), nil return strings.Split(string(reply.Payload), ";"), nil
} }
func (g *RemoteHost) GetCartMappings() ([]CartId, error) { func (g *RemoteHost) GetCartMappings() ([]CartId, error) {
reply, err := g.Call(GetCartIds, CartIdsResponse, []byte{}) reply, err := g.Call(GetCartIds, []byte{})
if err != nil { if err != nil {
return nil, err return nil, err
} }
if reply.StatusCode != 200 { if reply.StatusCode != 200 || reply.Type != CartIdsResponse {
log.Printf("Remote returned error on get cart mappings: %s", string(reply.Data)) log.Printf("Remote returned error on get cart mappings: %s", string(reply.Payload))
return nil, fmt.Errorf("remote returned error: %s", string(reply.Data)) return nil, fmt.Errorf("remote returned incorrect data")
} }
parts := strings.Split(string(reply.Data), ";") parts := strings.Split(string(reply.Payload), ";")
ids := make([]CartId, 0, len(parts)) ids := make([]CartId, 0, len(parts))
for _, p := range parts { for _, p := range parts {
ids = append(ids, ToCartId(p)) ids = append(ids, ToCartId(p))
@@ -84,14 +80,11 @@ func (g *RemoteHost) GetCartMappings() ([]CartId, error) {
} }
func (r *RemoteHost) ConfirmChange(id CartId, host string) error { func (r *RemoteHost) ConfirmChange(id CartId, host string) error {
reply, err := r.Call(RemoteGrainChanged, AckChange, []byte(fmt.Sprintf("%s;%s", id, host))) reply, err := r.Call(RemoteGrainChanged, []byte(fmt.Sprintf("%s;%s", id, host)))
if err != nil { if err != nil || reply.StatusCode != 200 || reply.Type != AckChange {
return err return err
} }
if string(reply.Data) != "ok" {
return fmt.Errorf("remote grain change failed %s", string(reply.Data))
}
return nil return nil
} }

View File

@@ -6,7 +6,7 @@ import (
) )
type GrainHandler struct { type GrainHandler struct {
*CartServer *GenericListener
pool *GrainLocalPool pool *GrainLocalPool
} }
@@ -20,13 +20,14 @@ func (h *GrainHandler) GetState(id CartId, reply *Grain) error {
} }
func NewGrainHandler(pool *GrainLocalPool, listen string) (*GrainHandler, error) { func NewGrainHandler(pool *GrainLocalPool, listen string) (*GrainHandler, error) {
server, err := CartListen(listen) conn := NewConnection(listen)
server, err := conn.Listen()
handler := &GrainHandler{ handler := &GrainHandler{
CartServer: server, GenericListener: server,
pool: pool, pool: pool,
} }
server.HandleCall(RemoteHandleMutation, handler.RemoteHandleMessageHandler) server.AddHandler(RemoteHandleMutation, handler.RemoteHandleMessageHandler)
server.HandleCall(RemoteGetState, handler.RemoteGetStateHandler) server.AddHandler(RemoteGetState, handler.RemoteGetStateHandler)
return handler, err return handler, err
} }
@@ -34,29 +35,36 @@ func (h *GrainHandler) IsHealthy() bool {
return len(h.pool.grains) < h.pool.PoolSize return len(h.pool.grains) < h.pool.PoolSize
} }
func (h *GrainHandler) RemoteHandleMessageHandler(id CartId, data []byte) (CartMessage, []byte, error) { func (h *GrainHandler) RemoteHandleMessageHandler(data *FrameWithPayload, resultChan chan<- FrameWithPayload) error {
cartData, err := GetCartFrame(data.Payload)
if err != nil {
return err
}
var msg Message var msg Message
err := ReadMessage(bytes.NewReader(data), &msg) err = ReadMessage(bytes.NewReader(cartData.Data), &msg)
if err != nil { if err != nil {
fmt.Println("Error reading message:", err) fmt.Println("Error reading message:", err)
return RemoteHandleMutationReply, nil, err return err
} }
replyData, err := h.pool.Process(id, msg) replyData, err := h.pool.Process(cartData.Id, msg)
if err != nil { if err != nil {
fmt.Println("Error handling message:", err) fmt.Println("Error handling message:", err)
} }
if err != nil { resultChan <- *replyData
return RemoteHandleMutationReply, nil, err return nil
}
return RemoteHandleMutationReply, replyData, nil
} }
func (h *GrainHandler) RemoteGetStateHandler(id CartId, data []byte) (CartMessage, []byte, error) { func (h *GrainHandler) RemoteGetStateHandler(data *FrameWithPayload, resultChan chan<- FrameWithPayload) error {
reply, err := h.pool.Get(id) cartData, err := GetCartFrame(data.Payload)
if err != nil { if err != nil {
return RemoteGetStateReply, nil, err return err
} }
return RemoteGetStateReply, reply, nil reply, err := h.pool.Get(cartData.Id)
if err != nil {
return err
}
resultChan <- *reply
return nil
} }

View File

@@ -22,7 +22,7 @@ type HealthHandler interface {
} }
type SyncedPool struct { type SyncedPool struct {
*Server Server *GenericListener
mu sync.RWMutex mu sync.RWMutex
Hostname string Hostname string
local *GrainLocalPool local *GrainLocalPool
@@ -61,11 +61,16 @@ var (
}) })
) )
func (p *SyncedPool) PongHandler(data []byte) (PoolMessage, []byte, error) { var (
return Pong, data, nil PongResponse = MakeFrameWithPayload(Pong, 200, nil)
)
func (p *SyncedPool) PongHandler(data *FrameWithPayload, resultChan chan<- FrameWithPayload) error {
resultChan <- PongResponse
return nil
} }
func (p *SyncedPool) GetCartIdHandler(data []byte) (PoolMessage, []byte, error) { func (p *SyncedPool) GetCartIdHandler(data *FrameWithPayload, resultChan chan<- FrameWithPayload) error {
ids := make([]string, 0, len(p.local.grains)) ids := make([]string, 0, len(p.local.grains))
for id := range p.local.grains { for id := range p.local.grains {
if p.local.grains[id] == nil { if p.local.grains[id] == nil {
@@ -78,45 +83,45 @@ func (p *SyncedPool) GetCartIdHandler(data []byte) (PoolMessage, []byte, error)
ids = append(ids, s) ids = append(ids, s)
} }
log.Printf("Returning %d cart ids\n", len(ids)) log.Printf("Returning %d cart ids\n", len(ids))
return CartIdsResponse, []byte(strings.Join(ids, ";")), nil resultChan <- MakeFrameWithPayload(CartIdsResponse, 200, []byte(strings.Join(ids, ";")))
return nil
} }
func (p *SyncedPool) NegotiateHandler(data []byte) (PoolMessage, []byte, error) { func (p *SyncedPool) NegotiateHandler(data *FrameWithPayload, resultChan chan<- FrameWithPayload) error {
negotiationCount.Inc() negotiationCount.Inc()
log.Printf("Handling negotiation\n") log.Printf("Handling negotiation\n")
for _, host := range p.ExcludeKnown(strings.Split(string(data), ";")) { for _, host := range p.ExcludeKnown(strings.Split(string(data.Payload), ";")) {
if host == "" { if host == "" {
continue continue
} }
go p.AddRemote(host) go p.AddRemote(host)
} }
resultChan <- MakeFrameWithPayload(RemoteNegotiateResponse, 200, []byte("ok"))
return RemoteNegotiateResponse, []byte("ok"), nil return nil
} }
func (p *SyncedPool) GrainOwnerChangeHandler(data []byte) (PoolMessage, []byte, error) { func (p *SyncedPool) GrainOwnerChangeHandler(data *FrameWithPayload, resultChan chan<- FrameWithPayload) error {
grainSyncCount.Inc() grainSyncCount.Inc()
idAndHostParts := strings.Split(string(data), ";") idAndHostParts := strings.Split(string(data.Payload), ";")
if len(idAndHostParts) != 2 { if len(idAndHostParts) != 2 {
log.Printf("Invalid remote grain change message\n") log.Printf("Invalid remote grain change message\n")
return AckChange, []byte("incorrect"), fmt.Errorf("invalid remote grain change message") resultChan <- MakeFrameWithPayload(AckError, 400, []byte("invalid"))
return nil
} }
id := ToCartId(idAndHostParts[0]) id := ToCartId(idAndHostParts[0])
host := idAndHostParts[1] host := idAndHostParts[1]
log.Printf("Handling remote grain owner change to %s for id %s\n", host, id) log.Printf("Handling remote grain owner change to %s for id %s\n", host, id)
for _, r := range p.remotes { for _, r := range p.remotes {
if r.Host == host && r.IsHealthy() { if r.Host == host && r.IsHealthy() {
// log.Printf("Remote grain %s changed to %s\n", id, host)
go p.SpawnRemoteGrain(id, host) go p.SpawnRemoteGrain(id, host)
break
return AckChange, []byte("ok"), nil
} }
} }
go p.AddRemote(host) go p.AddRemote(host)
return AckChange, []byte("ok"), nil resultChan <- MakeFrameWithPayload(AckChange, 200, []byte("ok"))
return nil
} }
func (p *SyncedPool) RemoveRemoteGrain(id CartId) { func (p *SyncedPool) RemoveRemoteGrain(id CartId) {
@@ -142,12 +147,12 @@ func (p *SyncedPool) SpawnRemoteGrain(id CartId, host string) {
log.Printf("Error creating remote grain %v\n", err) log.Printf("Error creating remote grain %v\n", err)
return return
} }
go func() { // go func() {
<-remote.PersistentConnection.Died // <-remote.Died
p.RemoveRemoteGrain(id) // p.RemoveRemoteGrain(id)
p.HandleHostError(host) // p.HandleHostError(host)
log.Printf("Remote grain %s died, host: %s\n", id.String(), host) // log.Printf("Remote grain %s died, host: %s\n", id.String(), host)
}() // }()
p.mu.Lock() p.mu.Lock()
p.remoteIndex[id] = remote p.remoteIndex[id] = remote
@@ -159,8 +164,6 @@ func (p *SyncedPool) HandleHostError(host string) {
if r.Host == host { if r.Host == host {
if !r.IsHealthy() { if !r.IsHealthy() {
p.RemoveHost(r) p.RemoveHost(r)
} else {
r.ErrorCount++
} }
return return
} }
@@ -169,8 +172,8 @@ func (p *SyncedPool) HandleHostError(host string) {
func NewSyncedPool(local *GrainLocalPool, hostname string, discovery Discovery) (*SyncedPool, error) { func NewSyncedPool(local *GrainLocalPool, hostname string, discovery Discovery) (*SyncedPool, error) {
listen := fmt.Sprintf("%s:1338", hostname) listen := fmt.Sprintf("%s:1338", hostname)
conn := NewConnection(listen)
server, err := Listen(listen) server, err := conn.Listen()
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -186,10 +189,10 @@ func NewSyncedPool(local *GrainLocalPool, hostname string, discovery Discovery)
remoteIndex: make(map[CartId]*RemoteGrain), remoteIndex: make(map[CartId]*RemoteGrain),
} }
server.HandleCall(Ping, pool.PongHandler) server.AddHandler(Ping, pool.PongHandler)
server.HandleCall(GetCartIds, pool.GetCartIdHandler) server.AddHandler(GetCartIds, pool.GetCartIdHandler)
server.HandleCall(RemoteNegotiate, pool.NegotiateHandler) server.AddHandler(RemoteNegotiate, pool.NegotiateHandler)
server.HandleCall(RemoteGrainChanged, pool.GrainOwnerChangeHandler) server.AddHandler(RemoteGrainChanged, pool.GrainOwnerChangeHandler)
if discovery != nil { if discovery != nil {
go func() { go func() {
@@ -259,18 +262,11 @@ func (p *SyncedPool) ExcludeKnown(hosts []string) []string {
} }
func (p *SyncedPool) RemoveHost(host *RemoteHost) { func (p *SyncedPool) RemoveHost(host *RemoteHost) {
if p.remotes[host.Host] == nil {
return
}
p.mu.Lock() p.mu.Lock()
defer p.mu.Unlock()
h := p.remotes[host.Host]
if h != nil {
h.Close()
}
delete(p.remotes, host.Host) delete(p.remotes, host.Host)
p.mu.Unlock()
p.RemoveHostMappedCarts(host)
connectedRemotes.Set(float64(len(p.remotes))) connectedRemotes.Set(float64(len(p.remotes)))
} }
@@ -279,24 +275,21 @@ func (p *SyncedPool) RemoveHostMappedCarts(host *RemoteHost) {
defer p.mu.Unlock() defer p.mu.Unlock()
for id, r := range p.remoteIndex { for id, r := range p.remoteIndex {
if r.Host == host.Host { if r.Host == host.Host {
p.remoteIndex[id].Close()
delete(p.remoteIndex, id) delete(p.remoteIndex, id)
} }
} }
} }
type PoolMessage uint32
const ( const (
RemoteNegotiate = PoolMessage(3) RemoteNegotiate = FrameType(3)
RemoteGrainChanged = PoolMessage(4) RemoteGrainChanged = FrameType(4)
AckChange = PoolMessage(5) AckChange = FrameType(5)
//AckError = PoolMessage(6) AckError = FrameType(6)
Ping = PoolMessage(7) Ping = FrameType(7)
Pong = PoolMessage(8) Pong = FrameType(8)
GetCartIds = PoolMessage(9) GetCartIds = FrameType(9)
CartIdsResponse = PoolMessage(10) CartIdsResponse = FrameType(10)
RemoteNegotiateResponse = PoolMessage(11) RemoteNegotiateResponse = FrameType(11)
) )
func (p *SyncedPool) Negotiate() { func (p *SyncedPool) Negotiate() {
@@ -377,25 +370,22 @@ func (p *SyncedPool) AddRemote(host string) error {
if host == "" || p.IsKnown(host) || hasHost { if host == "" || p.IsKnown(host) || hasHost {
return nil return nil
} }
client, err := Dial(fmt.Sprintf("%s:1338", host)) client := NewConnection(fmt.Sprintf("%s:1338", host))
if err != nil { response, err := client.Call(Ping, nil)
if err != nil || response.StatusCode != 200 || response.Type != Pong {
log.Printf("Error connecting to remote %s: %v\n", host, err) log.Printf("Error connecting to remote %s: %v\n", host, err)
return err return err
} }
remote := RemoteHost{ remote := RemoteHost{
Client: client, Connection: client,
MissedPings: 0, MissedPings: 0,
Host: host, Host: host,
} }
p.mu.Lock() p.mu.Lock()
p.remotes[host] = &remote p.remotes[host] = &remote
p.mu.Unlock() p.mu.Unlock()
go func() {
<-remote.PersistentConnection.Died
log.Printf("Removing host, remote died %s", host)
p.RemoveHost(&remote)
}()
go func() { go func() {
for range time.Tick(time.Second * 3) { for range time.Tick(time.Second * 3) {
@@ -450,9 +440,9 @@ func (p *SyncedPool) getGrain(id CartId) (Grain, error) {
return localGrain, nil return localGrain, nil
} }
func (p *SyncedPool) Process(id CartId, messages ...Message) (*CallResult, error) { func (p *SyncedPool) Process(id CartId, messages ...Message) (*FrameWithPayload, error) {
pool, err := p.getGrain(id) pool, err := p.getGrain(id)
var res *CallResult var res *FrameWithPayload
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -465,7 +455,7 @@ func (p *SyncedPool) Process(id CartId, messages ...Message) (*CallResult, error
return res, nil return res, nil
} }
func (p *SyncedPool) Get(id CartId) (*CallResult, error) { func (p *SyncedPool) Get(id CartId) (*FrameWithPayload, error) {
grain, err := p.getGrain(id) grain, err := p.getGrain(id)
if err != nil { if err != nil {
return nil, err return nil, err

View File

@@ -1,96 +0,0 @@
package main
import (
"encoding/binary"
"log"
"sync"
)
type CartClient struct {
*CartTCPClient
}
func CartDial(address string) (*CartClient, error) {
mux, err := NewCartTCPClient(address)
if err != nil {
return nil, err
}
client := &CartClient{
CartTCPClient: mux,
}
return client, nil
}
func (c *Client) Close() {
log.Printf("Closing connection to %s\n", c.PersistentConnection.address)
c.PersistentConnection.Close()
}
type CartTCPClient struct {
PersistentConnection *PersistentConnection
sendMux sync.Mutex
ErrorCount int
address string
*CartPacketQueue
}
func NewCartTCPClient(address string) (*CartTCPClient, error) {
connection, err := NewPersistentConnection(address)
if err != nil {
return nil, err
}
return &CartTCPClient{
ErrorCount: 0,
PersistentConnection: connection,
address: address,
CartPacketQueue: NewCartPacketQueue(connection),
}, nil
}
func (m *CartTCPClient) SendPacket(messageType CartMessage, id CartId, data []byte) error {
m.sendMux.Lock()
defer m.sendMux.Unlock()
m.PersistentConnection.Conn.Write(header[:])
err := binary.Write(m.PersistentConnection, binary.LittleEndian, CartPacket{
Version: CurrentPacketVersion,
MessageType: messageType,
DataLength: uint32(len(data)),
Id: id,
})
if err != nil {
return m.PersistentConnection.HandleConnectionError(err)
}
_, err = m.PersistentConnection.Write(data)
return m.PersistentConnection.HandleConnectionError(err)
}
func (m *CartTCPClient) call(messageType CartMessage, id CartId, responseType CartMessage, data []byte) (*CallResult, error) {
packetChan := m.Expect(responseType, id)
err := m.SendPacket(messageType, id, data)
if err != nil {
return nil, m.PersistentConnection.HandleConnectionError(err)
}
ret := <-packetChan
return &ret, nil
}
func isRetirableError(err error) bool {
log.Printf("is retryable error: %v", err)
return false
}
func (m *CartTCPClient) Call(messageType CartMessage, id CartId, responseType CartMessage, data []byte) (*CallResult, error) {
retries := 0
result, err := m.call(messageType, id, responseType, data)
for err != nil && retries < 3 {
if !isRetirableError(err) {
break
}
retries++
log.Printf("Retrying call to %d\n", messageType)
result, err = m.call(messageType, id, responseType, data)
}
return result, err
}

View File

@@ -1,160 +0,0 @@
package main
import (
"bufio"
"encoding/binary"
"io"
"log"
"net"
"sync"
)
type CartServer struct {
*TCPCartServerMux
}
func CartListen(address string) (*CartServer, error) {
listener, err := net.Listen("tcp", address)
server := &CartServer{
NewCartTCPServerMux(),
}
if err != nil {
return nil, err
}
go func() {
for {
conn, err := listener.Accept()
if err != nil {
log.Printf("Error accepting connection: %v\n", err)
continue
}
go server.HandleConnection(conn)
}
}()
return server, nil
}
type TCPCartServerMux struct {
mu sync.RWMutex
sendMux sync.Mutex
listeners map[CartMessage]func(CartId, []byte) error
functions map[CartMessage]func(CartId, []byte) (CartMessage, []byte, error)
}
func NewCartTCPServerMux() *TCPCartServerMux {
m := &TCPCartServerMux{
mu: sync.RWMutex{},
listeners: make(map[CartMessage]func(CartId, []byte) error),
functions: make(map[CartMessage]func(CartId, []byte) (CartMessage, []byte, error)),
}
return m
}
func (m *TCPCartServerMux) handleListener(messageType CartMessage, id CartId, data []byte) (bool, error) {
m.mu.RLock()
handler, ok := m.listeners[messageType]
m.mu.RUnlock()
if ok {
err := handler(id, data)
if err != nil {
return true, err
}
}
return false, nil
}
func (m *TCPCartServerMux) handleFunction(connection net.Conn, messageType CartMessage, id CartId, data []byte) (bool, error) {
m.mu.RLock()
fn, ok := m.functions[messageType]
m.mu.RUnlock()
m.sendMux.Lock()
defer m.sendMux.Unlock()
if ok {
responseType, responseData, err := fn(id, data)
connection.Write(header[:])
if err != nil {
errData := []byte(err.Error())
err = binary.Write(connection, binary.LittleEndian, CartPacket{
Version: CurrentPacketVersion,
MessageType: responseType,
DataLength: uint32(len(errData)),
StatusCode: 500,
Id: id,
})
_, err = connection.Write(errData)
return true, err
}
err = binary.Write(connection, binary.LittleEndian, CartPacket{
Version: CurrentPacketVersion,
MessageType: responseType,
DataLength: uint32(len(responseData)),
StatusCode: 200,
Id: id,
})
if err != nil {
return true, err
}
packetsSent.Inc()
_, err = connection.Write(responseData)
return true, err
} else {
log.Printf("No cart handler for type: %d\n", messageType)
}
return false, nil
}
func (m *TCPCartServerMux) HandleConnection(connection net.Conn) error {
var packet CartPacket
var err error
defer connection.Close()
reader := bufio.NewReader(connection)
for {
err = ReadCartPacket(reader, &packet)
if err != nil {
if err == io.EOF {
return nil
}
log.Printf("Error receiving packet: %v\n", err)
return err
}
if packet.Version != CurrentPacketVersion {
log.Printf("Incorrect packet version: %d\n", packet.Version)
continue
}
data, err := GetPacketData(reader, packet.DataLength)
if err != nil {
log.Printf("Error getting packet data: %v\n", err)
}
go m.HandleData(connection, packet.MessageType, packet.Id, data)
}
}
func (m *TCPCartServerMux) HandleData(connection net.Conn, t CartMessage, id CartId, data []byte) {
status, err := m.handleListener(t, id, data)
if err != nil {
log.Printf("Error handling listener: %v\n", err)
}
if !status {
status, err = m.handleFunction(connection, t, id, data)
if err != nil {
log.Printf("Error handling function: %v\n", err)
}
if !status {
log.Printf("Unknown message type: %d\n", t)
}
}
}
func (m *TCPCartServerMux) ListenFor(messageType CartMessage, handler func(CartId, []byte) error) {
m.mu.Lock()
m.listeners[messageType] = handler
m.mu.Unlock()
}
func (m *TCPCartServerMux) HandleCall(messageType CartMessage, handler func(CartId, []byte) (CartMessage, []byte, error)) {
m.mu.Lock()
m.functions[messageType] = handler
m.mu.Unlock()
}

View File

@@ -1,57 +0,0 @@
package main
import (
"fmt"
"log"
"testing"
)
func TestCartTcpHelpers(t *testing.T) {
server, err := CartListen("localhost:51337")
if err != nil {
t.Errorf("Error listening: %v\n", err)
}
client, err := CartDial("localhost:51337")
if err != nil {
t.Errorf("Error dialing: %v\n", err)
}
var messageData string
server.ListenFor(1, func(id CartId, data []byte) error {
log.Printf("Received message: %s\n", string(data))
messageData = string(data)
return nil
})
server.HandleCall(666, func(id CartId, data []byte) (CartMessage, []byte, error) {
log.Printf("Received 666 call: %s\n", string(data))
return 3, []byte("Hello, client!"), fmt.Errorf("Det blev fel")
})
server.HandleCall(2, func(id CartId, data []byte) (CartMessage, []byte, error) {
log.Printf("Received 2 call: %s\n", string(data))
return 4, []byte("Hello, client!"), nil
})
// server.HandleCall(Ping, func(id CartId, data []byte) (CartMessage, []byte, error) {
// return Pong, nil, nil
// })
id := ToCartId("kalle")
client.SendPacket(1, id, []byte("Hello, world!"))
answer, err := client.Call(2, id, 4, []byte("Hello, server!"))
if err != nil {
t.Errorf("Error calling: %v\n", err)
}
s, err := client.Call(666, id, 3, []byte("Hello, server!"))
client.PersistentConnection.Close()
if err != nil {
t.Errorf("Error calling: %v\n", err)
}
if s.StatusCode != 500 {
t.Errorf("Expected 500, got %d\n", s.StatusCode)
}
if string(answer.Data) != "Hello, client!" {
t.Errorf("Expected answer 'Hello, client!', got %s\n", string(answer.Data))
}
if messageData != "Hello, world!" {
t.Errorf("Expected message 'Hello, world!', got %s\n", messageData)
}
}

View File

@@ -1,144 +0,0 @@
package main
import (
"encoding/binary"
"log"
"net"
"sync"
"time"
)
type Client struct {
*TCPClient
}
func Dial(address string) (*Client, error) {
mux, err := NewTCPClient(address)
if err != nil {
return nil, err
}
client := &Client{
TCPClient: mux,
}
return client, nil
}
type TCPClient struct {
PersistentConnection *PersistentConnection
sendMux sync.Mutex
ErrorCount int
address string
*PacketQueue
}
type PersistentConnection struct {
net.Conn
Died chan bool
Dead bool
address string
}
func NewPersistentConnection(address string) (*PersistentConnection, error) {
p := &PersistentConnection{
Died: make(chan bool, 1),
Dead: false,
address: address,
}
err := p.Connect()
if err != nil {
return nil, err
}
return p, nil
}
func (m *PersistentConnection) Connect() error {
fails := 0
for {
connection, err := net.Dial("tcp", m.address)
if err != nil {
log.Printf("Can't connect to %s: %v, count: %d", m.address, err, fails)
fails++
if fails > 15 {
log.Printf("Too many connection failures, closing connection to %s", m.address)
m.Died <- true
m.Dead = true
return err
}
} else {
m.Conn = connection
break
}
time.Sleep(time.Millisecond * 300)
}
return nil
}
func (m *PersistentConnection) Close() {
log.Printf("Closing connection to %s\n", m.address)
m.Conn.Close()
m.Died <- true
m.Dead = true
}
func (m *PersistentConnection) HandleConnectionError(err error) error {
if err != nil {
log.Printf("Error from to %s: %v\n", m.address, err)
m.Conn.Close()
m.Connect()
}
return err
}
func NewTCPClient(address string) (*TCPClient, error) {
connection, err := NewPersistentConnection(address)
if err != nil {
return nil, err
}
return &TCPClient{
ErrorCount: 0,
PersistentConnection: connection,
address: address,
PacketQueue: NewPacketQueue(connection),
}, nil
}
type PacketHeader [4]byte
var (
header = PacketHeader([4]byte{0x01, 0x02, 0x03, 0x04})
)
func (m *TCPClient) SendPacket(messageType PoolMessage, data []byte) error {
m.sendMux.Lock()
defer m.sendMux.Unlock()
m.PersistentConnection.Write(header[:])
err := binary.Write(m.PersistentConnection, binary.LittleEndian, Packet{
Version: CurrentPacketVersion,
MessageType: messageType,
StatusCode: 0,
DataLength: uint32(len(data)),
})
if err != nil {
return m.PersistentConnection.HandleConnectionError(err)
}
_, err = m.PersistentConnection.Write(data)
return m.PersistentConnection.HandleConnectionError(err)
}
func (m *TCPClient) Call(messageType PoolMessage, responseType PoolMessage, data []byte) (*CallResult, error) {
packetChan := m.Expect(responseType)
err := m.SendPacket(messageType, data)
if err != nil {
m.RemoveListeners()
return nil, m.PersistentConnection.HandleConnectionError(err)
}
ret := <-packetChan
return &ret, nil
}

View File

@@ -15,12 +15,23 @@ type Connection struct {
} }
type FrameType uint32 type FrameType uint32
type StatusCode uint32
type CheckSum uint32
type Frame struct { type Frame struct {
Id uint64
Type FrameType Type FrameType
StatusCode uint32 StatusCode StatusCode
Length uint32 Length uint32
Checksum CheckSum
}
func (f *Frame) IsValid() bool {
return f.Checksum == MakeChecksum(f.Type, f.StatusCode, f.Length)
}
func MakeChecksum(msg FrameType, statusCode StatusCode, length uint32) CheckSum {
sum := CheckSum((uint32(msg) + uint32(statusCode) + length) / 8)
return sum
} }
type FrameWithPayload struct { type FrameWithPayload struct {
@@ -28,6 +39,19 @@ type FrameWithPayload struct {
Payload []byte Payload []byte
} }
func MakeFrameWithPayload(msg FrameType, statusCode StatusCode, payload []byte) FrameWithPayload {
len := uint32(len(payload))
return FrameWithPayload{
Frame: Frame{
Type: msg,
StatusCode: 0,
Length: len,
Checksum: MakeChecksum(msg, 0, len),
},
Payload: payload,
}
}
type FrameData interface { type FrameData interface {
ToBytes() []byte ToBytes() []byte
FromBytes([]byte) error FromBytes([]byte) error
@@ -41,11 +65,7 @@ func NewConnection(address string) *Connection {
} }
func SendFrame(conn net.Conn, data *FrameWithPayload) error { func SendFrame(conn net.Conn, data *FrameWithPayload) error {
_, err := conn.Write(header[:]) err := binary.Write(conn, binary.LittleEndian, data.Frame)
if err != nil {
return err
}
err = binary.Write(conn, binary.LittleEndian, data.Frame)
if err != nil { if err != nil {
return err return err
} }
@@ -53,68 +73,67 @@ func SendFrame(conn net.Conn, data *FrameWithPayload) error {
return err return err
} }
func (c *Connection) CallAsync(msg FrameType, data FrameData, ch chan<- *FrameWithPayload) error { func (c *Connection) CallAsync(msg FrameType, payload []byte, ch chan<- FrameWithPayload) (net.Conn, error) {
conn, err := net.Dial("tcp", c.address) conn, err := net.Dial("tcp", c.address)
go WaitForFrame(conn, ch) go WaitForFrame(conn, ch)
if err != nil { if err != nil {
return err return conn, err
}
payload := data.ToBytes()
toSend := &FrameWithPayload{
Frame: Frame{
Id: c.count,
Type: msg,
StatusCode: 0,
Length: uint32(len(payload)),
},
Payload: payload,
} }
toSend := MakeFrameWithPayload(msg, 1, payload)
err = SendFrame(conn, toSend) err = SendFrame(conn, &toSend)
if err != nil { if err != nil {
conn.Close()
close(ch) close(ch)
return err return nil, err
} }
c.count++ c.count++
return nil return conn, nil
} }
func (c *Connection) Call(msg FrameType, data FrameData) (*FrameWithPayload, error) { func (c *Connection) Call(msg FrameType, data []byte) (*FrameWithPayload, error) {
ch := make(chan *FrameWithPayload, 1) ch := make(chan FrameWithPayload, 1)
c.CallAsync(msg, data, ch) conn, err := c.CallAsync(msg, data, ch)
if err != nil {
return nil, err
}
defer conn.Close()
select { select {
case ret := <-ch: case ret := <-ch:
return ret, nil return &ret, nil
case <-time.After(5 * time.Second): case <-time.After(MaxCallDuration):
return nil, fmt.Errorf("timeout") return nil, fmt.Errorf("timeout")
} }
} }
func WaitForFrame(conn net.Conn, resultChan chan<- *FrameWithPayload) error { func WaitForFrame(conn net.Conn, resultChan chan<- FrameWithPayload) error {
defer conn.Close()
var err error var err error
var frame Frame
r := bufio.NewReader(conn) r := bufio.NewReader(conn)
h := make([]byte, 4)
r.Read(h)
if h[0] == header[0] && h[1] == header[1] && h[2] == header[2] && h[3] == header[3] {
frame := Frame{}
err = binary.Read(r, binary.LittleEndian, &frame) err = binary.Read(r, binary.LittleEndian, &frame)
if err != nil {
return err
}
if frame.IsValid() {
payload := make([]byte, frame.Length) payload := make([]byte, frame.Length)
_, err = r.Read(payload) _, err = r.Read(payload)
resultChan <- &FrameWithPayload{ if err != nil {
return err
}
resultChan <- FrameWithPayload{
Frame: frame, Frame: frame,
Payload: payload, Payload: payload,
} }
return err return err
} }
resultChan <- nil return fmt.Errorf("checksum mismatch")
return err
} }
type GenericListener struct { type GenericListener struct {
Closed bool Closed bool
handlers map[FrameType]func(*FrameWithPayload, chan<- *FrameWithPayload) error handlers map[FrameType]func(*FrameWithPayload, chan<- FrameWithPayload) error
} }
func (c *Connection) Listen() (*GenericListener, error) { func (c *Connection) Listen() (*GenericListener, error) {
@@ -123,7 +142,7 @@ func (c *Connection) Listen() (*GenericListener, error) {
return nil, err return nil, err
} }
ret := &GenericListener{ ret := &GenericListener{
handlers: make(map[FrameType]func(*FrameWithPayload, chan<- *FrameWithPayload) error), handlers: make(map[FrameType]func(*FrameWithPayload, chan<- FrameWithPayload) error),
} }
go func() { go func() {
for !ret.Closed { for !ret.Closed {
@@ -137,36 +156,44 @@ func (c *Connection) Listen() (*GenericListener, error) {
return ret, nil return ret, nil
} }
const (
MaxCallDuration = 500 * time.Millisecond
)
func (l *GenericListener) HandleConnection(conn net.Conn) { func (l *GenericListener) HandleConnection(conn net.Conn) {
ch := make(chan *FrameWithPayload, 1) ch := make(chan FrameWithPayload, 1)
go WaitForFrame(conn, ch) go WaitForFrame(conn, ch)
select { select {
case frame := <-ch: case frame := <-ch:
go l.HandleFrame(conn, frame) go l.HandleFrame(conn, &frame)
case <-time.After(1 * time.Second): case <-time.After(MaxCallDuration):
close(ch) close(ch)
log.Printf("Timeout waiting for frame\n") log.Printf("Timeout waiting for frame\n")
} }
} }
func (l *GenericListener) AddHandler(msg FrameType, handler func(*FrameWithPayload, chan<- *FrameWithPayload) error) { func (l *GenericListener) AddHandler(msg FrameType, handler func(*FrameWithPayload, chan<- FrameWithPayload) error) {
l.handlers[msg] = handler l.handlers[msg] = handler
} }
func (l *GenericListener) HandleFrame(conn net.Conn, frame *FrameWithPayload) { func (l *GenericListener) HandleFrame(conn net.Conn, frame *FrameWithPayload) {
handler, ok := l.handlers[frame.Type] handler, ok := l.handlers[frame.Type]
defer conn.Close()
if ok { if ok {
go func() { go func() {
resultChan := make(chan *FrameWithPayload, 1) resultChan := make(chan FrameWithPayload, 1)
defer close(resultChan) defer close(resultChan)
err := handler(frame, resultChan) err := handler(frame, resultChan)
if err != nil { if err != nil {
log.Fatalf("Error handling frame: %v\n", err) log.Fatalf("Error handling frame: %v\n", err)
} }
SendFrame(conn, <-resultChan) result := <-resultChan
err = SendFrame(conn, &result)
if err != nil {
log.Fatalf("Error sending frame: %v\n", err)
}
}() }()
} else { } else {
conn.Close()
log.Fatalf("No handler for frame type %d\n", frame.Type) log.Fatalf("No handler for frame type %d\n", frame.Type)
} }
} }

View File

@@ -2,37 +2,19 @@ package main
import "testing" import "testing"
type StringData string
func (s StringData) ToBytes() []byte {
return []byte(s)
}
func (s StringData) FromBytes(data []byte) error {
s = StringData(data)
return nil
}
func TestGenericConnection(t *testing.T) { func TestGenericConnection(t *testing.T) {
conn := NewConnection("localhost:51337") conn := NewConnection("localhost:51337")
listener, err := conn.Listen() listener, err := conn.Listen()
if err != nil { if err != nil {
t.Errorf("Error listening: %v\n", err) t.Errorf("Error listening: %v\n", err)
} }
listener.AddHandler(1, func(input *FrameWithPayload, resultChan chan<- *FrameWithPayload) error { datta := []byte("Hello, world!")
payload := []byte("Hello, world!") listener.AddHandler(1, func(input *FrameWithPayload, resultChan chan<- FrameWithPayload) error {
resultChan <- &FrameWithPayload{
Frame: Frame{ resultChan <- MakeFrameWithPayload(2, 200, datta)
Type: 2,
Id: input.Id,
StatusCode: 200,
Length: uint32(len("Hello, world!")),
},
Payload: payload,
}
return nil return nil
}) })
r, err := conn.Call(1, StringData("Hello, world!")) r, err := conn.Call(1, datta)
if err != nil { if err != nil {
t.Errorf("Error calling: %v\n", err) t.Errorf("Error calling: %v\n", err)
} }
@@ -40,9 +22,9 @@ func TestGenericConnection(t *testing.T) {
t.Errorf("Expected type 2, got %d\n", r.Type) t.Errorf("Expected type 2, got %d\n", r.Type)
} }
i := 100 i := 100
results := make(chan *FrameWithPayload, i) results := make(chan FrameWithPayload, i)
for i > 0 { for i > 0 {
conn.CallAsync(1, StringData("Hello, world!"), results) go conn.CallAsync(1, datta, results)
i-- i--
} }
for i < 100 { for i < 100 {

View File

@@ -1,161 +0,0 @@
package main
import (
"bufio"
"encoding/binary"
"io"
"log"
"net"
"sync"
)
type Server struct {
*TCPServerMux
}
func Listen(address string) (*Server, error) {
listener, err := net.Listen("tcp", address)
server := &Server{
NewTCPServerMux(),
}
if err != nil {
return nil, err
}
go func() {
for {
conn, err := listener.Accept()
if err != nil {
log.Printf("Error accepting connection: %v\n", err)
continue
}
go server.HandleConnection(conn)
}
}()
return server, nil
}
type TCPServerMux struct {
mu sync.RWMutex
sendMux sync.Mutex
listeners map[PoolMessage]func(data []byte) error
functions map[PoolMessage]func(data []byte) (PoolMessage, []byte, error)
}
func NewTCPServerMux() *TCPServerMux {
m := &TCPServerMux{
mu: sync.RWMutex{},
listeners: make(map[PoolMessage]func(data []byte) error),
functions: make(map[PoolMessage]func(data []byte) (PoolMessage, []byte, error)),
}
return m
}
func (m *TCPServerMux) handleListener(messageType PoolMessage, data []byte) (bool, error) {
m.mu.RLock()
handler, ok := m.listeners[messageType]
m.mu.RUnlock()
if ok {
err := handler(data)
if err != nil {
return true, err
}
}
return false, nil
}
func (m *TCPServerMux) handleFunction(connection net.Conn, messageType PoolMessage, data []byte) (bool, error) {
m.mu.RLock()
function, ok := m.functions[messageType]
m.mu.RUnlock()
m.sendMux.Lock()
defer m.sendMux.Unlock()
if ok {
connection.Write(header[:])
responseType, responseData, err := function(data)
if err != nil {
errData := []byte(err.Error())
err = binary.Write(connection, binary.LittleEndian, Packet{
Version: CurrentPacketVersion,
MessageType: responseType,
StatusCode: 500,
DataLength: uint32(len(errData)),
})
_, err = connection.Write(errData)
return true, err
}
err = binary.Write(connection, binary.LittleEndian, Packet{
Version: CurrentPacketVersion,
MessageType: responseType,
StatusCode: 200,
DataLength: uint32(len(responseData)),
})
if err != nil {
return true, err
}
packetsSent.Inc()
_, err = connection.Write(responseData)
return true, err
} else {
log.Printf("No pool handler for type: %d\n", messageType)
}
return false, nil
}
func (m *TCPServerMux) HandleConnection(connection net.Conn) error {
defer connection.Close()
var packet Packet
reader := bufio.NewReader(connection)
for {
err := ReadPacket(reader, &packet)
if err != nil {
if err == io.EOF {
return nil
}
log.Printf("Error receiving packet: %v\n", err)
return err
}
if packet.Version != CurrentPacketVersion {
log.Printf("Incorrect package version: %v\n", err)
continue
}
data, err := GetPacketData(reader, packet.DataLength)
if err != nil {
log.Printf("Error receiving packet data: %v\n", err)
return err
}
go m.HandleData(connection, packet.MessageType, data)
}
}
func (m *TCPServerMux) HandleData(connection net.Conn, t PoolMessage, data []byte) {
// listener := m.listeners[t]
// handler := m.functions[t]
status, err := m.handleListener(t, data)
if err != nil {
log.Printf("Error handling listener: %v\n", err)
}
if !status {
status, err = m.handleFunction(connection, t, data)
if err != nil {
log.Printf("Error handling function: %v\n", err)
}
if !status {
log.Printf("Unknown message type: %d\n", t)
}
}
}
func (m *TCPServerMux) ListenFor(messageType PoolMessage, handler func(data []byte) error) {
m.mu.Lock()
m.listeners[messageType] = handler
m.mu.Unlock()
}
func (m *TCPServerMux) HandleCall(messageType PoolMessage, handler func(data []byte) (PoolMessage, []byte, error)) {
m.mu.Lock()
m.functions[messageType] = handler
m.mu.Unlock()
}

View File

@@ -1,54 +0,0 @@
package main
import (
"log"
"testing"
)
func TestTcpHelpers(t *testing.T) {
server, err := Listen("localhost:51337")
if err != nil {
t.Errorf("Error listening: %v\n", err)
}
client, err := Dial("localhost:51337")
if err != nil {
t.Errorf("Error dialing: %v\n", err)
}
var messageData string
server.ListenFor(1, func(data []byte) error {
log.Printf("Received message: %s\n", string(data))
messageData = string(data)
return nil
})
server.HandleCall(2, func(data []byte) (PoolMessage, []byte, error) {
log.Printf("Received call: %s\n", string(data))
return 3, []byte("Hello, client!"), nil
})
server.HandleCall(Ping, func(data []byte) (PoolMessage, []byte, error) {
return Pong, nil, nil
})
client.SendPacket(1, []byte("Hello, world!"))
answer, err := client.Call(2, 3, []byte("Hello, server!"))
if err != nil {
t.Errorf("Error calling: %v\n", err)
}
for i := 0; i < 100; i++ {
_, err = client.Call(Ping, Pong, nil)
if err != nil {
t.Errorf("Error calling: %v\n", err)
}
}
_, err = client.Call(Ping, Pong, nil)
if err != nil {
t.Errorf("Error calling: %v\n", err)
}
client.Close()
if string(answer.Data) != "Hello, client!" {
t.Errorf("Expected answer 'Hello, client!', got %s\n", string(answer.Data))
}
if messageData != "Hello, world!" {
t.Errorf("Expected message 'Hello, world!', got %s\n", messageData)
}
}