major refactor
This commit is contained in:
@@ -54,8 +54,8 @@ type CartGrain struct {
|
|||||||
|
|
||||||
type Grain interface {
|
type Grain interface {
|
||||||
GetId() CartId
|
GetId() CartId
|
||||||
HandleMessage(message *Message, isReplay bool) (*CallResult, error)
|
HandleMessage(message *Message, isReplay bool) (*FrameWithPayload, error)
|
||||||
GetCurrentState() (*CallResult, error)
|
GetCurrentState() (*FrameWithPayload, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *CartGrain) GetId() CartId {
|
func (c *CartGrain) GetId() CartId {
|
||||||
@@ -69,12 +69,14 @@ func (c *CartGrain) GetLastChange() int64 {
|
|||||||
return *c.storageMessages[len(c.storageMessages)-1].TimeStamp
|
return *c.storageMessages[len(c.storageMessages)-1].TimeStamp
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *CartGrain) GetCurrentState() (*CallResult, error) {
|
func (c *CartGrain) GetCurrentState() (*FrameWithPayload, error) {
|
||||||
result, err := json.Marshal(c)
|
result, err := json.Marshal(c)
|
||||||
return &CallResult{
|
if err != nil {
|
||||||
StatusCode: 200,
|
ret := MakeFrameWithPayload(0, 400, []byte(err.Error()))
|
||||||
Data: result,
|
return &ret, nil
|
||||||
}, err
|
}
|
||||||
|
ret := MakeFrameWithPayload(0, 200, result)
|
||||||
|
return &ret, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getItemData(sku string, qty int) (*messages.AddItem, error) {
|
func getItemData(sku string, qty int) (*messages.AddItem, error) {
|
||||||
@@ -108,7 +110,7 @@ func getItemData(sku string, qty int) (*messages.AddItem, error) {
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *CartGrain) AddItem(sku string, qty int) (*CallResult, error) {
|
func (c *CartGrain) AddItem(sku string, qty int) (*FrameWithPayload, error) {
|
||||||
cartItem, err := getItemData(sku, qty)
|
cartItem, err := getItemData(sku, qty)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -180,7 +182,7 @@ func (c *CartGrain) FindItemWithSku(sku string) (*CartItem, bool) {
|
|||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *CartGrain) HandleMessage(message *Message, isReplay bool) (*CallResult, error) {
|
func (c *CartGrain) HandleMessage(message *Message, isReplay bool) (*FrameWithPayload, error) {
|
||||||
if message.TimeStamp == nil {
|
if message.TimeStamp == nil {
|
||||||
now := time.Now().Unix()
|
now := time.Now().Unix()
|
||||||
message.TimeStamp = &now
|
message.TimeStamp = &now
|
||||||
@@ -305,8 +307,10 @@ func (c *CartGrain) HandleMessage(message *Message, isReplay bool) (*CallResult,
|
|||||||
c.mu.Unlock()
|
c.mu.Unlock()
|
||||||
}
|
}
|
||||||
result, err := json.Marshal(c)
|
result, err := json.Marshal(c)
|
||||||
return &CallResult{
|
return &FrameWithPayload{
|
||||||
StatusCode: 200,
|
Frame: Frame{
|
||||||
Data: result,
|
StatusCode: 200,
|
||||||
|
},
|
||||||
|
Payload: result,
|
||||||
}, err
|
}, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,163 +0,0 @@
|
|||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bufio"
|
|
||||||
"fmt"
|
|
||||||
"log"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
type CartPacketQueue struct {
|
|
||||||
mu sync.RWMutex
|
|
||||||
expectedPackages map[CartMessage]*CartListener
|
|
||||||
}
|
|
||||||
|
|
||||||
const CurrentPacketVersion = 2
|
|
||||||
|
|
||||||
type CartListener map[CartId]Listener
|
|
||||||
|
|
||||||
func NewCartPacketQueue(connection *PersistentConnection) *CartPacketQueue {
|
|
||||||
queue := &CartPacketQueue{
|
|
||||||
expectedPackages: make(map[CartMessage]*CartListener),
|
|
||||||
}
|
|
||||||
go queue.HandleConnection(connection)
|
|
||||||
return queue
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *CartPacketQueue) RemoveListeners() {
|
|
||||||
p.mu.Lock()
|
|
||||||
defer p.mu.Unlock()
|
|
||||||
for _, l := range p.expectedPackages {
|
|
||||||
for _, l := range *l {
|
|
||||||
close(l.Chan)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
p.expectedPackages = make(map[CartMessage]*CartListener)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *CartPacketQueue) HandleConnection(connection *PersistentConnection) error {
|
|
||||||
defer p.RemoveListeners()
|
|
||||||
defer connection.Close()
|
|
||||||
var packet CartPacket
|
|
||||||
reader := bufio.NewReader(connection)
|
|
||||||
for {
|
|
||||||
err := ReadCartPacket(reader, &packet)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Error receiving packet: %v\n", err)
|
|
||||||
return connection.HandleConnectionError(err)
|
|
||||||
}
|
|
||||||
if packet.Version != CurrentPacketVersion {
|
|
||||||
log.Printf("Incorrect version: %v\n", packet.Version)
|
|
||||||
return connection.HandleConnectionError(fmt.Errorf("incorrect version: %d", packet.Version))
|
|
||||||
}
|
|
||||||
if packet.DataLength == 0 {
|
|
||||||
go p.HandleData(packet.MessageType, packet.Id, CallResult{
|
|
||||||
StatusCode: packet.StatusCode,
|
|
||||||
Data: []byte{},
|
|
||||||
})
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
data, err := GetPacketData(reader, packet.DataLength)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Error receiving packet data: %v\n", err)
|
|
||||||
return connection.HandleConnectionError(err)
|
|
||||||
}
|
|
||||||
go p.HandleData(packet.MessageType, packet.Id, CallResult{
|
|
||||||
StatusCode: packet.StatusCode,
|
|
||||||
Data: data,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *CartPacketQueue) HandleData(t CartMessage, id CartId, data CallResult) {
|
|
||||||
p.getListener(t, id, func(l *Listener) {
|
|
||||||
l.Chan <- data
|
|
||||||
l.Count--
|
|
||||||
})
|
|
||||||
// p.mu.Lock()
|
|
||||||
// defer p.mu.Unlock()
|
|
||||||
// pl, ok := p.expectedPackages[t]
|
|
||||||
// if ok {
|
|
||||||
// l, ok := (*pl)[id]
|
|
||||||
// if ok {
|
|
||||||
// l.Chan <- data
|
|
||||||
// l.Count--
|
|
||||||
// if l.Count == 0 {
|
|
||||||
// close(l.Chan)
|
|
||||||
// delete(*pl, id)
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *CartPacketQueue) getListener(t CartMessage, id CartId, fn func(*Listener)) {
|
|
||||||
p.mu.Lock()
|
|
||||||
defer p.mu.Unlock()
|
|
||||||
pl, ok := p.expectedPackages[t]
|
|
||||||
if ok {
|
|
||||||
l, ok := (*pl)[id]
|
|
||||||
if ok {
|
|
||||||
fn(&l)
|
|
||||||
if l.Count == 0 {
|
|
||||||
close(l.Chan)
|
|
||||||
delete(*pl, id)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func CallResultWithTimeout(onTimeout func() CallResult) chan CallResult {
|
|
||||||
ch := make(chan CallResult, 1)
|
|
||||||
resultCh := make(chan CallResult, 1)
|
|
||||||
select {
|
|
||||||
case ret := <-resultCh:
|
|
||||||
ch <- ret
|
|
||||||
case <-time.After(300 * time.Millisecond):
|
|
||||||
ch <- onTimeout()
|
|
||||||
}
|
|
||||||
return ch
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *CartPacketQueue) MakeChannel(messageType CartMessage, id CartId) chan CallResult {
|
|
||||||
return CallResultWithTimeout(func() CallResult {
|
|
||||||
p.getListener(messageType, id, func(l *Listener) {
|
|
||||||
l.Count--
|
|
||||||
})
|
|
||||||
return CallResult{
|
|
||||||
StatusCode: 504,
|
|
||||||
Data: []byte("timeout cart call"),
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *CartPacketQueue) Expect(messageType CartMessage, id CartId) <-chan CallResult {
|
|
||||||
p.mu.Lock()
|
|
||||||
defer p.mu.Unlock()
|
|
||||||
l, ok := p.expectedPackages[messageType]
|
|
||||||
if ok {
|
|
||||||
if idl, idOk := (*l)[id]; idOk {
|
|
||||||
idl.Count++
|
|
||||||
return idl.Chan
|
|
||||||
}
|
|
||||||
ch := p.MakeChannel(messageType, id)
|
|
||||||
|
|
||||||
(*l)[id] = Listener{
|
|
||||||
Chan: ch,
|
|
||||||
Count: 1,
|
|
||||||
}
|
|
||||||
|
|
||||||
return ch
|
|
||||||
}
|
|
||||||
|
|
||||||
ch := p.MakeChannel(messageType, id)
|
|
||||||
p.expectedPackages[messageType] = &CartListener{
|
|
||||||
id: Listener{
|
|
||||||
Chan: ch,
|
|
||||||
Count: 1,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
return ch
|
|
||||||
|
|
||||||
}
|
|
||||||
@@ -27,8 +27,8 @@ var (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type GrainPool interface {
|
type GrainPool interface {
|
||||||
Process(id CartId, messages ...Message) (*CallResult, error)
|
Process(id CartId, messages ...Message) (*FrameWithPayload, error)
|
||||||
Get(id CartId) (*CallResult, error)
|
Get(id CartId) (*FrameWithPayload, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Ttl struct {
|
type Ttl struct {
|
||||||
@@ -142,23 +142,29 @@ func (p *GrainLocalPool) GetGrain(id CartId) (*CartGrain, error) {
|
|||||||
return grain, err
|
return grain, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *GrainLocalPool) Process(id CartId, messages ...Message) ([]byte, error) {
|
func (p *GrainLocalPool) Process(id CartId, messages ...Message) (*FrameWithPayload, error) {
|
||||||
grain, err := p.GetGrain(id)
|
grain, err := p.GetGrain(id)
|
||||||
|
var result *FrameWithPayload
|
||||||
if err == nil && grain != nil {
|
if err == nil && grain != nil {
|
||||||
for _, message := range messages {
|
for _, message := range messages {
|
||||||
_, err = grain.HandleMessage(&message, false)
|
result, err = grain.HandleMessage(&message, false)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return result, err
|
||||||
}
|
}
|
||||||
return json.Marshal(grain)
|
return result, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *GrainLocalPool) Get(id CartId) ([]byte, error) {
|
func (p *GrainLocalPool) Get(id CartId) (*FrameWithPayload, error) {
|
||||||
grain, err := p.GetGrain(id)
|
grain, err := p.GetGrain(id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return json.Marshal(grain)
|
data, err := json.Marshal(grain)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
ret := MakeFrameWithPayload(0, 200, data)
|
||||||
|
return &ret, nil
|
||||||
}
|
}
|
||||||
|
|||||||
122
packet-queue.go
122
packet-queue.go
@@ -1,122 +0,0 @@
|
|||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bufio"
|
|
||||||
"fmt"
|
|
||||||
"log"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
type PacketQueue struct {
|
|
||||||
mu sync.RWMutex
|
|
||||||
expectedPackages map[PoolMessage]*Listener
|
|
||||||
}
|
|
||||||
|
|
||||||
type CallResult struct {
|
|
||||||
StatusCode uint32
|
|
||||||
Data []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
type Listener struct {
|
|
||||||
Count int
|
|
||||||
Chan chan CallResult
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewPacketQueue(connection *PersistentConnection) *PacketQueue {
|
|
||||||
queue := &PacketQueue{
|
|
||||||
expectedPackages: make(map[PoolMessage]*Listener),
|
|
||||||
}
|
|
||||||
go queue.HandleConnection(connection)
|
|
||||||
return queue
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *PacketQueue) RemoveListeners() {
|
|
||||||
p.mu.Lock()
|
|
||||||
defer p.mu.Unlock()
|
|
||||||
for _, l := range p.expectedPackages {
|
|
||||||
close(l.Chan)
|
|
||||||
}
|
|
||||||
p.expectedPackages = make(map[PoolMessage]*Listener)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *PacketQueue) HandleConnection(connection *PersistentConnection) error {
|
|
||||||
defer connection.Close()
|
|
||||||
defer p.RemoveListeners()
|
|
||||||
var packet Packet
|
|
||||||
reader := bufio.NewReader(connection)
|
|
||||||
|
|
||||||
for {
|
|
||||||
err := ReadPacket(reader, &packet)
|
|
||||||
if err != nil {
|
|
||||||
return connection.HandleConnectionError(err)
|
|
||||||
}
|
|
||||||
if packet.Version != CurrentPacketVersion {
|
|
||||||
log.Printf("Incorrect packet version: %v\n", packet.Version)
|
|
||||||
return connection.HandleConnectionError(fmt.Errorf("incorrect packet version: %d", packet.Version))
|
|
||||||
}
|
|
||||||
if packet.DataLength == 0 {
|
|
||||||
go p.HandleData(packet.MessageType, CallResult{
|
|
||||||
StatusCode: packet.StatusCode,
|
|
||||||
Data: []byte{},
|
|
||||||
})
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
data, err := GetPacketData(reader, packet.DataLength)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Error receiving packet data: %v\n", err)
|
|
||||||
return connection.HandleConnectionError(err)
|
|
||||||
} else {
|
|
||||||
go p.HandleData(packet.MessageType, CallResult{
|
|
||||||
StatusCode: packet.StatusCode,
|
|
||||||
Data: data,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *PacketQueue) HandleData(t PoolMessage, data CallResult) {
|
|
||||||
p.mu.Lock()
|
|
||||||
defer p.mu.Unlock()
|
|
||||||
l, ok := p.expectedPackages[t]
|
|
||||||
if ok {
|
|
||||||
l.Chan <- data
|
|
||||||
l.Count--
|
|
||||||
if l.Count == 0 {
|
|
||||||
close(l.Chan)
|
|
||||||
delete(p.expectedPackages, t)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *PacketQueue) Expect(messageType PoolMessage) <-chan CallResult {
|
|
||||||
p.mu.Lock()
|
|
||||||
defer p.mu.Unlock()
|
|
||||||
l, ok := p.expectedPackages[messageType]
|
|
||||||
if ok {
|
|
||||||
l.Count++
|
|
||||||
return l.Chan
|
|
||||||
}
|
|
||||||
|
|
||||||
ch := make(chan CallResult, 1)
|
|
||||||
go func() {
|
|
||||||
time.Sleep(time.Millisecond * 300)
|
|
||||||
p.mu.Lock()
|
|
||||||
defer p.mu.Unlock()
|
|
||||||
|
|
||||||
ch <- CallResult{
|
|
||||||
StatusCode: 504,
|
|
||||||
Data: []byte("timeout cart call"),
|
|
||||||
}
|
|
||||||
|
|
||||||
close(ch)
|
|
||||||
}()
|
|
||||||
p.expectedPackages[messageType] = &Listener{
|
|
||||||
Count: 1,
|
|
||||||
Chan: ch,
|
|
||||||
}
|
|
||||||
|
|
||||||
return ch
|
|
||||||
|
|
||||||
}
|
|
||||||
@@ -1,28 +0,0 @@
|
|||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestQueue(t *testing.T) {
|
|
||||||
localPool := NewGrainLocalPool(100, time.Minute, func(id CartId) (*CartGrain, error) {
|
|
||||||
return &CartGrain{
|
|
||||||
Id: id,
|
|
||||||
storageMessages: []Message{},
|
|
||||||
Items: []*CartItem{},
|
|
||||||
TotalPrice: 0,
|
|
||||||
}, nil
|
|
||||||
})
|
|
||||||
pool, err := NewSyncedPool(localPool, "localhost", nil)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("Error creating pool: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = pool.AddRemote("localhost")
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("Error adding remote: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
130
packet.go
130
packet.go
@@ -1,85 +1,77 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/binary"
|
|
||||||
"io"
|
|
||||||
)
|
|
||||||
|
|
||||||
type CartMessage uint32
|
|
||||||
type PackageVersion uint32
|
|
||||||
|
|
||||||
const (
|
const (
|
||||||
RemoteGetState = CartMessage(0x01)
|
RemoteGetState = FrameType(0x01)
|
||||||
RemoteHandleMutation = CartMessage(0x02)
|
RemoteHandleMutation = FrameType(0x02)
|
||||||
ResponseBody = CartMessage(0x03)
|
ResponseBody = FrameType(0x03)
|
||||||
RemoteGetStateReply = CartMessage(0x04)
|
RemoteGetStateReply = FrameType(0x04)
|
||||||
RemoteHandleMutationReply = CartMessage(0x05)
|
RemoteHandleMutationReply = FrameType(0x05)
|
||||||
)
|
)
|
||||||
|
|
||||||
type CartPacket struct {
|
// type CartPacket struct {
|
||||||
Version PackageVersion
|
// Version PackageVersion
|
||||||
MessageType CartMessage
|
// MessageType CartMessage
|
||||||
DataLength uint32
|
// DataLength uint32
|
||||||
StatusCode uint32
|
// StatusCode uint32
|
||||||
Id CartId
|
// Id CartId
|
||||||
}
|
// }
|
||||||
|
|
||||||
type Packet struct {
|
// type Packet struct {
|
||||||
Version PackageVersion
|
// Version PackageVersion
|
||||||
MessageType PoolMessage
|
// MessageType PoolMessage
|
||||||
DataLength uint32
|
// DataLength uint32
|
||||||
StatusCode uint32
|
// StatusCode uint32
|
||||||
}
|
// }
|
||||||
|
|
||||||
var headerData = make([]byte, 4)
|
// var headerData = make([]byte, 4)
|
||||||
|
|
||||||
func matchHeader(conn io.Reader) error {
|
// func matchHeader(conn io.Reader) error {
|
||||||
|
|
||||||
pos := 0
|
// pos := 0
|
||||||
for pos < 4 {
|
// for pos < 4 {
|
||||||
|
|
||||||
l, err := conn.Read(headerData)
|
// l, err := conn.Read(headerData)
|
||||||
if err != nil {
|
// if err != nil {
|
||||||
return err
|
// return err
|
||||||
}
|
// }
|
||||||
for i := 0; i < l; i++ {
|
// for i := 0; i < l; i++ {
|
||||||
if headerData[i] == header[pos] {
|
// if headerData[i] == header[pos] {
|
||||||
pos++
|
// pos++
|
||||||
if pos == 4 {
|
// if pos == 4 {
|
||||||
return nil
|
// return nil
|
||||||
}
|
// }
|
||||||
} else {
|
// } else {
|
||||||
pos = 0
|
// pos = 0
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
return nil
|
// return nil
|
||||||
}
|
// }
|
||||||
|
|
||||||
func ReadPacket(conn io.Reader, packet *Packet) error {
|
// func ReadPacket(conn io.Reader, packet *Packet) error {
|
||||||
err := matchHeader(conn)
|
// err := matchHeader(conn)
|
||||||
if err != nil {
|
// if err != nil {
|
||||||
return err
|
// return err
|
||||||
}
|
// }
|
||||||
return binary.Read(conn, binary.LittleEndian, packet)
|
// return binary.Read(conn, binary.LittleEndian, packet)
|
||||||
}
|
// }
|
||||||
|
|
||||||
func ReadCartPacket(conn io.Reader, packet *CartPacket) error {
|
// func ReadCartPacket(conn io.Reader, packet *CartPacket) error {
|
||||||
err := matchHeader(conn)
|
// err := matchHeader(conn)
|
||||||
if err != nil {
|
// if err != nil {
|
||||||
return err
|
// return err
|
||||||
}
|
// }
|
||||||
return binary.Read(conn, binary.LittleEndian, packet)
|
// return binary.Read(conn, binary.LittleEndian, packet)
|
||||||
}
|
// }
|
||||||
|
|
||||||
func GetPacketData(conn io.Reader, len uint32) ([]byte, error) {
|
// func GetPacketData(conn io.Reader, len uint32) ([]byte, error) {
|
||||||
if len == 0 {
|
// if len == 0 {
|
||||||
return []byte{}, nil
|
// return []byte{}, nil
|
||||||
}
|
// }
|
||||||
data := make([]byte, len)
|
// data := make([]byte, len)
|
||||||
_, err := conn.Read(data)
|
// _, err := conn.Read(data)
|
||||||
return data, err
|
// return data, err
|
||||||
}
|
// }
|
||||||
|
|
||||||
// func ReceivePacket(conn io.Reader) (uint32, []byte, error) {
|
// func ReceivePacket(conn io.Reader) (uint32, []byte, error) {
|
||||||
// var packet Packet
|
// var packet Packet
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ func ErrorHandler(fn func(w http.ResponseWriter, r *http.Request) error) func(w
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *PoolServer) WriteResult(w http.ResponseWriter, result *CallResult) error {
|
func (s *PoolServer) WriteResult(w http.ResponseWriter, result *FrameWithPayload) error {
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
w.Header().Set("X-Pod-Name", s.pod_name)
|
w.Header().Set("X-Pod-Name", s.pod_name)
|
||||||
if result.StatusCode != 200 {
|
if result.StatusCode != 200 {
|
||||||
@@ -65,11 +65,11 @@ func (s *PoolServer) WriteResult(w http.ResponseWriter, result *CallResult) erro
|
|||||||
} else {
|
} else {
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
w.Write([]byte(result.Data))
|
w.Write([]byte(result.Payload))
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
_, err := w.Write(result.Data)
|
_, err := w.Write(result.Payload)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -46,8 +46,8 @@ func (p *RemoteGrainPool) Delete(id CartId) {
|
|||||||
p.mu.Unlock()
|
p.mu.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *RemoteGrainPool) Process(id CartId, messages ...Message) (*CallResult, error) {
|
func (p *RemoteGrainPool) Process(id CartId, messages ...Message) (*FrameWithPayload, error) {
|
||||||
var result *CallResult
|
var result *FrameWithPayload
|
||||||
grain, err := p.findOrCreateGrain(id)
|
grain, err := p.findOrCreateGrain(id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -58,7 +58,7 @@ func (p *RemoteGrainPool) Process(id CartId, messages ...Message) (*CallResult,
|
|||||||
return result, err
|
return result, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *RemoteGrainPool) Get(id CartId) (*CallResult, error) {
|
func (p *RemoteGrainPool) Get(id CartId) (*FrameWithPayload, error) {
|
||||||
grain, err := p.findOrCreateGrain(id)
|
grain, err := p.findOrCreateGrain(id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ package main
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/prometheus/client_golang/prometheus"
|
"github.com/prometheus/client_golang/prometheus"
|
||||||
"github.com/prometheus/client_golang/prometheus/promauto"
|
"github.com/prometheus/client_golang/prometheus/promauto"
|
||||||
@@ -13,6 +12,25 @@ func (id CartId) String() string {
|
|||||||
return strings.Trim(string(id[:]), "\x00")
|
return strings.Trim(string(id[:]), "\x00")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type CartIdPayload struct {
|
||||||
|
Id CartId
|
||||||
|
Data []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func MakeCartInnerFrame(id CartId, payload []byte) []byte {
|
||||||
|
return append(id[:], payload...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetCartFrame(data []byte) (*CartIdPayload, error) {
|
||||||
|
if len(data) < 16 {
|
||||||
|
return nil, fmt.Errorf("data too short")
|
||||||
|
}
|
||||||
|
return &CartIdPayload{
|
||||||
|
Id: CartId(data[:16]),
|
||||||
|
Data: data[16:],
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
func ToCartId(id string) CartId {
|
func ToCartId(id string) CartId {
|
||||||
var result [16]byte
|
var result [16]byte
|
||||||
copy(result[:], []byte(id))
|
copy(result[:], []byte(id))
|
||||||
@@ -20,21 +38,16 @@ func ToCartId(id string) CartId {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type RemoteGrain struct {
|
type RemoteGrain struct {
|
||||||
*CartClient
|
*Connection
|
||||||
Id CartId
|
Id CartId
|
||||||
Host string
|
Host string
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewRemoteGrain(id CartId, host string) (*RemoteGrain, error) {
|
func NewRemoteGrain(id CartId, host string) (*RemoteGrain, error) {
|
||||||
client, err := CartDial(fmt.Sprintf("%s:1337", host))
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return &RemoteGrain{
|
return &RemoteGrain{
|
||||||
Id: id,
|
Id: id,
|
||||||
Host: host,
|
Host: host,
|
||||||
CartClient: client,
|
Connection: NewConnection(fmt.Sprintf("%s:1337", host)),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -49,47 +62,35 @@ var (
|
|||||||
})
|
})
|
||||||
)
|
)
|
||||||
|
|
||||||
var start time.Time
|
// var start time.Time
|
||||||
|
|
||||||
func MeasureLatency(fn func() (*CallResult, error)) (*CallResult, error) {
|
// func MeasureLatency(fn func() (*CallResult, error)) (*CallResult, error) {
|
||||||
start = time.Now()
|
// start = time.Now()
|
||||||
data, err := fn()
|
// data, err := fn()
|
||||||
if err != nil {
|
// if err != nil {
|
||||||
return data, err
|
// return data, err
|
||||||
}
|
// }
|
||||||
elapsed := time.Since(start).Milliseconds()
|
// elapsed := time.Since(start).Milliseconds()
|
||||||
go func() {
|
// go func() {
|
||||||
remoteCartLatency.Add(float64(elapsed))
|
// remoteCartLatency.Add(float64(elapsed))
|
||||||
remoteCartCallsTotal.Inc()
|
// remoteCartCallsTotal.Inc()
|
||||||
}()
|
// }()
|
||||||
return data, nil
|
// return data, nil
|
||||||
}
|
// }
|
||||||
|
|
||||||
func (g *RemoteGrain) HandleMessage(message *Message, isReplay bool) (*CallResult, error) {
|
func (g *RemoteGrain) HandleMessage(message *Message, isReplay bool) (*FrameWithPayload, error) {
|
||||||
|
|
||||||
data, err := GetData(message.Write)
|
data, err := GetData(message.Write)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
reply, err := MeasureLatency(func() (*CallResult, error) {
|
return g.Call(RemoteHandleMutation, MakeCartInnerFrame(g.Id, data))
|
||||||
return g.Call(RemoteHandleMutation, g.Id, RemoteHandleMutationReply, data)
|
|
||||||
})
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return reply, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (g *RemoteGrain) Close() {
|
|
||||||
g.CartClient.PersistentConnection.Close()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *RemoteGrain) GetId() CartId {
|
func (g *RemoteGrain) GetId() CartId {
|
||||||
return g.Id
|
return g.Id
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *RemoteGrain) GetCurrentState() (*CallResult, error) {
|
func (g *RemoteGrain) GetCurrentState() (*FrameWithPayload, error) {
|
||||||
return MeasureLatency(func() (*CallResult, error) { return g.Call(RemoteGetState, g.Id, RemoteGetStateReply, []byte{}) })
|
return g.Call(RemoteGetState, MakeCartInnerFrame(g.Id, nil))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,13 +7,13 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type RemoteHost struct {
|
type RemoteHost struct {
|
||||||
*Client
|
*Connection
|
||||||
Host string
|
Host string
|
||||||
MissedPings int
|
MissedPings int
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *RemoteHost) IsHealthy() bool {
|
func (h *RemoteHost) IsHealthy() bool {
|
||||||
return !h.PersistentConnection.Dead && h.MissedPings < 3
|
return h.MissedPings < 3
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *RemoteHost) Initialize(p *SyncedPool) {
|
func (h *RemoteHost) Initialize(p *SyncedPool) {
|
||||||
@@ -38,15 +38,11 @@ func (h *RemoteHost) Initialize(p *SyncedPool) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h *RemoteHost) Ping() error {
|
func (h *RemoteHost) Ping() error {
|
||||||
_, err := h.Call(Ping, Pong, []byte{})
|
result, err := h.Call(Ping, nil)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil || result.StatusCode != 200 || result.Type != Pong {
|
||||||
h.MissedPings++
|
h.MissedPings++
|
||||||
log.Printf("Error pinging remote %s, missed pings: %d", h.Host, h.MissedPings)
|
log.Printf("Error pinging remote %s, missed pings: %d", h.Host, h.MissedPings)
|
||||||
if !h.IsHealthy() {
|
|
||||||
h.Close()
|
|
||||||
return fmt.Errorf("remote %s is dead", h.Host)
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
h.MissedPings = 0
|
h.MissedPings = 0
|
||||||
}
|
}
|
||||||
@@ -54,28 +50,28 @@ func (h *RemoteHost) Ping() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h *RemoteHost) Negotiate(knownHosts []string) ([]string, error) {
|
func (h *RemoteHost) Negotiate(knownHosts []string) ([]string, error) {
|
||||||
reply, err := h.Call(RemoteNegotiate, RemoteNegotiateResponse, []byte(strings.Join(knownHosts, ";")))
|
reply, err := h.Call(RemoteNegotiate, []byte(strings.Join(knownHosts, ";")))
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if reply.StatusCode != 200 {
|
if reply.StatusCode != 200 {
|
||||||
return nil, fmt.Errorf("remote returned error on negotiate: %s", string(reply.Data))
|
return nil, fmt.Errorf("remote returned error on negotiate: %s", string(reply.Payload))
|
||||||
}
|
}
|
||||||
|
|
||||||
return strings.Split(string(reply.Data), ";"), nil
|
return strings.Split(string(reply.Payload), ";"), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *RemoteHost) GetCartMappings() ([]CartId, error) {
|
func (g *RemoteHost) GetCartMappings() ([]CartId, error) {
|
||||||
reply, err := g.Call(GetCartIds, CartIdsResponse, []byte{})
|
reply, err := g.Call(GetCartIds, []byte{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if reply.StatusCode != 200 {
|
if reply.StatusCode != 200 || reply.Type != CartIdsResponse {
|
||||||
log.Printf("Remote returned error on get cart mappings: %s", string(reply.Data))
|
log.Printf("Remote returned error on get cart mappings: %s", string(reply.Payload))
|
||||||
return nil, fmt.Errorf("remote returned error: %s", string(reply.Data))
|
return nil, fmt.Errorf("remote returned incorrect data")
|
||||||
}
|
}
|
||||||
parts := strings.Split(string(reply.Data), ";")
|
parts := strings.Split(string(reply.Payload), ";")
|
||||||
ids := make([]CartId, 0, len(parts))
|
ids := make([]CartId, 0, len(parts))
|
||||||
for _, p := range parts {
|
for _, p := range parts {
|
||||||
ids = append(ids, ToCartId(p))
|
ids = append(ids, ToCartId(p))
|
||||||
@@ -84,14 +80,11 @@ func (g *RemoteHost) GetCartMappings() ([]CartId, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *RemoteHost) ConfirmChange(id CartId, host string) error {
|
func (r *RemoteHost) ConfirmChange(id CartId, host string) error {
|
||||||
reply, err := r.Call(RemoteGrainChanged, AckChange, []byte(fmt.Sprintf("%s;%s", id, host)))
|
reply, err := r.Call(RemoteGrainChanged, []byte(fmt.Sprintf("%s;%s", id, host)))
|
||||||
|
|
||||||
if err != nil {
|
if err != nil || reply.StatusCode != 200 || reply.Type != AckChange {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if string(reply.Data) != "ok" {
|
|
||||||
return fmt.Errorf("remote grain change failed %s", string(reply.Data))
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type GrainHandler struct {
|
type GrainHandler struct {
|
||||||
*CartServer
|
*GenericListener
|
||||||
pool *GrainLocalPool
|
pool *GrainLocalPool
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -20,13 +20,14 @@ func (h *GrainHandler) GetState(id CartId, reply *Grain) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func NewGrainHandler(pool *GrainLocalPool, listen string) (*GrainHandler, error) {
|
func NewGrainHandler(pool *GrainLocalPool, listen string) (*GrainHandler, error) {
|
||||||
server, err := CartListen(listen)
|
conn := NewConnection(listen)
|
||||||
|
server, err := conn.Listen()
|
||||||
handler := &GrainHandler{
|
handler := &GrainHandler{
|
||||||
CartServer: server,
|
GenericListener: server,
|
||||||
pool: pool,
|
pool: pool,
|
||||||
}
|
}
|
||||||
server.HandleCall(RemoteHandleMutation, handler.RemoteHandleMessageHandler)
|
server.AddHandler(RemoteHandleMutation, handler.RemoteHandleMessageHandler)
|
||||||
server.HandleCall(RemoteGetState, handler.RemoteGetStateHandler)
|
server.AddHandler(RemoteGetState, handler.RemoteGetStateHandler)
|
||||||
return handler, err
|
return handler, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -34,29 +35,36 @@ func (h *GrainHandler) IsHealthy() bool {
|
|||||||
return len(h.pool.grains) < h.pool.PoolSize
|
return len(h.pool.grains) < h.pool.PoolSize
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *GrainHandler) RemoteHandleMessageHandler(id CartId, data []byte) (CartMessage, []byte, error) {
|
func (h *GrainHandler) RemoteHandleMessageHandler(data *FrameWithPayload, resultChan chan<- FrameWithPayload) error {
|
||||||
|
cartData, err := GetCartFrame(data.Payload)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
var msg Message
|
var msg Message
|
||||||
err := ReadMessage(bytes.NewReader(data), &msg)
|
err = ReadMessage(bytes.NewReader(cartData.Data), &msg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Println("Error reading message:", err)
|
fmt.Println("Error reading message:", err)
|
||||||
return RemoteHandleMutationReply, nil, err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
replyData, err := h.pool.Process(id, msg)
|
replyData, err := h.pool.Process(cartData.Id, msg)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Println("Error handling message:", err)
|
fmt.Println("Error handling message:", err)
|
||||||
}
|
}
|
||||||
if err != nil {
|
resultChan <- *replyData
|
||||||
return RemoteHandleMutationReply, nil, err
|
return nil
|
||||||
}
|
|
||||||
return RemoteHandleMutationReply, replyData, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *GrainHandler) RemoteGetStateHandler(id CartId, data []byte) (CartMessage, []byte, error) {
|
func (h *GrainHandler) RemoteGetStateHandler(data *FrameWithPayload, resultChan chan<- FrameWithPayload) error {
|
||||||
reply, err := h.pool.Get(id)
|
cartData, err := GetCartFrame(data.Payload)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return RemoteGetStateReply, nil, err
|
return err
|
||||||
}
|
}
|
||||||
return RemoteGetStateReply, reply, nil
|
reply, err := h.pool.Get(cartData.Id)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
resultChan <- *reply
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
118
synced-pool.go
118
synced-pool.go
@@ -22,7 +22,7 @@ type HealthHandler interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type SyncedPool struct {
|
type SyncedPool struct {
|
||||||
*Server
|
Server *GenericListener
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
Hostname string
|
Hostname string
|
||||||
local *GrainLocalPool
|
local *GrainLocalPool
|
||||||
@@ -61,11 +61,16 @@ var (
|
|||||||
})
|
})
|
||||||
)
|
)
|
||||||
|
|
||||||
func (p *SyncedPool) PongHandler(data []byte) (PoolMessage, []byte, error) {
|
var (
|
||||||
return Pong, data, nil
|
PongResponse = MakeFrameWithPayload(Pong, 200, nil)
|
||||||
|
)
|
||||||
|
|
||||||
|
func (p *SyncedPool) PongHandler(data *FrameWithPayload, resultChan chan<- FrameWithPayload) error {
|
||||||
|
resultChan <- PongResponse
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *SyncedPool) GetCartIdHandler(data []byte) (PoolMessage, []byte, error) {
|
func (p *SyncedPool) GetCartIdHandler(data *FrameWithPayload, resultChan chan<- FrameWithPayload) error {
|
||||||
ids := make([]string, 0, len(p.local.grains))
|
ids := make([]string, 0, len(p.local.grains))
|
||||||
for id := range p.local.grains {
|
for id := range p.local.grains {
|
||||||
if p.local.grains[id] == nil {
|
if p.local.grains[id] == nil {
|
||||||
@@ -78,45 +83,45 @@ func (p *SyncedPool) GetCartIdHandler(data []byte) (PoolMessage, []byte, error)
|
|||||||
ids = append(ids, s)
|
ids = append(ids, s)
|
||||||
}
|
}
|
||||||
log.Printf("Returning %d cart ids\n", len(ids))
|
log.Printf("Returning %d cart ids\n", len(ids))
|
||||||
return CartIdsResponse, []byte(strings.Join(ids, ";")), nil
|
resultChan <- MakeFrameWithPayload(CartIdsResponse, 200, []byte(strings.Join(ids, ";")))
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *SyncedPool) NegotiateHandler(data []byte) (PoolMessage, []byte, error) {
|
func (p *SyncedPool) NegotiateHandler(data *FrameWithPayload, resultChan chan<- FrameWithPayload) error {
|
||||||
negotiationCount.Inc()
|
negotiationCount.Inc()
|
||||||
log.Printf("Handling negotiation\n")
|
log.Printf("Handling negotiation\n")
|
||||||
for _, host := range p.ExcludeKnown(strings.Split(string(data), ";")) {
|
for _, host := range p.ExcludeKnown(strings.Split(string(data.Payload), ";")) {
|
||||||
if host == "" {
|
if host == "" {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
go p.AddRemote(host)
|
go p.AddRemote(host)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
resultChan <- MakeFrameWithPayload(RemoteNegotiateResponse, 200, []byte("ok"))
|
||||||
return RemoteNegotiateResponse, []byte("ok"), nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *SyncedPool) GrainOwnerChangeHandler(data []byte) (PoolMessage, []byte, error) {
|
func (p *SyncedPool) GrainOwnerChangeHandler(data *FrameWithPayload, resultChan chan<- FrameWithPayload) error {
|
||||||
grainSyncCount.Inc()
|
grainSyncCount.Inc()
|
||||||
|
|
||||||
idAndHostParts := strings.Split(string(data), ";")
|
idAndHostParts := strings.Split(string(data.Payload), ";")
|
||||||
if len(idAndHostParts) != 2 {
|
if len(idAndHostParts) != 2 {
|
||||||
log.Printf("Invalid remote grain change message\n")
|
log.Printf("Invalid remote grain change message\n")
|
||||||
return AckChange, []byte("incorrect"), fmt.Errorf("invalid remote grain change message")
|
resultChan <- MakeFrameWithPayload(AckError, 400, []byte("invalid"))
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
id := ToCartId(idAndHostParts[0])
|
id := ToCartId(idAndHostParts[0])
|
||||||
host := idAndHostParts[1]
|
host := idAndHostParts[1]
|
||||||
log.Printf("Handling remote grain owner change to %s for id %s\n", host, id)
|
log.Printf("Handling remote grain owner change to %s for id %s\n", host, id)
|
||||||
for _, r := range p.remotes {
|
for _, r := range p.remotes {
|
||||||
if r.Host == host && r.IsHealthy() {
|
if r.Host == host && r.IsHealthy() {
|
||||||
// log.Printf("Remote grain %s changed to %s\n", id, host)
|
|
||||||
|
|
||||||
go p.SpawnRemoteGrain(id, host)
|
go p.SpawnRemoteGrain(id, host)
|
||||||
|
break
|
||||||
return AckChange, []byte("ok"), nil
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
go p.AddRemote(host)
|
go p.AddRemote(host)
|
||||||
return AckChange, []byte("ok"), nil
|
resultChan <- MakeFrameWithPayload(AckChange, 200, []byte("ok"))
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *SyncedPool) RemoveRemoteGrain(id CartId) {
|
func (p *SyncedPool) RemoveRemoteGrain(id CartId) {
|
||||||
@@ -142,12 +147,12 @@ func (p *SyncedPool) SpawnRemoteGrain(id CartId, host string) {
|
|||||||
log.Printf("Error creating remote grain %v\n", err)
|
log.Printf("Error creating remote grain %v\n", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
go func() {
|
// go func() {
|
||||||
<-remote.PersistentConnection.Died
|
// <-remote.Died
|
||||||
p.RemoveRemoteGrain(id)
|
// p.RemoveRemoteGrain(id)
|
||||||
p.HandleHostError(host)
|
// p.HandleHostError(host)
|
||||||
log.Printf("Remote grain %s died, host: %s\n", id.String(), host)
|
// log.Printf("Remote grain %s died, host: %s\n", id.String(), host)
|
||||||
}()
|
// }()
|
||||||
|
|
||||||
p.mu.Lock()
|
p.mu.Lock()
|
||||||
p.remoteIndex[id] = remote
|
p.remoteIndex[id] = remote
|
||||||
@@ -159,8 +164,6 @@ func (p *SyncedPool) HandleHostError(host string) {
|
|||||||
if r.Host == host {
|
if r.Host == host {
|
||||||
if !r.IsHealthy() {
|
if !r.IsHealthy() {
|
||||||
p.RemoveHost(r)
|
p.RemoveHost(r)
|
||||||
} else {
|
|
||||||
r.ErrorCount++
|
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -169,8 +172,8 @@ func (p *SyncedPool) HandleHostError(host string) {
|
|||||||
|
|
||||||
func NewSyncedPool(local *GrainLocalPool, hostname string, discovery Discovery) (*SyncedPool, error) {
|
func NewSyncedPool(local *GrainLocalPool, hostname string, discovery Discovery) (*SyncedPool, error) {
|
||||||
listen := fmt.Sprintf("%s:1338", hostname)
|
listen := fmt.Sprintf("%s:1338", hostname)
|
||||||
|
conn := NewConnection(listen)
|
||||||
server, err := Listen(listen)
|
server, err := conn.Listen()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -186,10 +189,10 @@ func NewSyncedPool(local *GrainLocalPool, hostname string, discovery Discovery)
|
|||||||
remoteIndex: make(map[CartId]*RemoteGrain),
|
remoteIndex: make(map[CartId]*RemoteGrain),
|
||||||
}
|
}
|
||||||
|
|
||||||
server.HandleCall(Ping, pool.PongHandler)
|
server.AddHandler(Ping, pool.PongHandler)
|
||||||
server.HandleCall(GetCartIds, pool.GetCartIdHandler)
|
server.AddHandler(GetCartIds, pool.GetCartIdHandler)
|
||||||
server.HandleCall(RemoteNegotiate, pool.NegotiateHandler)
|
server.AddHandler(RemoteNegotiate, pool.NegotiateHandler)
|
||||||
server.HandleCall(RemoteGrainChanged, pool.GrainOwnerChangeHandler)
|
server.AddHandler(RemoteGrainChanged, pool.GrainOwnerChangeHandler)
|
||||||
|
|
||||||
if discovery != nil {
|
if discovery != nil {
|
||||||
go func() {
|
go func() {
|
||||||
@@ -259,18 +262,11 @@ func (p *SyncedPool) ExcludeKnown(hosts []string) []string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (p *SyncedPool) RemoveHost(host *RemoteHost) {
|
func (p *SyncedPool) RemoveHost(host *RemoteHost) {
|
||||||
if p.remotes[host.Host] == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
p.mu.Lock()
|
p.mu.Lock()
|
||||||
defer p.mu.Unlock()
|
|
||||||
|
|
||||||
h := p.remotes[host.Host]
|
|
||||||
if h != nil {
|
|
||||||
h.Close()
|
|
||||||
}
|
|
||||||
delete(p.remotes, host.Host)
|
delete(p.remotes, host.Host)
|
||||||
|
p.mu.Unlock()
|
||||||
|
p.RemoveHostMappedCarts(host)
|
||||||
connectedRemotes.Set(float64(len(p.remotes)))
|
connectedRemotes.Set(float64(len(p.remotes)))
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -279,24 +275,21 @@ func (p *SyncedPool) RemoveHostMappedCarts(host *RemoteHost) {
|
|||||||
defer p.mu.Unlock()
|
defer p.mu.Unlock()
|
||||||
for id, r := range p.remoteIndex {
|
for id, r := range p.remoteIndex {
|
||||||
if r.Host == host.Host {
|
if r.Host == host.Host {
|
||||||
p.remoteIndex[id].Close()
|
|
||||||
delete(p.remoteIndex, id)
|
delete(p.remoteIndex, id)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type PoolMessage uint32
|
|
||||||
|
|
||||||
const (
|
const (
|
||||||
RemoteNegotiate = PoolMessage(3)
|
RemoteNegotiate = FrameType(3)
|
||||||
RemoteGrainChanged = PoolMessage(4)
|
RemoteGrainChanged = FrameType(4)
|
||||||
AckChange = PoolMessage(5)
|
AckChange = FrameType(5)
|
||||||
//AckError = PoolMessage(6)
|
AckError = FrameType(6)
|
||||||
Ping = PoolMessage(7)
|
Ping = FrameType(7)
|
||||||
Pong = PoolMessage(8)
|
Pong = FrameType(8)
|
||||||
GetCartIds = PoolMessage(9)
|
GetCartIds = FrameType(9)
|
||||||
CartIdsResponse = PoolMessage(10)
|
CartIdsResponse = FrameType(10)
|
||||||
RemoteNegotiateResponse = PoolMessage(11)
|
RemoteNegotiateResponse = FrameType(11)
|
||||||
)
|
)
|
||||||
|
|
||||||
func (p *SyncedPool) Negotiate() {
|
func (p *SyncedPool) Negotiate() {
|
||||||
@@ -377,25 +370,22 @@ func (p *SyncedPool) AddRemote(host string) error {
|
|||||||
if host == "" || p.IsKnown(host) || hasHost {
|
if host == "" || p.IsKnown(host) || hasHost {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
client, err := Dial(fmt.Sprintf("%s:1338", host))
|
client := NewConnection(fmt.Sprintf("%s:1338", host))
|
||||||
if err != nil {
|
response, err := client.Call(Ping, nil)
|
||||||
|
if err != nil || response.StatusCode != 200 || response.Type != Pong {
|
||||||
log.Printf("Error connecting to remote %s: %v\n", host, err)
|
log.Printf("Error connecting to remote %s: %v\n", host, err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
remote := RemoteHost{
|
remote := RemoteHost{
|
||||||
Client: client,
|
Connection: client,
|
||||||
MissedPings: 0,
|
MissedPings: 0,
|
||||||
Host: host,
|
Host: host,
|
||||||
}
|
}
|
||||||
p.mu.Lock()
|
p.mu.Lock()
|
||||||
p.remotes[host] = &remote
|
p.remotes[host] = &remote
|
||||||
p.mu.Unlock()
|
p.mu.Unlock()
|
||||||
go func() {
|
|
||||||
<-remote.PersistentConnection.Died
|
|
||||||
log.Printf("Removing host, remote died %s", host)
|
|
||||||
p.RemoveHost(&remote)
|
|
||||||
}()
|
|
||||||
go func() {
|
go func() {
|
||||||
for range time.Tick(time.Second * 3) {
|
for range time.Tick(time.Second * 3) {
|
||||||
|
|
||||||
@@ -450,9 +440,9 @@ func (p *SyncedPool) getGrain(id CartId) (Grain, error) {
|
|||||||
return localGrain, nil
|
return localGrain, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *SyncedPool) Process(id CartId, messages ...Message) (*CallResult, error) {
|
func (p *SyncedPool) Process(id CartId, messages ...Message) (*FrameWithPayload, error) {
|
||||||
pool, err := p.getGrain(id)
|
pool, err := p.getGrain(id)
|
||||||
var res *CallResult
|
var res *FrameWithPayload
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -465,7 +455,7 @@ func (p *SyncedPool) Process(id CartId, messages ...Message) (*CallResult, error
|
|||||||
return res, nil
|
return res, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *SyncedPool) Get(id CartId) (*CallResult, error) {
|
func (p *SyncedPool) Get(id CartId) (*FrameWithPayload, error) {
|
||||||
grain, err := p.getGrain(id)
|
grain, err := p.getGrain(id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|||||||
@@ -1,96 +0,0 @@
|
|||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/binary"
|
|
||||||
"log"
|
|
||||||
"sync"
|
|
||||||
)
|
|
||||||
|
|
||||||
type CartClient struct {
|
|
||||||
*CartTCPClient
|
|
||||||
}
|
|
||||||
|
|
||||||
func CartDial(address string) (*CartClient, error) {
|
|
||||||
|
|
||||||
mux, err := NewCartTCPClient(address)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
client := &CartClient{
|
|
||||||
CartTCPClient: mux,
|
|
||||||
}
|
|
||||||
return client, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Client) Close() {
|
|
||||||
log.Printf("Closing connection to %s\n", c.PersistentConnection.address)
|
|
||||||
c.PersistentConnection.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
type CartTCPClient struct {
|
|
||||||
PersistentConnection *PersistentConnection
|
|
||||||
sendMux sync.Mutex
|
|
||||||
ErrorCount int
|
|
||||||
address string
|
|
||||||
*CartPacketQueue
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewCartTCPClient(address string) (*CartTCPClient, error) {
|
|
||||||
connection, err := NewPersistentConnection(address)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &CartTCPClient{
|
|
||||||
ErrorCount: 0,
|
|
||||||
PersistentConnection: connection,
|
|
||||||
address: address,
|
|
||||||
CartPacketQueue: NewCartPacketQueue(connection),
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *CartTCPClient) SendPacket(messageType CartMessage, id CartId, data []byte) error {
|
|
||||||
m.sendMux.Lock()
|
|
||||||
defer m.sendMux.Unlock()
|
|
||||||
m.PersistentConnection.Conn.Write(header[:])
|
|
||||||
err := binary.Write(m.PersistentConnection, binary.LittleEndian, CartPacket{
|
|
||||||
Version: CurrentPacketVersion,
|
|
||||||
MessageType: messageType,
|
|
||||||
DataLength: uint32(len(data)),
|
|
||||||
Id: id,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return m.PersistentConnection.HandleConnectionError(err)
|
|
||||||
}
|
|
||||||
_, err = m.PersistentConnection.Write(data)
|
|
||||||
return m.PersistentConnection.HandleConnectionError(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *CartTCPClient) call(messageType CartMessage, id CartId, responseType CartMessage, data []byte) (*CallResult, error) {
|
|
||||||
packetChan := m.Expect(responseType, id)
|
|
||||||
err := m.SendPacket(messageType, id, data)
|
|
||||||
if err != nil {
|
|
||||||
return nil, m.PersistentConnection.HandleConnectionError(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
ret := <-packetChan
|
|
||||||
return &ret, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func isRetirableError(err error) bool {
|
|
||||||
log.Printf("is retryable error: %v", err)
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *CartTCPClient) Call(messageType CartMessage, id CartId, responseType CartMessage, data []byte) (*CallResult, error) {
|
|
||||||
retries := 0
|
|
||||||
result, err := m.call(messageType, id, responseType, data)
|
|
||||||
for err != nil && retries < 3 {
|
|
||||||
if !isRetirableError(err) {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
retries++
|
|
||||||
log.Printf("Retrying call to %d\n", messageType)
|
|
||||||
result, err = m.call(messageType, id, responseType, data)
|
|
||||||
}
|
|
||||||
return result, err
|
|
||||||
}
|
|
||||||
@@ -1,160 +0,0 @@
|
|||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bufio"
|
|
||||||
"encoding/binary"
|
|
||||||
"io"
|
|
||||||
"log"
|
|
||||||
"net"
|
|
||||||
"sync"
|
|
||||||
)
|
|
||||||
|
|
||||||
type CartServer struct {
|
|
||||||
*TCPCartServerMux
|
|
||||||
}
|
|
||||||
|
|
||||||
func CartListen(address string) (*CartServer, error) {
|
|
||||||
listener, err := net.Listen("tcp", address)
|
|
||||||
server := &CartServer{
|
|
||||||
NewCartTCPServerMux(),
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
go func() {
|
|
||||||
for {
|
|
||||||
conn, err := listener.Accept()
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Error accepting connection: %v\n", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
go server.HandleConnection(conn)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
return server, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type TCPCartServerMux struct {
|
|
||||||
mu sync.RWMutex
|
|
||||||
sendMux sync.Mutex
|
|
||||||
listeners map[CartMessage]func(CartId, []byte) error
|
|
||||||
functions map[CartMessage]func(CartId, []byte) (CartMessage, []byte, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewCartTCPServerMux() *TCPCartServerMux {
|
|
||||||
m := &TCPCartServerMux{
|
|
||||||
mu: sync.RWMutex{},
|
|
||||||
listeners: make(map[CartMessage]func(CartId, []byte) error),
|
|
||||||
functions: make(map[CartMessage]func(CartId, []byte) (CartMessage, []byte, error)),
|
|
||||||
}
|
|
||||||
|
|
||||||
return m
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *TCPCartServerMux) handleListener(messageType CartMessage, id CartId, data []byte) (bool, error) {
|
|
||||||
m.mu.RLock()
|
|
||||||
handler, ok := m.listeners[messageType]
|
|
||||||
m.mu.RUnlock()
|
|
||||||
if ok {
|
|
||||||
err := handler(id, data)
|
|
||||||
if err != nil {
|
|
||||||
return true, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *TCPCartServerMux) handleFunction(connection net.Conn, messageType CartMessage, id CartId, data []byte) (bool, error) {
|
|
||||||
m.mu.RLock()
|
|
||||||
fn, ok := m.functions[messageType]
|
|
||||||
m.mu.RUnlock()
|
|
||||||
m.sendMux.Lock()
|
|
||||||
defer m.sendMux.Unlock()
|
|
||||||
if ok {
|
|
||||||
responseType, responseData, err := fn(id, data)
|
|
||||||
connection.Write(header[:])
|
|
||||||
if err != nil {
|
|
||||||
errData := []byte(err.Error())
|
|
||||||
err = binary.Write(connection, binary.LittleEndian, CartPacket{
|
|
||||||
Version: CurrentPacketVersion,
|
|
||||||
MessageType: responseType,
|
|
||||||
DataLength: uint32(len(errData)),
|
|
||||||
StatusCode: 500,
|
|
||||||
Id: id,
|
|
||||||
})
|
|
||||||
_, err = connection.Write(errData)
|
|
||||||
return true, err
|
|
||||||
}
|
|
||||||
err = binary.Write(connection, binary.LittleEndian, CartPacket{
|
|
||||||
Version: CurrentPacketVersion,
|
|
||||||
MessageType: responseType,
|
|
||||||
DataLength: uint32(len(responseData)),
|
|
||||||
StatusCode: 200,
|
|
||||||
Id: id,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return true, err
|
|
||||||
}
|
|
||||||
packetsSent.Inc()
|
|
||||||
_, err = connection.Write(responseData)
|
|
||||||
return true, err
|
|
||||||
} else {
|
|
||||||
log.Printf("No cart handler for type: %d\n", messageType)
|
|
||||||
}
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *TCPCartServerMux) HandleConnection(connection net.Conn) error {
|
|
||||||
var packet CartPacket
|
|
||||||
var err error
|
|
||||||
defer connection.Close()
|
|
||||||
reader := bufio.NewReader(connection)
|
|
||||||
for {
|
|
||||||
err = ReadCartPacket(reader, &packet)
|
|
||||||
if err != nil {
|
|
||||||
if err == io.EOF {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
log.Printf("Error receiving packet: %v\n", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if packet.Version != CurrentPacketVersion {
|
|
||||||
log.Printf("Incorrect packet version: %d\n", packet.Version)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
data, err := GetPacketData(reader, packet.DataLength)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Error getting packet data: %v\n", err)
|
|
||||||
}
|
|
||||||
go m.HandleData(connection, packet.MessageType, packet.Id, data)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *TCPCartServerMux) HandleData(connection net.Conn, t CartMessage, id CartId, data []byte) {
|
|
||||||
status, err := m.handleListener(t, id, data)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Error handling listener: %v\n", err)
|
|
||||||
}
|
|
||||||
if !status {
|
|
||||||
status, err = m.handleFunction(connection, t, id, data)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Error handling function: %v\n", err)
|
|
||||||
}
|
|
||||||
if !status {
|
|
||||||
log.Printf("Unknown message type: %d\n", t)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *TCPCartServerMux) ListenFor(messageType CartMessage, handler func(CartId, []byte) error) {
|
|
||||||
m.mu.Lock()
|
|
||||||
m.listeners[messageType] = handler
|
|
||||||
m.mu.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *TCPCartServerMux) HandleCall(messageType CartMessage, handler func(CartId, []byte) (CartMessage, []byte, error)) {
|
|
||||||
m.mu.Lock()
|
|
||||||
m.functions[messageType] = handler
|
|
||||||
m.mu.Unlock()
|
|
||||||
}
|
|
||||||
@@ -1,57 +0,0 @@
|
|||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"log"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestCartTcpHelpers(t *testing.T) {
|
|
||||||
|
|
||||||
server, err := CartListen("localhost:51337")
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("Error listening: %v\n", err)
|
|
||||||
}
|
|
||||||
client, err := CartDial("localhost:51337")
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("Error dialing: %v\n", err)
|
|
||||||
}
|
|
||||||
var messageData string
|
|
||||||
server.ListenFor(1, func(id CartId, data []byte) error {
|
|
||||||
log.Printf("Received message: %s\n", string(data))
|
|
||||||
messageData = string(data)
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
server.HandleCall(666, func(id CartId, data []byte) (CartMessage, []byte, error) {
|
|
||||||
log.Printf("Received 666 call: %s\n", string(data))
|
|
||||||
return 3, []byte("Hello, client!"), fmt.Errorf("Det blev fel")
|
|
||||||
})
|
|
||||||
server.HandleCall(2, func(id CartId, data []byte) (CartMessage, []byte, error) {
|
|
||||||
log.Printf("Received 2 call: %s\n", string(data))
|
|
||||||
return 4, []byte("Hello, client!"), nil
|
|
||||||
})
|
|
||||||
// server.HandleCall(Ping, func(id CartId, data []byte) (CartMessage, []byte, error) {
|
|
||||||
// return Pong, nil, nil
|
|
||||||
// })
|
|
||||||
id := ToCartId("kalle")
|
|
||||||
client.SendPacket(1, id, []byte("Hello, world!"))
|
|
||||||
answer, err := client.Call(2, id, 4, []byte("Hello, server!"))
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("Error calling: %v\n", err)
|
|
||||||
}
|
|
||||||
s, err := client.Call(666, id, 3, []byte("Hello, server!"))
|
|
||||||
client.PersistentConnection.Close()
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("Error calling: %v\n", err)
|
|
||||||
}
|
|
||||||
if s.StatusCode != 500 {
|
|
||||||
t.Errorf("Expected 500, got %d\n", s.StatusCode)
|
|
||||||
}
|
|
||||||
|
|
||||||
if string(answer.Data) != "Hello, client!" {
|
|
||||||
t.Errorf("Expected answer 'Hello, client!', got %s\n", string(answer.Data))
|
|
||||||
}
|
|
||||||
if messageData != "Hello, world!" {
|
|
||||||
t.Errorf("Expected message 'Hello, world!', got %s\n", messageData)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
144
tcp-client.go
144
tcp-client.go
@@ -1,144 +0,0 @@
|
|||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/binary"
|
|
||||||
"log"
|
|
||||||
"net"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Client struct {
|
|
||||||
*TCPClient
|
|
||||||
}
|
|
||||||
|
|
||||||
func Dial(address string) (*Client, error) {
|
|
||||||
|
|
||||||
mux, err := NewTCPClient(address)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
client := &Client{
|
|
||||||
TCPClient: mux,
|
|
||||||
}
|
|
||||||
return client, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type TCPClient struct {
|
|
||||||
PersistentConnection *PersistentConnection
|
|
||||||
sendMux sync.Mutex
|
|
||||||
ErrorCount int
|
|
||||||
address string
|
|
||||||
*PacketQueue
|
|
||||||
}
|
|
||||||
|
|
||||||
type PersistentConnection struct {
|
|
||||||
net.Conn
|
|
||||||
Died chan bool
|
|
||||||
Dead bool
|
|
||||||
address string
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewPersistentConnection(address string) (*PersistentConnection, error) {
|
|
||||||
|
|
||||||
p := &PersistentConnection{
|
|
||||||
Died: make(chan bool, 1),
|
|
||||||
Dead: false,
|
|
||||||
address: address,
|
|
||||||
}
|
|
||||||
err := p.Connect()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return p, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *PersistentConnection) Connect() error {
|
|
||||||
fails := 0
|
|
||||||
for {
|
|
||||||
connection, err := net.Dial("tcp", m.address)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Can't connect to %s: %v, count: %d", m.address, err, fails)
|
|
||||||
fails++
|
|
||||||
if fails > 15 {
|
|
||||||
log.Printf("Too many connection failures, closing connection to %s", m.address)
|
|
||||||
m.Died <- true
|
|
||||||
m.Dead = true
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
m.Conn = connection
|
|
||||||
break
|
|
||||||
}
|
|
||||||
time.Sleep(time.Millisecond * 300)
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *PersistentConnection) Close() {
|
|
||||||
log.Printf("Closing connection to %s\n", m.address)
|
|
||||||
m.Conn.Close()
|
|
||||||
m.Died <- true
|
|
||||||
m.Dead = true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *PersistentConnection) HandleConnectionError(err error) error {
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Error from to %s: %v\n", m.address, err)
|
|
||||||
m.Conn.Close()
|
|
||||||
m.Connect()
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewTCPClient(address string) (*TCPClient, error) {
|
|
||||||
|
|
||||||
connection, err := NewPersistentConnection(address)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &TCPClient{
|
|
||||||
ErrorCount: 0,
|
|
||||||
PersistentConnection: connection,
|
|
||||||
address: address,
|
|
||||||
PacketQueue: NewPacketQueue(connection),
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type PacketHeader [4]byte
|
|
||||||
|
|
||||||
var (
|
|
||||||
header = PacketHeader([4]byte{0x01, 0x02, 0x03, 0x04})
|
|
||||||
)
|
|
||||||
|
|
||||||
func (m *TCPClient) SendPacket(messageType PoolMessage, data []byte) error {
|
|
||||||
m.sendMux.Lock()
|
|
||||||
defer m.sendMux.Unlock()
|
|
||||||
m.PersistentConnection.Write(header[:])
|
|
||||||
err := binary.Write(m.PersistentConnection, binary.LittleEndian, Packet{
|
|
||||||
Version: CurrentPacketVersion,
|
|
||||||
MessageType: messageType,
|
|
||||||
StatusCode: 0,
|
|
||||||
DataLength: uint32(len(data)),
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return m.PersistentConnection.HandleConnectionError(err)
|
|
||||||
}
|
|
||||||
_, err = m.PersistentConnection.Write(data)
|
|
||||||
return m.PersistentConnection.HandleConnectionError(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *TCPClient) Call(messageType PoolMessage, responseType PoolMessage, data []byte) (*CallResult, error) {
|
|
||||||
packetChan := m.Expect(responseType)
|
|
||||||
err := m.SendPacket(messageType, data)
|
|
||||||
if err != nil {
|
|
||||||
m.RemoveListeners()
|
|
||||||
return nil, m.PersistentConnection.HandleConnectionError(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
ret := <-packetChan
|
|
||||||
return &ret, nil
|
|
||||||
|
|
||||||
}
|
|
||||||
@@ -15,12 +15,23 @@ type Connection struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type FrameType uint32
|
type FrameType uint32
|
||||||
|
type StatusCode uint32
|
||||||
|
type CheckSum uint32
|
||||||
|
|
||||||
type Frame struct {
|
type Frame struct {
|
||||||
Id uint64
|
|
||||||
Type FrameType
|
Type FrameType
|
||||||
StatusCode uint32
|
StatusCode StatusCode
|
||||||
Length uint32
|
Length uint32
|
||||||
|
Checksum CheckSum
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Frame) IsValid() bool {
|
||||||
|
return f.Checksum == MakeChecksum(f.Type, f.StatusCode, f.Length)
|
||||||
|
}
|
||||||
|
|
||||||
|
func MakeChecksum(msg FrameType, statusCode StatusCode, length uint32) CheckSum {
|
||||||
|
sum := CheckSum((uint32(msg) + uint32(statusCode) + length) / 8)
|
||||||
|
return sum
|
||||||
}
|
}
|
||||||
|
|
||||||
type FrameWithPayload struct {
|
type FrameWithPayload struct {
|
||||||
@@ -28,6 +39,19 @@ type FrameWithPayload struct {
|
|||||||
Payload []byte
|
Payload []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func MakeFrameWithPayload(msg FrameType, statusCode StatusCode, payload []byte) FrameWithPayload {
|
||||||
|
len := uint32(len(payload))
|
||||||
|
return FrameWithPayload{
|
||||||
|
Frame: Frame{
|
||||||
|
Type: msg,
|
||||||
|
StatusCode: 0,
|
||||||
|
Length: len,
|
||||||
|
Checksum: MakeChecksum(msg, 0, len),
|
||||||
|
},
|
||||||
|
Payload: payload,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
type FrameData interface {
|
type FrameData interface {
|
||||||
ToBytes() []byte
|
ToBytes() []byte
|
||||||
FromBytes([]byte) error
|
FromBytes([]byte) error
|
||||||
@@ -41,11 +65,7 @@ func NewConnection(address string) *Connection {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func SendFrame(conn net.Conn, data *FrameWithPayload) error {
|
func SendFrame(conn net.Conn, data *FrameWithPayload) error {
|
||||||
_, err := conn.Write(header[:])
|
err := binary.Write(conn, binary.LittleEndian, data.Frame)
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
err = binary.Write(conn, binary.LittleEndian, data.Frame)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -53,68 +73,67 @@ func SendFrame(conn net.Conn, data *FrameWithPayload) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Connection) CallAsync(msg FrameType, data FrameData, ch chan<- *FrameWithPayload) error {
|
func (c *Connection) CallAsync(msg FrameType, payload []byte, ch chan<- FrameWithPayload) (net.Conn, error) {
|
||||||
conn, err := net.Dial("tcp", c.address)
|
conn, err := net.Dial("tcp", c.address)
|
||||||
go WaitForFrame(conn, ch)
|
go WaitForFrame(conn, ch)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return conn, err
|
||||||
}
|
|
||||||
payload := data.ToBytes()
|
|
||||||
toSend := &FrameWithPayload{
|
|
||||||
Frame: Frame{
|
|
||||||
Id: c.count,
|
|
||||||
Type: msg,
|
|
||||||
StatusCode: 0,
|
|
||||||
Length: uint32(len(payload)),
|
|
||||||
},
|
|
||||||
Payload: payload,
|
|
||||||
}
|
}
|
||||||
|
toSend := MakeFrameWithPayload(msg, 1, payload)
|
||||||
|
|
||||||
err = SendFrame(conn, toSend)
|
err = SendFrame(conn, &toSend)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
conn.Close()
|
||||||
close(ch)
|
close(ch)
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
c.count++
|
c.count++
|
||||||
return nil
|
return conn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Connection) Call(msg FrameType, data FrameData) (*FrameWithPayload, error) {
|
func (c *Connection) Call(msg FrameType, data []byte) (*FrameWithPayload, error) {
|
||||||
ch := make(chan *FrameWithPayload, 1)
|
ch := make(chan FrameWithPayload, 1)
|
||||||
c.CallAsync(msg, data, ch)
|
conn, err := c.CallAsync(msg, data, ch)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
select {
|
select {
|
||||||
case ret := <-ch:
|
case ret := <-ch:
|
||||||
return ret, nil
|
return &ret, nil
|
||||||
case <-time.After(5 * time.Second):
|
case <-time.After(MaxCallDuration):
|
||||||
return nil, fmt.Errorf("timeout")
|
return nil, fmt.Errorf("timeout")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func WaitForFrame(conn net.Conn, resultChan chan<- *FrameWithPayload) error {
|
func WaitForFrame(conn net.Conn, resultChan chan<- FrameWithPayload) error {
|
||||||
defer conn.Close()
|
|
||||||
var err error
|
var err error
|
||||||
|
var frame Frame
|
||||||
r := bufio.NewReader(conn)
|
r := bufio.NewReader(conn)
|
||||||
h := make([]byte, 4)
|
|
||||||
r.Read(h)
|
err = binary.Read(r, binary.LittleEndian, &frame)
|
||||||
if h[0] == header[0] && h[1] == header[1] && h[2] == header[2] && h[3] == header[3] {
|
if err != nil {
|
||||||
frame := Frame{}
|
return err
|
||||||
err = binary.Read(r, binary.LittleEndian, &frame)
|
}
|
||||||
|
if frame.IsValid() {
|
||||||
payload := make([]byte, frame.Length)
|
payload := make([]byte, frame.Length)
|
||||||
_, err = r.Read(payload)
|
_, err = r.Read(payload)
|
||||||
resultChan <- &FrameWithPayload{
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
resultChan <- FrameWithPayload{
|
||||||
Frame: frame,
|
Frame: frame,
|
||||||
Payload: payload,
|
Payload: payload,
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
resultChan <- nil
|
return fmt.Errorf("checksum mismatch")
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type GenericListener struct {
|
type GenericListener struct {
|
||||||
Closed bool
|
Closed bool
|
||||||
handlers map[FrameType]func(*FrameWithPayload, chan<- *FrameWithPayload) error
|
handlers map[FrameType]func(*FrameWithPayload, chan<- FrameWithPayload) error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Connection) Listen() (*GenericListener, error) {
|
func (c *Connection) Listen() (*GenericListener, error) {
|
||||||
@@ -123,7 +142,7 @@ func (c *Connection) Listen() (*GenericListener, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
ret := &GenericListener{
|
ret := &GenericListener{
|
||||||
handlers: make(map[FrameType]func(*FrameWithPayload, chan<- *FrameWithPayload) error),
|
handlers: make(map[FrameType]func(*FrameWithPayload, chan<- FrameWithPayload) error),
|
||||||
}
|
}
|
||||||
go func() {
|
go func() {
|
||||||
for !ret.Closed {
|
for !ret.Closed {
|
||||||
@@ -137,36 +156,44 @@ func (c *Connection) Listen() (*GenericListener, error) {
|
|||||||
return ret, nil
|
return ret, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
MaxCallDuration = 500 * time.Millisecond
|
||||||
|
)
|
||||||
|
|
||||||
func (l *GenericListener) HandleConnection(conn net.Conn) {
|
func (l *GenericListener) HandleConnection(conn net.Conn) {
|
||||||
ch := make(chan *FrameWithPayload, 1)
|
ch := make(chan FrameWithPayload, 1)
|
||||||
go WaitForFrame(conn, ch)
|
go WaitForFrame(conn, ch)
|
||||||
select {
|
select {
|
||||||
case frame := <-ch:
|
case frame := <-ch:
|
||||||
go l.HandleFrame(conn, frame)
|
go l.HandleFrame(conn, &frame)
|
||||||
case <-time.After(1 * time.Second):
|
case <-time.After(MaxCallDuration):
|
||||||
close(ch)
|
close(ch)
|
||||||
log.Printf("Timeout waiting for frame\n")
|
log.Printf("Timeout waiting for frame\n")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *GenericListener) AddHandler(msg FrameType, handler func(*FrameWithPayload, chan<- *FrameWithPayload) error) {
|
func (l *GenericListener) AddHandler(msg FrameType, handler func(*FrameWithPayload, chan<- FrameWithPayload) error) {
|
||||||
l.handlers[msg] = handler
|
l.handlers[msg] = handler
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *GenericListener) HandleFrame(conn net.Conn, frame *FrameWithPayload) {
|
func (l *GenericListener) HandleFrame(conn net.Conn, frame *FrameWithPayload) {
|
||||||
handler, ok := l.handlers[frame.Type]
|
handler, ok := l.handlers[frame.Type]
|
||||||
defer conn.Close()
|
|
||||||
if ok {
|
if ok {
|
||||||
go func() {
|
go func() {
|
||||||
resultChan := make(chan *FrameWithPayload, 1)
|
resultChan := make(chan FrameWithPayload, 1)
|
||||||
defer close(resultChan)
|
defer close(resultChan)
|
||||||
err := handler(frame, resultChan)
|
err := handler(frame, resultChan)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("Error handling frame: %v\n", err)
|
log.Fatalf("Error handling frame: %v\n", err)
|
||||||
}
|
}
|
||||||
SendFrame(conn, <-resultChan)
|
result := <-resultChan
|
||||||
|
err = SendFrame(conn, &result)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Error sending frame: %v\n", err)
|
||||||
|
}
|
||||||
}()
|
}()
|
||||||
} else {
|
} else {
|
||||||
|
conn.Close()
|
||||||
log.Fatalf("No handler for frame type %d\n", frame.Type)
|
log.Fatalf("No handler for frame type %d\n", frame.Type)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,37 +2,19 @@ package main
|
|||||||
|
|
||||||
import "testing"
|
import "testing"
|
||||||
|
|
||||||
type StringData string
|
|
||||||
|
|
||||||
func (s StringData) ToBytes() []byte {
|
|
||||||
return []byte(s)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s StringData) FromBytes(data []byte) error {
|
|
||||||
s = StringData(data)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGenericConnection(t *testing.T) {
|
func TestGenericConnection(t *testing.T) {
|
||||||
conn := NewConnection("localhost:51337")
|
conn := NewConnection("localhost:51337")
|
||||||
listener, err := conn.Listen()
|
listener, err := conn.Listen()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Error listening: %v\n", err)
|
t.Errorf("Error listening: %v\n", err)
|
||||||
}
|
}
|
||||||
listener.AddHandler(1, func(input *FrameWithPayload, resultChan chan<- *FrameWithPayload) error {
|
datta := []byte("Hello, world!")
|
||||||
payload := []byte("Hello, world!")
|
listener.AddHandler(1, func(input *FrameWithPayload, resultChan chan<- FrameWithPayload) error {
|
||||||
resultChan <- &FrameWithPayload{
|
|
||||||
Frame: Frame{
|
resultChan <- MakeFrameWithPayload(2, 200, datta)
|
||||||
Type: 2,
|
|
||||||
Id: input.Id,
|
|
||||||
StatusCode: 200,
|
|
||||||
Length: uint32(len("Hello, world!")),
|
|
||||||
},
|
|
||||||
Payload: payload,
|
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
r, err := conn.Call(1, StringData("Hello, world!"))
|
r, err := conn.Call(1, datta)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Error calling: %v\n", err)
|
t.Errorf("Error calling: %v\n", err)
|
||||||
}
|
}
|
||||||
@@ -40,9 +22,9 @@ func TestGenericConnection(t *testing.T) {
|
|||||||
t.Errorf("Expected type 2, got %d\n", r.Type)
|
t.Errorf("Expected type 2, got %d\n", r.Type)
|
||||||
}
|
}
|
||||||
i := 100
|
i := 100
|
||||||
results := make(chan *FrameWithPayload, i)
|
results := make(chan FrameWithPayload, i)
|
||||||
for i > 0 {
|
for i > 0 {
|
||||||
conn.CallAsync(1, StringData("Hello, world!"), results)
|
go conn.CallAsync(1, datta, results)
|
||||||
i--
|
i--
|
||||||
}
|
}
|
||||||
for i < 100 {
|
for i < 100 {
|
||||||
|
|||||||
@@ -1,161 +0,0 @@
|
|||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bufio"
|
|
||||||
"encoding/binary"
|
|
||||||
"io"
|
|
||||||
"log"
|
|
||||||
"net"
|
|
||||||
"sync"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Server struct {
|
|
||||||
*TCPServerMux
|
|
||||||
}
|
|
||||||
|
|
||||||
func Listen(address string) (*Server, error) {
|
|
||||||
listener, err := net.Listen("tcp", address)
|
|
||||||
server := &Server{
|
|
||||||
NewTCPServerMux(),
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
go func() {
|
|
||||||
for {
|
|
||||||
conn, err := listener.Accept()
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Error accepting connection: %v\n", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
go server.HandleConnection(conn)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
return server, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type TCPServerMux struct {
|
|
||||||
mu sync.RWMutex
|
|
||||||
sendMux sync.Mutex
|
|
||||||
listeners map[PoolMessage]func(data []byte) error
|
|
||||||
functions map[PoolMessage]func(data []byte) (PoolMessage, []byte, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewTCPServerMux() *TCPServerMux {
|
|
||||||
m := &TCPServerMux{
|
|
||||||
mu: sync.RWMutex{},
|
|
||||||
listeners: make(map[PoolMessage]func(data []byte) error),
|
|
||||||
functions: make(map[PoolMessage]func(data []byte) (PoolMessage, []byte, error)),
|
|
||||||
}
|
|
||||||
|
|
||||||
return m
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *TCPServerMux) handleListener(messageType PoolMessage, data []byte) (bool, error) {
|
|
||||||
m.mu.RLock()
|
|
||||||
handler, ok := m.listeners[messageType]
|
|
||||||
m.mu.RUnlock()
|
|
||||||
if ok {
|
|
||||||
err := handler(data)
|
|
||||||
if err != nil {
|
|
||||||
return true, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *TCPServerMux) handleFunction(connection net.Conn, messageType PoolMessage, data []byte) (bool, error) {
|
|
||||||
m.mu.RLock()
|
|
||||||
function, ok := m.functions[messageType]
|
|
||||||
m.mu.RUnlock()
|
|
||||||
m.sendMux.Lock()
|
|
||||||
defer m.sendMux.Unlock()
|
|
||||||
if ok {
|
|
||||||
connection.Write(header[:])
|
|
||||||
responseType, responseData, err := function(data)
|
|
||||||
if err != nil {
|
|
||||||
errData := []byte(err.Error())
|
|
||||||
err = binary.Write(connection, binary.LittleEndian, Packet{
|
|
||||||
Version: CurrentPacketVersion,
|
|
||||||
MessageType: responseType,
|
|
||||||
StatusCode: 500,
|
|
||||||
DataLength: uint32(len(errData)),
|
|
||||||
})
|
|
||||||
_, err = connection.Write(errData)
|
|
||||||
return true, err
|
|
||||||
}
|
|
||||||
err = binary.Write(connection, binary.LittleEndian, Packet{
|
|
||||||
Version: CurrentPacketVersion,
|
|
||||||
MessageType: responseType,
|
|
||||||
StatusCode: 200,
|
|
||||||
DataLength: uint32(len(responseData)),
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return true, err
|
|
||||||
}
|
|
||||||
packetsSent.Inc()
|
|
||||||
_, err = connection.Write(responseData)
|
|
||||||
return true, err
|
|
||||||
} else {
|
|
||||||
log.Printf("No pool handler for type: %d\n", messageType)
|
|
||||||
}
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *TCPServerMux) HandleConnection(connection net.Conn) error {
|
|
||||||
|
|
||||||
defer connection.Close()
|
|
||||||
var packet Packet
|
|
||||||
reader := bufio.NewReader(connection)
|
|
||||||
for {
|
|
||||||
err := ReadPacket(reader, &packet)
|
|
||||||
if err != nil {
|
|
||||||
if err == io.EOF {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
log.Printf("Error receiving packet: %v\n", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if packet.Version != CurrentPacketVersion {
|
|
||||||
log.Printf("Incorrect package version: %v\n", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
data, err := GetPacketData(reader, packet.DataLength)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Error receiving packet data: %v\n", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
go m.HandleData(connection, packet.MessageType, data)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *TCPServerMux) HandleData(connection net.Conn, t PoolMessage, data []byte) {
|
|
||||||
// listener := m.listeners[t]
|
|
||||||
// handler := m.functions[t]
|
|
||||||
status, err := m.handleListener(t, data)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Error handling listener: %v\n", err)
|
|
||||||
}
|
|
||||||
if !status {
|
|
||||||
status, err = m.handleFunction(connection, t, data)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Error handling function: %v\n", err)
|
|
||||||
}
|
|
||||||
if !status {
|
|
||||||
log.Printf("Unknown message type: %d\n", t)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *TCPServerMux) ListenFor(messageType PoolMessage, handler func(data []byte) error) {
|
|
||||||
m.mu.Lock()
|
|
||||||
m.listeners[messageType] = handler
|
|
||||||
m.mu.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *TCPServerMux) HandleCall(messageType PoolMessage, handler func(data []byte) (PoolMessage, []byte, error)) {
|
|
||||||
m.mu.Lock()
|
|
||||||
m.functions[messageType] = handler
|
|
||||||
m.mu.Unlock()
|
|
||||||
}
|
|
||||||
54
tcp_test.go
54
tcp_test.go
@@ -1,54 +0,0 @@
|
|||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"log"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestTcpHelpers(t *testing.T) {
|
|
||||||
|
|
||||||
server, err := Listen("localhost:51337")
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("Error listening: %v\n", err)
|
|
||||||
}
|
|
||||||
client, err := Dial("localhost:51337")
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("Error dialing: %v\n", err)
|
|
||||||
}
|
|
||||||
var messageData string
|
|
||||||
server.ListenFor(1, func(data []byte) error {
|
|
||||||
log.Printf("Received message: %s\n", string(data))
|
|
||||||
messageData = string(data)
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
server.HandleCall(2, func(data []byte) (PoolMessage, []byte, error) {
|
|
||||||
log.Printf("Received call: %s\n", string(data))
|
|
||||||
return 3, []byte("Hello, client!"), nil
|
|
||||||
})
|
|
||||||
server.HandleCall(Ping, func(data []byte) (PoolMessage, []byte, error) {
|
|
||||||
return Pong, nil, nil
|
|
||||||
})
|
|
||||||
|
|
||||||
client.SendPacket(1, []byte("Hello, world!"))
|
|
||||||
answer, err := client.Call(2, 3, []byte("Hello, server!"))
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("Error calling: %v\n", err)
|
|
||||||
}
|
|
||||||
for i := 0; i < 100; i++ {
|
|
||||||
_, err = client.Call(Ping, Pong, nil)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("Error calling: %v\n", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
_, err = client.Call(Ping, Pong, nil)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("Error calling: %v\n", err)
|
|
||||||
}
|
|
||||||
client.Close()
|
|
||||||
if string(answer.Data) != "Hello, client!" {
|
|
||||||
t.Errorf("Expected answer 'Hello, client!', got %s\n", string(answer.Data))
|
|
||||||
}
|
|
||||||
if messageData != "Hello, world!" {
|
|
||||||
t.Errorf("Expected message 'Hello, world!', got %s\n", messageData)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Reference in New Issue
Block a user