major refactor
All checks were successful
Build and Publish / BuildAndDeployAmd64 (push) Successful in 28s
Build and Publish / BuildAndDeploy (push) Successful in 2m18s

This commit is contained in:
matst80
2024-11-13 21:56:40 +01:00
parent 9f7c8227c2
commit abf561c3fe
20 changed files with 310 additions and 1292 deletions

View File

@@ -54,8 +54,8 @@ type CartGrain struct {
type Grain interface {
GetId() CartId
HandleMessage(message *Message, isReplay bool) (*CallResult, error)
GetCurrentState() (*CallResult, error)
HandleMessage(message *Message, isReplay bool) (*FrameWithPayload, error)
GetCurrentState() (*FrameWithPayload, error)
}
func (c *CartGrain) GetId() CartId {
@@ -69,12 +69,14 @@ func (c *CartGrain) GetLastChange() int64 {
return *c.storageMessages[len(c.storageMessages)-1].TimeStamp
}
func (c *CartGrain) GetCurrentState() (*CallResult, error) {
func (c *CartGrain) GetCurrentState() (*FrameWithPayload, error) {
result, err := json.Marshal(c)
return &CallResult{
StatusCode: 200,
Data: result,
}, err
if err != nil {
ret := MakeFrameWithPayload(0, 400, []byte(err.Error()))
return &ret, nil
}
ret := MakeFrameWithPayload(0, 200, result)
return &ret, nil
}
func getItemData(sku string, qty int) (*messages.AddItem, error) {
@@ -108,7 +110,7 @@ func getItemData(sku string, qty int) (*messages.AddItem, error) {
}, nil
}
func (c *CartGrain) AddItem(sku string, qty int) (*CallResult, error) {
func (c *CartGrain) AddItem(sku string, qty int) (*FrameWithPayload, error) {
cartItem, err := getItemData(sku, qty)
if err != nil {
return nil, err
@@ -180,7 +182,7 @@ func (c *CartGrain) FindItemWithSku(sku string) (*CartItem, bool) {
return nil, false
}
func (c *CartGrain) HandleMessage(message *Message, isReplay bool) (*CallResult, error) {
func (c *CartGrain) HandleMessage(message *Message, isReplay bool) (*FrameWithPayload, error) {
if message.TimeStamp == nil {
now := time.Now().Unix()
message.TimeStamp = &now
@@ -305,8 +307,10 @@ func (c *CartGrain) HandleMessage(message *Message, isReplay bool) (*CallResult,
c.mu.Unlock()
}
result, err := json.Marshal(c)
return &CallResult{
StatusCode: 200,
Data: result,
return &FrameWithPayload{
Frame: Frame{
StatusCode: 200,
},
Payload: result,
}, err
}

View File

@@ -1,163 +0,0 @@
package main
import (
"bufio"
"fmt"
"log"
"sync"
"time"
)
type CartPacketQueue struct {
mu sync.RWMutex
expectedPackages map[CartMessage]*CartListener
}
const CurrentPacketVersion = 2
type CartListener map[CartId]Listener
func NewCartPacketQueue(connection *PersistentConnection) *CartPacketQueue {
queue := &CartPacketQueue{
expectedPackages: make(map[CartMessage]*CartListener),
}
go queue.HandleConnection(connection)
return queue
}
func (p *CartPacketQueue) RemoveListeners() {
p.mu.Lock()
defer p.mu.Unlock()
for _, l := range p.expectedPackages {
for _, l := range *l {
close(l.Chan)
}
}
p.expectedPackages = make(map[CartMessage]*CartListener)
}
func (p *CartPacketQueue) HandleConnection(connection *PersistentConnection) error {
defer p.RemoveListeners()
defer connection.Close()
var packet CartPacket
reader := bufio.NewReader(connection)
for {
err := ReadCartPacket(reader, &packet)
if err != nil {
log.Printf("Error receiving packet: %v\n", err)
return connection.HandleConnectionError(err)
}
if packet.Version != CurrentPacketVersion {
log.Printf("Incorrect version: %v\n", packet.Version)
return connection.HandleConnectionError(fmt.Errorf("incorrect version: %d", packet.Version))
}
if packet.DataLength == 0 {
go p.HandleData(packet.MessageType, packet.Id, CallResult{
StatusCode: packet.StatusCode,
Data: []byte{},
})
continue
}
data, err := GetPacketData(reader, packet.DataLength)
if err != nil {
log.Printf("Error receiving packet data: %v\n", err)
return connection.HandleConnectionError(err)
}
go p.HandleData(packet.MessageType, packet.Id, CallResult{
StatusCode: packet.StatusCode,
Data: data,
})
}
}
func (p *CartPacketQueue) HandleData(t CartMessage, id CartId, data CallResult) {
p.getListener(t, id, func(l *Listener) {
l.Chan <- data
l.Count--
})
// p.mu.Lock()
// defer p.mu.Unlock()
// pl, ok := p.expectedPackages[t]
// if ok {
// l, ok := (*pl)[id]
// if ok {
// l.Chan <- data
// l.Count--
// if l.Count == 0 {
// close(l.Chan)
// delete(*pl, id)
// }
// }
// }
}
func (p *CartPacketQueue) getListener(t CartMessage, id CartId, fn func(*Listener)) {
p.mu.Lock()
defer p.mu.Unlock()
pl, ok := p.expectedPackages[t]
if ok {
l, ok := (*pl)[id]
if ok {
fn(&l)
if l.Count == 0 {
close(l.Chan)
delete(*pl, id)
}
}
}
}
func CallResultWithTimeout(onTimeout func() CallResult) chan CallResult {
ch := make(chan CallResult, 1)
resultCh := make(chan CallResult, 1)
select {
case ret := <-resultCh:
ch <- ret
case <-time.After(300 * time.Millisecond):
ch <- onTimeout()
}
return ch
}
func (p *CartPacketQueue) MakeChannel(messageType CartMessage, id CartId) chan CallResult {
return CallResultWithTimeout(func() CallResult {
p.getListener(messageType, id, func(l *Listener) {
l.Count--
})
return CallResult{
StatusCode: 504,
Data: []byte("timeout cart call"),
}
})
}
func (p *CartPacketQueue) Expect(messageType CartMessage, id CartId) <-chan CallResult {
p.mu.Lock()
defer p.mu.Unlock()
l, ok := p.expectedPackages[messageType]
if ok {
if idl, idOk := (*l)[id]; idOk {
idl.Count++
return idl.Chan
}
ch := p.MakeChannel(messageType, id)
(*l)[id] = Listener{
Chan: ch,
Count: 1,
}
return ch
}
ch := p.MakeChannel(messageType, id)
p.expectedPackages[messageType] = &CartListener{
id: Listener{
Chan: ch,
Count: 1,
},
}
return ch
}

View File

@@ -27,8 +27,8 @@ var (
)
type GrainPool interface {
Process(id CartId, messages ...Message) (*CallResult, error)
Get(id CartId) (*CallResult, error)
Process(id CartId, messages ...Message) (*FrameWithPayload, error)
Get(id CartId) (*FrameWithPayload, error)
}
type Ttl struct {
@@ -142,23 +142,29 @@ func (p *GrainLocalPool) GetGrain(id CartId) (*CartGrain, error) {
return grain, err
}
func (p *GrainLocalPool) Process(id CartId, messages ...Message) ([]byte, error) {
func (p *GrainLocalPool) Process(id CartId, messages ...Message) (*FrameWithPayload, error) {
grain, err := p.GetGrain(id)
var result *FrameWithPayload
if err == nil && grain != nil {
for _, message := range messages {
_, err = grain.HandleMessage(&message, false)
result, err = grain.HandleMessage(&message, false)
}
}
if err != nil {
return nil, err
return result, err
}
return json.Marshal(grain)
return result, err
}
func (p *GrainLocalPool) Get(id CartId) ([]byte, error) {
func (p *GrainLocalPool) Get(id CartId) (*FrameWithPayload, error) {
grain, err := p.GetGrain(id)
if err != nil {
return nil, err
}
return json.Marshal(grain)
data, err := json.Marshal(grain)
if err != nil {
return nil, err
}
ret := MakeFrameWithPayload(0, 200, data)
return &ret, nil
}

View File

@@ -1,122 +0,0 @@
package main
import (
"bufio"
"fmt"
"log"
"sync"
"time"
)
type PacketQueue struct {
mu sync.RWMutex
expectedPackages map[PoolMessage]*Listener
}
type CallResult struct {
StatusCode uint32
Data []byte
}
type Listener struct {
Count int
Chan chan CallResult
}
func NewPacketQueue(connection *PersistentConnection) *PacketQueue {
queue := &PacketQueue{
expectedPackages: make(map[PoolMessage]*Listener),
}
go queue.HandleConnection(connection)
return queue
}
func (p *PacketQueue) RemoveListeners() {
p.mu.Lock()
defer p.mu.Unlock()
for _, l := range p.expectedPackages {
close(l.Chan)
}
p.expectedPackages = make(map[PoolMessage]*Listener)
}
func (p *PacketQueue) HandleConnection(connection *PersistentConnection) error {
defer connection.Close()
defer p.RemoveListeners()
var packet Packet
reader := bufio.NewReader(connection)
for {
err := ReadPacket(reader, &packet)
if err != nil {
return connection.HandleConnectionError(err)
}
if packet.Version != CurrentPacketVersion {
log.Printf("Incorrect packet version: %v\n", packet.Version)
return connection.HandleConnectionError(fmt.Errorf("incorrect packet version: %d", packet.Version))
}
if packet.DataLength == 0 {
go p.HandleData(packet.MessageType, CallResult{
StatusCode: packet.StatusCode,
Data: []byte{},
})
continue
}
data, err := GetPacketData(reader, packet.DataLength)
if err != nil {
log.Printf("Error receiving packet data: %v\n", err)
return connection.HandleConnectionError(err)
} else {
go p.HandleData(packet.MessageType, CallResult{
StatusCode: packet.StatusCode,
Data: data,
})
}
}
}
func (p *PacketQueue) HandleData(t PoolMessage, data CallResult) {
p.mu.Lock()
defer p.mu.Unlock()
l, ok := p.expectedPackages[t]
if ok {
l.Chan <- data
l.Count--
if l.Count == 0 {
close(l.Chan)
delete(p.expectedPackages, t)
}
return
}
}
func (p *PacketQueue) Expect(messageType PoolMessage) <-chan CallResult {
p.mu.Lock()
defer p.mu.Unlock()
l, ok := p.expectedPackages[messageType]
if ok {
l.Count++
return l.Chan
}
ch := make(chan CallResult, 1)
go func() {
time.Sleep(time.Millisecond * 300)
p.mu.Lock()
defer p.mu.Unlock()
ch <- CallResult{
StatusCode: 504,
Data: []byte("timeout cart call"),
}
close(ch)
}()
p.expectedPackages[messageType] = &Listener{
Count: 1,
Chan: ch,
}
return ch
}

View File

@@ -1,28 +0,0 @@
package main
import (
"testing"
"time"
)
func TestQueue(t *testing.T) {
localPool := NewGrainLocalPool(100, time.Minute, func(id CartId) (*CartGrain, error) {
return &CartGrain{
Id: id,
storageMessages: []Message{},
Items: []*CartItem{},
TotalPrice: 0,
}, nil
})
pool, err := NewSyncedPool(localPool, "localhost", nil)
if err != nil {
t.Errorf("Error creating pool: %v", err)
}
err = pool.AddRemote("localhost")
if err != nil {
t.Errorf("Error adding remote: %v", err)
return
}
}

130
packet.go
View File

@@ -1,85 +1,77 @@
package main
import (
"encoding/binary"
"io"
)
type CartMessage uint32
type PackageVersion uint32
const (
RemoteGetState = CartMessage(0x01)
RemoteHandleMutation = CartMessage(0x02)
ResponseBody = CartMessage(0x03)
RemoteGetStateReply = CartMessage(0x04)
RemoteHandleMutationReply = CartMessage(0x05)
RemoteGetState = FrameType(0x01)
RemoteHandleMutation = FrameType(0x02)
ResponseBody = FrameType(0x03)
RemoteGetStateReply = FrameType(0x04)
RemoteHandleMutationReply = FrameType(0x05)
)
type CartPacket struct {
Version PackageVersion
MessageType CartMessage
DataLength uint32
StatusCode uint32
Id CartId
}
// type CartPacket struct {
// Version PackageVersion
// MessageType CartMessage
// DataLength uint32
// StatusCode uint32
// Id CartId
// }
type Packet struct {
Version PackageVersion
MessageType PoolMessage
DataLength uint32
StatusCode uint32
}
// type Packet struct {
// Version PackageVersion
// MessageType PoolMessage
// DataLength uint32
// StatusCode uint32
// }
var headerData = make([]byte, 4)
// var headerData = make([]byte, 4)
func matchHeader(conn io.Reader) error {
// func matchHeader(conn io.Reader) error {
pos := 0
for pos < 4 {
// pos := 0
// for pos < 4 {
l, err := conn.Read(headerData)
if err != nil {
return err
}
for i := 0; i < l; i++ {
if headerData[i] == header[pos] {
pos++
if pos == 4 {
return nil
}
} else {
pos = 0
}
}
}
return nil
}
// l, err := conn.Read(headerData)
// if err != nil {
// return err
// }
// for i := 0; i < l; i++ {
// if headerData[i] == header[pos] {
// pos++
// if pos == 4 {
// return nil
// }
// } else {
// pos = 0
// }
// }
// }
// return nil
// }
func ReadPacket(conn io.Reader, packet *Packet) error {
err := matchHeader(conn)
if err != nil {
return err
}
return binary.Read(conn, binary.LittleEndian, packet)
}
// func ReadPacket(conn io.Reader, packet *Packet) error {
// err := matchHeader(conn)
// if err != nil {
// return err
// }
// return binary.Read(conn, binary.LittleEndian, packet)
// }
func ReadCartPacket(conn io.Reader, packet *CartPacket) error {
err := matchHeader(conn)
if err != nil {
return err
}
return binary.Read(conn, binary.LittleEndian, packet)
}
// func ReadCartPacket(conn io.Reader, packet *CartPacket) error {
// err := matchHeader(conn)
// if err != nil {
// return err
// }
// return binary.Read(conn, binary.LittleEndian, packet)
// }
func GetPacketData(conn io.Reader, len uint32) ([]byte, error) {
if len == 0 {
return []byte{}, nil
}
data := make([]byte, len)
_, err := conn.Read(data)
return data, err
}
// func GetPacketData(conn io.Reader, len uint32) ([]byte, error) {
// if len == 0 {
// return []byte{}, nil
// }
// data := make([]byte, len)
// _, err := conn.Read(data)
// return data, err
// }
// func ReceivePacket(conn io.Reader) (uint32, []byte, error) {
// var packet Packet

View File

@@ -55,7 +55,7 @@ func ErrorHandler(fn func(w http.ResponseWriter, r *http.Request) error) func(w
}
}
func (s *PoolServer) WriteResult(w http.ResponseWriter, result *CallResult) error {
func (s *PoolServer) WriteResult(w http.ResponseWriter, result *FrameWithPayload) error {
w.Header().Set("Content-Type", "application/json")
w.Header().Set("X-Pod-Name", s.pod_name)
if result.StatusCode != 200 {
@@ -65,11 +65,11 @@ func (s *PoolServer) WriteResult(w http.ResponseWriter, result *CallResult) erro
} else {
w.WriteHeader(http.StatusInternalServerError)
}
w.Write([]byte(result.Data))
w.Write([]byte(result.Payload))
return nil
}
w.WriteHeader(http.StatusOK)
_, err := w.Write(result.Data)
_, err := w.Write(result.Payload)
return err
}

View File

@@ -46,8 +46,8 @@ func (p *RemoteGrainPool) Delete(id CartId) {
p.mu.Unlock()
}
func (p *RemoteGrainPool) Process(id CartId, messages ...Message) (*CallResult, error) {
var result *CallResult
func (p *RemoteGrainPool) Process(id CartId, messages ...Message) (*FrameWithPayload, error) {
var result *FrameWithPayload
grain, err := p.findOrCreateGrain(id)
if err != nil {
return nil, err
@@ -58,7 +58,7 @@ func (p *RemoteGrainPool) Process(id CartId, messages ...Message) (*CallResult,
return result, err
}
func (p *RemoteGrainPool) Get(id CartId) (*CallResult, error) {
func (p *RemoteGrainPool) Get(id CartId) (*FrameWithPayload, error) {
grain, err := p.findOrCreateGrain(id)
if err != nil {
return nil, err

View File

@@ -3,7 +3,6 @@ package main
import (
"fmt"
"strings"
"time"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
@@ -13,6 +12,25 @@ func (id CartId) String() string {
return strings.Trim(string(id[:]), "\x00")
}
type CartIdPayload struct {
Id CartId
Data []byte
}
func MakeCartInnerFrame(id CartId, payload []byte) []byte {
return append(id[:], payload...)
}
func GetCartFrame(data []byte) (*CartIdPayload, error) {
if len(data) < 16 {
return nil, fmt.Errorf("data too short")
}
return &CartIdPayload{
Id: CartId(data[:16]),
Data: data[16:],
}, nil
}
func ToCartId(id string) CartId {
var result [16]byte
copy(result[:], []byte(id))
@@ -20,21 +38,16 @@ func ToCartId(id string) CartId {
}
type RemoteGrain struct {
*CartClient
*Connection
Id CartId
Host string
}
func NewRemoteGrain(id CartId, host string) (*RemoteGrain, error) {
client, err := CartDial(fmt.Sprintf("%s:1337", host))
if err != nil {
return nil, err
}
return &RemoteGrain{
Id: id,
Host: host,
CartClient: client,
Connection: NewConnection(fmt.Sprintf("%s:1337", host)),
}, nil
}
@@ -49,47 +62,35 @@ var (
})
)
var start time.Time
// var start time.Time
func MeasureLatency(fn func() (*CallResult, error)) (*CallResult, error) {
start = time.Now()
data, err := fn()
if err != nil {
return data, err
}
elapsed := time.Since(start).Milliseconds()
go func() {
remoteCartLatency.Add(float64(elapsed))
remoteCartCallsTotal.Inc()
}()
return data, nil
}
// func MeasureLatency(fn func() (*CallResult, error)) (*CallResult, error) {
// start = time.Now()
// data, err := fn()
// if err != nil {
// return data, err
// }
// elapsed := time.Since(start).Milliseconds()
// go func() {
// remoteCartLatency.Add(float64(elapsed))
// remoteCartCallsTotal.Inc()
// }()
// return data, nil
// }
func (g *RemoteGrain) HandleMessage(message *Message, isReplay bool) (*CallResult, error) {
func (g *RemoteGrain) HandleMessage(message *Message, isReplay bool) (*FrameWithPayload, error) {
data, err := GetData(message.Write)
if err != nil {
return nil, err
}
reply, err := MeasureLatency(func() (*CallResult, error) {
return g.Call(RemoteHandleMutation, g.Id, RemoteHandleMutationReply, data)
})
if err != nil {
return nil, err
}
return reply, err
}
func (g *RemoteGrain) Close() {
g.CartClient.PersistentConnection.Close()
return g.Call(RemoteHandleMutation, MakeCartInnerFrame(g.Id, data))
}
func (g *RemoteGrain) GetId() CartId {
return g.Id
}
func (g *RemoteGrain) GetCurrentState() (*CallResult, error) {
return MeasureLatency(func() (*CallResult, error) { return g.Call(RemoteGetState, g.Id, RemoteGetStateReply, []byte{}) })
func (g *RemoteGrain) GetCurrentState() (*FrameWithPayload, error) {
return g.Call(RemoteGetState, MakeCartInnerFrame(g.Id, nil))
}

View File

@@ -7,13 +7,13 @@ import (
)
type RemoteHost struct {
*Client
*Connection
Host string
MissedPings int
}
func (h *RemoteHost) IsHealthy() bool {
return !h.PersistentConnection.Dead && h.MissedPings < 3
return h.MissedPings < 3
}
func (h *RemoteHost) Initialize(p *SyncedPool) {
@@ -38,15 +38,11 @@ func (h *RemoteHost) Initialize(p *SyncedPool) {
}
func (h *RemoteHost) Ping() error {
_, err := h.Call(Ping, Pong, []byte{})
result, err := h.Call(Ping, nil)
if err != nil {
if err != nil || result.StatusCode != 200 || result.Type != Pong {
h.MissedPings++
log.Printf("Error pinging remote %s, missed pings: %d", h.Host, h.MissedPings)
if !h.IsHealthy() {
h.Close()
return fmt.Errorf("remote %s is dead", h.Host)
}
} else {
h.MissedPings = 0
}
@@ -54,28 +50,28 @@ func (h *RemoteHost) Ping() error {
}
func (h *RemoteHost) Negotiate(knownHosts []string) ([]string, error) {
reply, err := h.Call(RemoteNegotiate, RemoteNegotiateResponse, []byte(strings.Join(knownHosts, ";")))
reply, err := h.Call(RemoteNegotiate, []byte(strings.Join(knownHosts, ";")))
if err != nil {
return nil, err
}
if reply.StatusCode != 200 {
return nil, fmt.Errorf("remote returned error on negotiate: %s", string(reply.Data))
return nil, fmt.Errorf("remote returned error on negotiate: %s", string(reply.Payload))
}
return strings.Split(string(reply.Data), ";"), nil
return strings.Split(string(reply.Payload), ";"), nil
}
func (g *RemoteHost) GetCartMappings() ([]CartId, error) {
reply, err := g.Call(GetCartIds, CartIdsResponse, []byte{})
reply, err := g.Call(GetCartIds, []byte{})
if err != nil {
return nil, err
}
if reply.StatusCode != 200 {
log.Printf("Remote returned error on get cart mappings: %s", string(reply.Data))
return nil, fmt.Errorf("remote returned error: %s", string(reply.Data))
if reply.StatusCode != 200 || reply.Type != CartIdsResponse {
log.Printf("Remote returned error on get cart mappings: %s", string(reply.Payload))
return nil, fmt.Errorf("remote returned incorrect data")
}
parts := strings.Split(string(reply.Data), ";")
parts := strings.Split(string(reply.Payload), ";")
ids := make([]CartId, 0, len(parts))
for _, p := range parts {
ids = append(ids, ToCartId(p))
@@ -84,14 +80,11 @@ func (g *RemoteHost) GetCartMappings() ([]CartId, error) {
}
func (r *RemoteHost) ConfirmChange(id CartId, host string) error {
reply, err := r.Call(RemoteGrainChanged, AckChange, []byte(fmt.Sprintf("%s;%s", id, host)))
reply, err := r.Call(RemoteGrainChanged, []byte(fmt.Sprintf("%s;%s", id, host)))
if err != nil {
if err != nil || reply.StatusCode != 200 || reply.Type != AckChange {
return err
}
if string(reply.Data) != "ok" {
return fmt.Errorf("remote grain change failed %s", string(reply.Data))
}
return nil
}

View File

@@ -6,7 +6,7 @@ import (
)
type GrainHandler struct {
*CartServer
*GenericListener
pool *GrainLocalPool
}
@@ -20,13 +20,14 @@ func (h *GrainHandler) GetState(id CartId, reply *Grain) error {
}
func NewGrainHandler(pool *GrainLocalPool, listen string) (*GrainHandler, error) {
server, err := CartListen(listen)
conn := NewConnection(listen)
server, err := conn.Listen()
handler := &GrainHandler{
CartServer: server,
pool: pool,
GenericListener: server,
pool: pool,
}
server.HandleCall(RemoteHandleMutation, handler.RemoteHandleMessageHandler)
server.HandleCall(RemoteGetState, handler.RemoteGetStateHandler)
server.AddHandler(RemoteHandleMutation, handler.RemoteHandleMessageHandler)
server.AddHandler(RemoteGetState, handler.RemoteGetStateHandler)
return handler, err
}
@@ -34,29 +35,36 @@ func (h *GrainHandler) IsHealthy() bool {
return len(h.pool.grains) < h.pool.PoolSize
}
func (h *GrainHandler) RemoteHandleMessageHandler(id CartId, data []byte) (CartMessage, []byte, error) {
func (h *GrainHandler) RemoteHandleMessageHandler(data *FrameWithPayload, resultChan chan<- FrameWithPayload) error {
cartData, err := GetCartFrame(data.Payload)
if err != nil {
return err
}
var msg Message
err := ReadMessage(bytes.NewReader(data), &msg)
err = ReadMessage(bytes.NewReader(cartData.Data), &msg)
if err != nil {
fmt.Println("Error reading message:", err)
return RemoteHandleMutationReply, nil, err
return err
}
replyData, err := h.pool.Process(id, msg)
replyData, err := h.pool.Process(cartData.Id, msg)
if err != nil {
fmt.Println("Error handling message:", err)
}
if err != nil {
return RemoteHandleMutationReply, nil, err
}
return RemoteHandleMutationReply, replyData, nil
resultChan <- *replyData
return nil
}
func (h *GrainHandler) RemoteGetStateHandler(id CartId, data []byte) (CartMessage, []byte, error) {
reply, err := h.pool.Get(id)
func (h *GrainHandler) RemoteGetStateHandler(data *FrameWithPayload, resultChan chan<- FrameWithPayload) error {
cartData, err := GetCartFrame(data.Payload)
if err != nil {
return RemoteGetStateReply, nil, err
return err
}
return RemoteGetStateReply, reply, nil
reply, err := h.pool.Get(cartData.Id)
if err != nil {
return err
}
resultChan <- *reply
return nil
}

View File

@@ -22,7 +22,7 @@ type HealthHandler interface {
}
type SyncedPool struct {
*Server
Server *GenericListener
mu sync.RWMutex
Hostname string
local *GrainLocalPool
@@ -61,11 +61,16 @@ var (
})
)
func (p *SyncedPool) PongHandler(data []byte) (PoolMessage, []byte, error) {
return Pong, data, nil
var (
PongResponse = MakeFrameWithPayload(Pong, 200, nil)
)
func (p *SyncedPool) PongHandler(data *FrameWithPayload, resultChan chan<- FrameWithPayload) error {
resultChan <- PongResponse
return nil
}
func (p *SyncedPool) GetCartIdHandler(data []byte) (PoolMessage, []byte, error) {
func (p *SyncedPool) GetCartIdHandler(data *FrameWithPayload, resultChan chan<- FrameWithPayload) error {
ids := make([]string, 0, len(p.local.grains))
for id := range p.local.grains {
if p.local.grains[id] == nil {
@@ -78,45 +83,45 @@ func (p *SyncedPool) GetCartIdHandler(data []byte) (PoolMessage, []byte, error)
ids = append(ids, s)
}
log.Printf("Returning %d cart ids\n", len(ids))
return CartIdsResponse, []byte(strings.Join(ids, ";")), nil
resultChan <- MakeFrameWithPayload(CartIdsResponse, 200, []byte(strings.Join(ids, ";")))
return nil
}
func (p *SyncedPool) NegotiateHandler(data []byte) (PoolMessage, []byte, error) {
func (p *SyncedPool) NegotiateHandler(data *FrameWithPayload, resultChan chan<- FrameWithPayload) error {
negotiationCount.Inc()
log.Printf("Handling negotiation\n")
for _, host := range p.ExcludeKnown(strings.Split(string(data), ";")) {
for _, host := range p.ExcludeKnown(strings.Split(string(data.Payload), ";")) {
if host == "" {
continue
}
go p.AddRemote(host)
}
return RemoteNegotiateResponse, []byte("ok"), nil
resultChan <- MakeFrameWithPayload(RemoteNegotiateResponse, 200, []byte("ok"))
return nil
}
func (p *SyncedPool) GrainOwnerChangeHandler(data []byte) (PoolMessage, []byte, error) {
func (p *SyncedPool) GrainOwnerChangeHandler(data *FrameWithPayload, resultChan chan<- FrameWithPayload) error {
grainSyncCount.Inc()
idAndHostParts := strings.Split(string(data), ";")
idAndHostParts := strings.Split(string(data.Payload), ";")
if len(idAndHostParts) != 2 {
log.Printf("Invalid remote grain change message\n")
return AckChange, []byte("incorrect"), fmt.Errorf("invalid remote grain change message")
resultChan <- MakeFrameWithPayload(AckError, 400, []byte("invalid"))
return nil
}
id := ToCartId(idAndHostParts[0])
host := idAndHostParts[1]
log.Printf("Handling remote grain owner change to %s for id %s\n", host, id)
for _, r := range p.remotes {
if r.Host == host && r.IsHealthy() {
// log.Printf("Remote grain %s changed to %s\n", id, host)
go p.SpawnRemoteGrain(id, host)
return AckChange, []byte("ok"), nil
break
}
}
go p.AddRemote(host)
return AckChange, []byte("ok"), nil
resultChan <- MakeFrameWithPayload(AckChange, 200, []byte("ok"))
return nil
}
func (p *SyncedPool) RemoveRemoteGrain(id CartId) {
@@ -142,12 +147,12 @@ func (p *SyncedPool) SpawnRemoteGrain(id CartId, host string) {
log.Printf("Error creating remote grain %v\n", err)
return
}
go func() {
<-remote.PersistentConnection.Died
p.RemoveRemoteGrain(id)
p.HandleHostError(host)
log.Printf("Remote grain %s died, host: %s\n", id.String(), host)
}()
// go func() {
// <-remote.Died
// p.RemoveRemoteGrain(id)
// p.HandleHostError(host)
// log.Printf("Remote grain %s died, host: %s\n", id.String(), host)
// }()
p.mu.Lock()
p.remoteIndex[id] = remote
@@ -159,8 +164,6 @@ func (p *SyncedPool) HandleHostError(host string) {
if r.Host == host {
if !r.IsHealthy() {
p.RemoveHost(r)
} else {
r.ErrorCount++
}
return
}
@@ -169,8 +172,8 @@ func (p *SyncedPool) HandleHostError(host string) {
func NewSyncedPool(local *GrainLocalPool, hostname string, discovery Discovery) (*SyncedPool, error) {
listen := fmt.Sprintf("%s:1338", hostname)
server, err := Listen(listen)
conn := NewConnection(listen)
server, err := conn.Listen()
if err != nil {
return nil, err
}
@@ -186,10 +189,10 @@ func NewSyncedPool(local *GrainLocalPool, hostname string, discovery Discovery)
remoteIndex: make(map[CartId]*RemoteGrain),
}
server.HandleCall(Ping, pool.PongHandler)
server.HandleCall(GetCartIds, pool.GetCartIdHandler)
server.HandleCall(RemoteNegotiate, pool.NegotiateHandler)
server.HandleCall(RemoteGrainChanged, pool.GrainOwnerChangeHandler)
server.AddHandler(Ping, pool.PongHandler)
server.AddHandler(GetCartIds, pool.GetCartIdHandler)
server.AddHandler(RemoteNegotiate, pool.NegotiateHandler)
server.AddHandler(RemoteGrainChanged, pool.GrainOwnerChangeHandler)
if discovery != nil {
go func() {
@@ -259,18 +262,11 @@ func (p *SyncedPool) ExcludeKnown(hosts []string) []string {
}
func (p *SyncedPool) RemoveHost(host *RemoteHost) {
if p.remotes[host.Host] == nil {
return
}
p.mu.Lock()
defer p.mu.Unlock()
h := p.remotes[host.Host]
if h != nil {
h.Close()
}
delete(p.remotes, host.Host)
p.mu.Unlock()
p.RemoveHostMappedCarts(host)
connectedRemotes.Set(float64(len(p.remotes)))
}
@@ -279,24 +275,21 @@ func (p *SyncedPool) RemoveHostMappedCarts(host *RemoteHost) {
defer p.mu.Unlock()
for id, r := range p.remoteIndex {
if r.Host == host.Host {
p.remoteIndex[id].Close()
delete(p.remoteIndex, id)
}
}
}
type PoolMessage uint32
const (
RemoteNegotiate = PoolMessage(3)
RemoteGrainChanged = PoolMessage(4)
AckChange = PoolMessage(5)
//AckError = PoolMessage(6)
Ping = PoolMessage(7)
Pong = PoolMessage(8)
GetCartIds = PoolMessage(9)
CartIdsResponse = PoolMessage(10)
RemoteNegotiateResponse = PoolMessage(11)
RemoteNegotiate = FrameType(3)
RemoteGrainChanged = FrameType(4)
AckChange = FrameType(5)
AckError = FrameType(6)
Ping = FrameType(7)
Pong = FrameType(8)
GetCartIds = FrameType(9)
CartIdsResponse = FrameType(10)
RemoteNegotiateResponse = FrameType(11)
)
func (p *SyncedPool) Negotiate() {
@@ -377,25 +370,22 @@ func (p *SyncedPool) AddRemote(host string) error {
if host == "" || p.IsKnown(host) || hasHost {
return nil
}
client, err := Dial(fmt.Sprintf("%s:1338", host))
if err != nil {
client := NewConnection(fmt.Sprintf("%s:1338", host))
response, err := client.Call(Ping, nil)
if err != nil || response.StatusCode != 200 || response.Type != Pong {
log.Printf("Error connecting to remote %s: %v\n", host, err)
return err
}
remote := RemoteHost{
Client: client,
Connection: client,
MissedPings: 0,
Host: host,
}
p.mu.Lock()
p.remotes[host] = &remote
p.mu.Unlock()
go func() {
<-remote.PersistentConnection.Died
log.Printf("Removing host, remote died %s", host)
p.RemoveHost(&remote)
}()
go func() {
for range time.Tick(time.Second * 3) {
@@ -450,9 +440,9 @@ func (p *SyncedPool) getGrain(id CartId) (Grain, error) {
return localGrain, nil
}
func (p *SyncedPool) Process(id CartId, messages ...Message) (*CallResult, error) {
func (p *SyncedPool) Process(id CartId, messages ...Message) (*FrameWithPayload, error) {
pool, err := p.getGrain(id)
var res *CallResult
var res *FrameWithPayload
if err != nil {
return nil, err
}
@@ -465,7 +455,7 @@ func (p *SyncedPool) Process(id CartId, messages ...Message) (*CallResult, error
return res, nil
}
func (p *SyncedPool) Get(id CartId) (*CallResult, error) {
func (p *SyncedPool) Get(id CartId) (*FrameWithPayload, error) {
grain, err := p.getGrain(id)
if err != nil {
return nil, err

View File

@@ -1,96 +0,0 @@
package main
import (
"encoding/binary"
"log"
"sync"
)
type CartClient struct {
*CartTCPClient
}
func CartDial(address string) (*CartClient, error) {
mux, err := NewCartTCPClient(address)
if err != nil {
return nil, err
}
client := &CartClient{
CartTCPClient: mux,
}
return client, nil
}
func (c *Client) Close() {
log.Printf("Closing connection to %s\n", c.PersistentConnection.address)
c.PersistentConnection.Close()
}
type CartTCPClient struct {
PersistentConnection *PersistentConnection
sendMux sync.Mutex
ErrorCount int
address string
*CartPacketQueue
}
func NewCartTCPClient(address string) (*CartTCPClient, error) {
connection, err := NewPersistentConnection(address)
if err != nil {
return nil, err
}
return &CartTCPClient{
ErrorCount: 0,
PersistentConnection: connection,
address: address,
CartPacketQueue: NewCartPacketQueue(connection),
}, nil
}
func (m *CartTCPClient) SendPacket(messageType CartMessage, id CartId, data []byte) error {
m.sendMux.Lock()
defer m.sendMux.Unlock()
m.PersistentConnection.Conn.Write(header[:])
err := binary.Write(m.PersistentConnection, binary.LittleEndian, CartPacket{
Version: CurrentPacketVersion,
MessageType: messageType,
DataLength: uint32(len(data)),
Id: id,
})
if err != nil {
return m.PersistentConnection.HandleConnectionError(err)
}
_, err = m.PersistentConnection.Write(data)
return m.PersistentConnection.HandleConnectionError(err)
}
func (m *CartTCPClient) call(messageType CartMessage, id CartId, responseType CartMessage, data []byte) (*CallResult, error) {
packetChan := m.Expect(responseType, id)
err := m.SendPacket(messageType, id, data)
if err != nil {
return nil, m.PersistentConnection.HandleConnectionError(err)
}
ret := <-packetChan
return &ret, nil
}
func isRetirableError(err error) bool {
log.Printf("is retryable error: %v", err)
return false
}
func (m *CartTCPClient) Call(messageType CartMessage, id CartId, responseType CartMessage, data []byte) (*CallResult, error) {
retries := 0
result, err := m.call(messageType, id, responseType, data)
for err != nil && retries < 3 {
if !isRetirableError(err) {
break
}
retries++
log.Printf("Retrying call to %d\n", messageType)
result, err = m.call(messageType, id, responseType, data)
}
return result, err
}

View File

@@ -1,160 +0,0 @@
package main
import (
"bufio"
"encoding/binary"
"io"
"log"
"net"
"sync"
)
type CartServer struct {
*TCPCartServerMux
}
func CartListen(address string) (*CartServer, error) {
listener, err := net.Listen("tcp", address)
server := &CartServer{
NewCartTCPServerMux(),
}
if err != nil {
return nil, err
}
go func() {
for {
conn, err := listener.Accept()
if err != nil {
log.Printf("Error accepting connection: %v\n", err)
continue
}
go server.HandleConnection(conn)
}
}()
return server, nil
}
type TCPCartServerMux struct {
mu sync.RWMutex
sendMux sync.Mutex
listeners map[CartMessage]func(CartId, []byte) error
functions map[CartMessage]func(CartId, []byte) (CartMessage, []byte, error)
}
func NewCartTCPServerMux() *TCPCartServerMux {
m := &TCPCartServerMux{
mu: sync.RWMutex{},
listeners: make(map[CartMessage]func(CartId, []byte) error),
functions: make(map[CartMessage]func(CartId, []byte) (CartMessage, []byte, error)),
}
return m
}
func (m *TCPCartServerMux) handleListener(messageType CartMessage, id CartId, data []byte) (bool, error) {
m.mu.RLock()
handler, ok := m.listeners[messageType]
m.mu.RUnlock()
if ok {
err := handler(id, data)
if err != nil {
return true, err
}
}
return false, nil
}
func (m *TCPCartServerMux) handleFunction(connection net.Conn, messageType CartMessage, id CartId, data []byte) (bool, error) {
m.mu.RLock()
fn, ok := m.functions[messageType]
m.mu.RUnlock()
m.sendMux.Lock()
defer m.sendMux.Unlock()
if ok {
responseType, responseData, err := fn(id, data)
connection.Write(header[:])
if err != nil {
errData := []byte(err.Error())
err = binary.Write(connection, binary.LittleEndian, CartPacket{
Version: CurrentPacketVersion,
MessageType: responseType,
DataLength: uint32(len(errData)),
StatusCode: 500,
Id: id,
})
_, err = connection.Write(errData)
return true, err
}
err = binary.Write(connection, binary.LittleEndian, CartPacket{
Version: CurrentPacketVersion,
MessageType: responseType,
DataLength: uint32(len(responseData)),
StatusCode: 200,
Id: id,
})
if err != nil {
return true, err
}
packetsSent.Inc()
_, err = connection.Write(responseData)
return true, err
} else {
log.Printf("No cart handler for type: %d\n", messageType)
}
return false, nil
}
func (m *TCPCartServerMux) HandleConnection(connection net.Conn) error {
var packet CartPacket
var err error
defer connection.Close()
reader := bufio.NewReader(connection)
for {
err = ReadCartPacket(reader, &packet)
if err != nil {
if err == io.EOF {
return nil
}
log.Printf("Error receiving packet: %v\n", err)
return err
}
if packet.Version != CurrentPacketVersion {
log.Printf("Incorrect packet version: %d\n", packet.Version)
continue
}
data, err := GetPacketData(reader, packet.DataLength)
if err != nil {
log.Printf("Error getting packet data: %v\n", err)
}
go m.HandleData(connection, packet.MessageType, packet.Id, data)
}
}
func (m *TCPCartServerMux) HandleData(connection net.Conn, t CartMessage, id CartId, data []byte) {
status, err := m.handleListener(t, id, data)
if err != nil {
log.Printf("Error handling listener: %v\n", err)
}
if !status {
status, err = m.handleFunction(connection, t, id, data)
if err != nil {
log.Printf("Error handling function: %v\n", err)
}
if !status {
log.Printf("Unknown message type: %d\n", t)
}
}
}
func (m *TCPCartServerMux) ListenFor(messageType CartMessage, handler func(CartId, []byte) error) {
m.mu.Lock()
m.listeners[messageType] = handler
m.mu.Unlock()
}
func (m *TCPCartServerMux) HandleCall(messageType CartMessage, handler func(CartId, []byte) (CartMessage, []byte, error)) {
m.mu.Lock()
m.functions[messageType] = handler
m.mu.Unlock()
}

View File

@@ -1,57 +0,0 @@
package main
import (
"fmt"
"log"
"testing"
)
func TestCartTcpHelpers(t *testing.T) {
server, err := CartListen("localhost:51337")
if err != nil {
t.Errorf("Error listening: %v\n", err)
}
client, err := CartDial("localhost:51337")
if err != nil {
t.Errorf("Error dialing: %v\n", err)
}
var messageData string
server.ListenFor(1, func(id CartId, data []byte) error {
log.Printf("Received message: %s\n", string(data))
messageData = string(data)
return nil
})
server.HandleCall(666, func(id CartId, data []byte) (CartMessage, []byte, error) {
log.Printf("Received 666 call: %s\n", string(data))
return 3, []byte("Hello, client!"), fmt.Errorf("Det blev fel")
})
server.HandleCall(2, func(id CartId, data []byte) (CartMessage, []byte, error) {
log.Printf("Received 2 call: %s\n", string(data))
return 4, []byte("Hello, client!"), nil
})
// server.HandleCall(Ping, func(id CartId, data []byte) (CartMessage, []byte, error) {
// return Pong, nil, nil
// })
id := ToCartId("kalle")
client.SendPacket(1, id, []byte("Hello, world!"))
answer, err := client.Call(2, id, 4, []byte("Hello, server!"))
if err != nil {
t.Errorf("Error calling: %v\n", err)
}
s, err := client.Call(666, id, 3, []byte("Hello, server!"))
client.PersistentConnection.Close()
if err != nil {
t.Errorf("Error calling: %v\n", err)
}
if s.StatusCode != 500 {
t.Errorf("Expected 500, got %d\n", s.StatusCode)
}
if string(answer.Data) != "Hello, client!" {
t.Errorf("Expected answer 'Hello, client!', got %s\n", string(answer.Data))
}
if messageData != "Hello, world!" {
t.Errorf("Expected message 'Hello, world!', got %s\n", messageData)
}
}

View File

@@ -1,144 +0,0 @@
package main
import (
"encoding/binary"
"log"
"net"
"sync"
"time"
)
type Client struct {
*TCPClient
}
func Dial(address string) (*Client, error) {
mux, err := NewTCPClient(address)
if err != nil {
return nil, err
}
client := &Client{
TCPClient: mux,
}
return client, nil
}
type TCPClient struct {
PersistentConnection *PersistentConnection
sendMux sync.Mutex
ErrorCount int
address string
*PacketQueue
}
type PersistentConnection struct {
net.Conn
Died chan bool
Dead bool
address string
}
func NewPersistentConnection(address string) (*PersistentConnection, error) {
p := &PersistentConnection{
Died: make(chan bool, 1),
Dead: false,
address: address,
}
err := p.Connect()
if err != nil {
return nil, err
}
return p, nil
}
func (m *PersistentConnection) Connect() error {
fails := 0
for {
connection, err := net.Dial("tcp", m.address)
if err != nil {
log.Printf("Can't connect to %s: %v, count: %d", m.address, err, fails)
fails++
if fails > 15 {
log.Printf("Too many connection failures, closing connection to %s", m.address)
m.Died <- true
m.Dead = true
return err
}
} else {
m.Conn = connection
break
}
time.Sleep(time.Millisecond * 300)
}
return nil
}
func (m *PersistentConnection) Close() {
log.Printf("Closing connection to %s\n", m.address)
m.Conn.Close()
m.Died <- true
m.Dead = true
}
func (m *PersistentConnection) HandleConnectionError(err error) error {
if err != nil {
log.Printf("Error from to %s: %v\n", m.address, err)
m.Conn.Close()
m.Connect()
}
return err
}
func NewTCPClient(address string) (*TCPClient, error) {
connection, err := NewPersistentConnection(address)
if err != nil {
return nil, err
}
return &TCPClient{
ErrorCount: 0,
PersistentConnection: connection,
address: address,
PacketQueue: NewPacketQueue(connection),
}, nil
}
type PacketHeader [4]byte
var (
header = PacketHeader([4]byte{0x01, 0x02, 0x03, 0x04})
)
func (m *TCPClient) SendPacket(messageType PoolMessage, data []byte) error {
m.sendMux.Lock()
defer m.sendMux.Unlock()
m.PersistentConnection.Write(header[:])
err := binary.Write(m.PersistentConnection, binary.LittleEndian, Packet{
Version: CurrentPacketVersion,
MessageType: messageType,
StatusCode: 0,
DataLength: uint32(len(data)),
})
if err != nil {
return m.PersistentConnection.HandleConnectionError(err)
}
_, err = m.PersistentConnection.Write(data)
return m.PersistentConnection.HandleConnectionError(err)
}
func (m *TCPClient) Call(messageType PoolMessage, responseType PoolMessage, data []byte) (*CallResult, error) {
packetChan := m.Expect(responseType)
err := m.SendPacket(messageType, data)
if err != nil {
m.RemoveListeners()
return nil, m.PersistentConnection.HandleConnectionError(err)
}
ret := <-packetChan
return &ret, nil
}

View File

@@ -15,12 +15,23 @@ type Connection struct {
}
type FrameType uint32
type StatusCode uint32
type CheckSum uint32
type Frame struct {
Id uint64
Type FrameType
StatusCode uint32
StatusCode StatusCode
Length uint32
Checksum CheckSum
}
func (f *Frame) IsValid() bool {
return f.Checksum == MakeChecksum(f.Type, f.StatusCode, f.Length)
}
func MakeChecksum(msg FrameType, statusCode StatusCode, length uint32) CheckSum {
sum := CheckSum((uint32(msg) + uint32(statusCode) + length) / 8)
return sum
}
type FrameWithPayload struct {
@@ -28,6 +39,19 @@ type FrameWithPayload struct {
Payload []byte
}
func MakeFrameWithPayload(msg FrameType, statusCode StatusCode, payload []byte) FrameWithPayload {
len := uint32(len(payload))
return FrameWithPayload{
Frame: Frame{
Type: msg,
StatusCode: 0,
Length: len,
Checksum: MakeChecksum(msg, 0, len),
},
Payload: payload,
}
}
type FrameData interface {
ToBytes() []byte
FromBytes([]byte) error
@@ -41,11 +65,7 @@ func NewConnection(address string) *Connection {
}
func SendFrame(conn net.Conn, data *FrameWithPayload) error {
_, err := conn.Write(header[:])
if err != nil {
return err
}
err = binary.Write(conn, binary.LittleEndian, data.Frame)
err := binary.Write(conn, binary.LittleEndian, data.Frame)
if err != nil {
return err
}
@@ -53,68 +73,67 @@ func SendFrame(conn net.Conn, data *FrameWithPayload) error {
return err
}
func (c *Connection) CallAsync(msg FrameType, data FrameData, ch chan<- *FrameWithPayload) error {
func (c *Connection) CallAsync(msg FrameType, payload []byte, ch chan<- FrameWithPayload) (net.Conn, error) {
conn, err := net.Dial("tcp", c.address)
go WaitForFrame(conn, ch)
if err != nil {
return err
}
payload := data.ToBytes()
toSend := &FrameWithPayload{
Frame: Frame{
Id: c.count,
Type: msg,
StatusCode: 0,
Length: uint32(len(payload)),
},
Payload: payload,
return conn, err
}
toSend := MakeFrameWithPayload(msg, 1, payload)
err = SendFrame(conn, toSend)
err = SendFrame(conn, &toSend)
if err != nil {
conn.Close()
close(ch)
return err
return nil, err
}
c.count++
return nil
return conn, nil
}
func (c *Connection) Call(msg FrameType, data FrameData) (*FrameWithPayload, error) {
ch := make(chan *FrameWithPayload, 1)
c.CallAsync(msg, data, ch)
func (c *Connection) Call(msg FrameType, data []byte) (*FrameWithPayload, error) {
ch := make(chan FrameWithPayload, 1)
conn, err := c.CallAsync(msg, data, ch)
if err != nil {
return nil, err
}
defer conn.Close()
select {
case ret := <-ch:
return ret, nil
case <-time.After(5 * time.Second):
return &ret, nil
case <-time.After(MaxCallDuration):
return nil, fmt.Errorf("timeout")
}
}
func WaitForFrame(conn net.Conn, resultChan chan<- *FrameWithPayload) error {
defer conn.Close()
func WaitForFrame(conn net.Conn, resultChan chan<- FrameWithPayload) error {
var err error
var frame Frame
r := bufio.NewReader(conn)
h := make([]byte, 4)
r.Read(h)
if h[0] == header[0] && h[1] == header[1] && h[2] == header[2] && h[3] == header[3] {
frame := Frame{}
err = binary.Read(r, binary.LittleEndian, &frame)
err = binary.Read(r, binary.LittleEndian, &frame)
if err != nil {
return err
}
if frame.IsValid() {
payload := make([]byte, frame.Length)
_, err = r.Read(payload)
resultChan <- &FrameWithPayload{
if err != nil {
return err
}
resultChan <- FrameWithPayload{
Frame: frame,
Payload: payload,
}
return err
}
resultChan <- nil
return err
return fmt.Errorf("checksum mismatch")
}
type GenericListener struct {
Closed bool
handlers map[FrameType]func(*FrameWithPayload, chan<- *FrameWithPayload) error
handlers map[FrameType]func(*FrameWithPayload, chan<- FrameWithPayload) error
}
func (c *Connection) Listen() (*GenericListener, error) {
@@ -123,7 +142,7 @@ func (c *Connection) Listen() (*GenericListener, error) {
return nil, err
}
ret := &GenericListener{
handlers: make(map[FrameType]func(*FrameWithPayload, chan<- *FrameWithPayload) error),
handlers: make(map[FrameType]func(*FrameWithPayload, chan<- FrameWithPayload) error),
}
go func() {
for !ret.Closed {
@@ -137,36 +156,44 @@ func (c *Connection) Listen() (*GenericListener, error) {
return ret, nil
}
const (
MaxCallDuration = 500 * time.Millisecond
)
func (l *GenericListener) HandleConnection(conn net.Conn) {
ch := make(chan *FrameWithPayload, 1)
ch := make(chan FrameWithPayload, 1)
go WaitForFrame(conn, ch)
select {
case frame := <-ch:
go l.HandleFrame(conn, frame)
case <-time.After(1 * time.Second):
go l.HandleFrame(conn, &frame)
case <-time.After(MaxCallDuration):
close(ch)
log.Printf("Timeout waiting for frame\n")
}
}
func (l *GenericListener) AddHandler(msg FrameType, handler func(*FrameWithPayload, chan<- *FrameWithPayload) error) {
func (l *GenericListener) AddHandler(msg FrameType, handler func(*FrameWithPayload, chan<- FrameWithPayload) error) {
l.handlers[msg] = handler
}
func (l *GenericListener) HandleFrame(conn net.Conn, frame *FrameWithPayload) {
handler, ok := l.handlers[frame.Type]
defer conn.Close()
if ok {
go func() {
resultChan := make(chan *FrameWithPayload, 1)
resultChan := make(chan FrameWithPayload, 1)
defer close(resultChan)
err := handler(frame, resultChan)
if err != nil {
log.Fatalf("Error handling frame: %v\n", err)
}
SendFrame(conn, <-resultChan)
result := <-resultChan
err = SendFrame(conn, &result)
if err != nil {
log.Fatalf("Error sending frame: %v\n", err)
}
}()
} else {
conn.Close()
log.Fatalf("No handler for frame type %d\n", frame.Type)
}
}

View File

@@ -2,37 +2,19 @@ package main
import "testing"
type StringData string
func (s StringData) ToBytes() []byte {
return []byte(s)
}
func (s StringData) FromBytes(data []byte) error {
s = StringData(data)
return nil
}
func TestGenericConnection(t *testing.T) {
conn := NewConnection("localhost:51337")
listener, err := conn.Listen()
if err != nil {
t.Errorf("Error listening: %v\n", err)
}
listener.AddHandler(1, func(input *FrameWithPayload, resultChan chan<- *FrameWithPayload) error {
payload := []byte("Hello, world!")
resultChan <- &FrameWithPayload{
Frame: Frame{
Type: 2,
Id: input.Id,
StatusCode: 200,
Length: uint32(len("Hello, world!")),
},
Payload: payload,
}
datta := []byte("Hello, world!")
listener.AddHandler(1, func(input *FrameWithPayload, resultChan chan<- FrameWithPayload) error {
resultChan <- MakeFrameWithPayload(2, 200, datta)
return nil
})
r, err := conn.Call(1, StringData("Hello, world!"))
r, err := conn.Call(1, datta)
if err != nil {
t.Errorf("Error calling: %v\n", err)
}
@@ -40,9 +22,9 @@ func TestGenericConnection(t *testing.T) {
t.Errorf("Expected type 2, got %d\n", r.Type)
}
i := 100
results := make(chan *FrameWithPayload, i)
results := make(chan FrameWithPayload, i)
for i > 0 {
conn.CallAsync(1, StringData("Hello, world!"), results)
go conn.CallAsync(1, datta, results)
i--
}
for i < 100 {

View File

@@ -1,161 +0,0 @@
package main
import (
"bufio"
"encoding/binary"
"io"
"log"
"net"
"sync"
)
type Server struct {
*TCPServerMux
}
func Listen(address string) (*Server, error) {
listener, err := net.Listen("tcp", address)
server := &Server{
NewTCPServerMux(),
}
if err != nil {
return nil, err
}
go func() {
for {
conn, err := listener.Accept()
if err != nil {
log.Printf("Error accepting connection: %v\n", err)
continue
}
go server.HandleConnection(conn)
}
}()
return server, nil
}
type TCPServerMux struct {
mu sync.RWMutex
sendMux sync.Mutex
listeners map[PoolMessage]func(data []byte) error
functions map[PoolMessage]func(data []byte) (PoolMessage, []byte, error)
}
func NewTCPServerMux() *TCPServerMux {
m := &TCPServerMux{
mu: sync.RWMutex{},
listeners: make(map[PoolMessage]func(data []byte) error),
functions: make(map[PoolMessage]func(data []byte) (PoolMessage, []byte, error)),
}
return m
}
func (m *TCPServerMux) handleListener(messageType PoolMessage, data []byte) (bool, error) {
m.mu.RLock()
handler, ok := m.listeners[messageType]
m.mu.RUnlock()
if ok {
err := handler(data)
if err != nil {
return true, err
}
}
return false, nil
}
func (m *TCPServerMux) handleFunction(connection net.Conn, messageType PoolMessage, data []byte) (bool, error) {
m.mu.RLock()
function, ok := m.functions[messageType]
m.mu.RUnlock()
m.sendMux.Lock()
defer m.sendMux.Unlock()
if ok {
connection.Write(header[:])
responseType, responseData, err := function(data)
if err != nil {
errData := []byte(err.Error())
err = binary.Write(connection, binary.LittleEndian, Packet{
Version: CurrentPacketVersion,
MessageType: responseType,
StatusCode: 500,
DataLength: uint32(len(errData)),
})
_, err = connection.Write(errData)
return true, err
}
err = binary.Write(connection, binary.LittleEndian, Packet{
Version: CurrentPacketVersion,
MessageType: responseType,
StatusCode: 200,
DataLength: uint32(len(responseData)),
})
if err != nil {
return true, err
}
packetsSent.Inc()
_, err = connection.Write(responseData)
return true, err
} else {
log.Printf("No pool handler for type: %d\n", messageType)
}
return false, nil
}
func (m *TCPServerMux) HandleConnection(connection net.Conn) error {
defer connection.Close()
var packet Packet
reader := bufio.NewReader(connection)
for {
err := ReadPacket(reader, &packet)
if err != nil {
if err == io.EOF {
return nil
}
log.Printf("Error receiving packet: %v\n", err)
return err
}
if packet.Version != CurrentPacketVersion {
log.Printf("Incorrect package version: %v\n", err)
continue
}
data, err := GetPacketData(reader, packet.DataLength)
if err != nil {
log.Printf("Error receiving packet data: %v\n", err)
return err
}
go m.HandleData(connection, packet.MessageType, data)
}
}
func (m *TCPServerMux) HandleData(connection net.Conn, t PoolMessage, data []byte) {
// listener := m.listeners[t]
// handler := m.functions[t]
status, err := m.handleListener(t, data)
if err != nil {
log.Printf("Error handling listener: %v\n", err)
}
if !status {
status, err = m.handleFunction(connection, t, data)
if err != nil {
log.Printf("Error handling function: %v\n", err)
}
if !status {
log.Printf("Unknown message type: %d\n", t)
}
}
}
func (m *TCPServerMux) ListenFor(messageType PoolMessage, handler func(data []byte) error) {
m.mu.Lock()
m.listeners[messageType] = handler
m.mu.Unlock()
}
func (m *TCPServerMux) HandleCall(messageType PoolMessage, handler func(data []byte) (PoolMessage, []byte, error)) {
m.mu.Lock()
m.functions[messageType] = handler
m.mu.Unlock()
}

View File

@@ -1,54 +0,0 @@
package main
import (
"log"
"testing"
)
func TestTcpHelpers(t *testing.T) {
server, err := Listen("localhost:51337")
if err != nil {
t.Errorf("Error listening: %v\n", err)
}
client, err := Dial("localhost:51337")
if err != nil {
t.Errorf("Error dialing: %v\n", err)
}
var messageData string
server.ListenFor(1, func(data []byte) error {
log.Printf("Received message: %s\n", string(data))
messageData = string(data)
return nil
})
server.HandleCall(2, func(data []byte) (PoolMessage, []byte, error) {
log.Printf("Received call: %s\n", string(data))
return 3, []byte("Hello, client!"), nil
})
server.HandleCall(Ping, func(data []byte) (PoolMessage, []byte, error) {
return Pong, nil, nil
})
client.SendPacket(1, []byte("Hello, world!"))
answer, err := client.Call(2, 3, []byte("Hello, server!"))
if err != nil {
t.Errorf("Error calling: %v\n", err)
}
for i := 0; i < 100; i++ {
_, err = client.Call(Ping, Pong, nil)
if err != nil {
t.Errorf("Error calling: %v\n", err)
}
}
_, err = client.Call(Ping, Pong, nil)
if err != nil {
t.Errorf("Error calling: %v\n", err)
}
client.Close()
if string(answer.Data) != "Hello, client!" {
t.Errorf("Expected answer 'Hello, client!', got %s\n", string(answer.Data))
}
if messageData != "Hello, world!" {
t.Errorf("Expected message 'Hello, world!', got %s\n", messageData)
}
}