major refactor
This commit is contained in:
@@ -54,8 +54,8 @@ type CartGrain struct {
|
||||
|
||||
type Grain interface {
|
||||
GetId() CartId
|
||||
HandleMessage(message *Message, isReplay bool) (*CallResult, error)
|
||||
GetCurrentState() (*CallResult, error)
|
||||
HandleMessage(message *Message, isReplay bool) (*FrameWithPayload, error)
|
||||
GetCurrentState() (*FrameWithPayload, error)
|
||||
}
|
||||
|
||||
func (c *CartGrain) GetId() CartId {
|
||||
@@ -69,12 +69,14 @@ func (c *CartGrain) GetLastChange() int64 {
|
||||
return *c.storageMessages[len(c.storageMessages)-1].TimeStamp
|
||||
}
|
||||
|
||||
func (c *CartGrain) GetCurrentState() (*CallResult, error) {
|
||||
func (c *CartGrain) GetCurrentState() (*FrameWithPayload, error) {
|
||||
result, err := json.Marshal(c)
|
||||
return &CallResult{
|
||||
StatusCode: 200,
|
||||
Data: result,
|
||||
}, err
|
||||
if err != nil {
|
||||
ret := MakeFrameWithPayload(0, 400, []byte(err.Error()))
|
||||
return &ret, nil
|
||||
}
|
||||
ret := MakeFrameWithPayload(0, 200, result)
|
||||
return &ret, nil
|
||||
}
|
||||
|
||||
func getItemData(sku string, qty int) (*messages.AddItem, error) {
|
||||
@@ -108,7 +110,7 @@ func getItemData(sku string, qty int) (*messages.AddItem, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *CartGrain) AddItem(sku string, qty int) (*CallResult, error) {
|
||||
func (c *CartGrain) AddItem(sku string, qty int) (*FrameWithPayload, error) {
|
||||
cartItem, err := getItemData(sku, qty)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -180,7 +182,7 @@ func (c *CartGrain) FindItemWithSku(sku string) (*CartItem, bool) {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (c *CartGrain) HandleMessage(message *Message, isReplay bool) (*CallResult, error) {
|
||||
func (c *CartGrain) HandleMessage(message *Message, isReplay bool) (*FrameWithPayload, error) {
|
||||
if message.TimeStamp == nil {
|
||||
now := time.Now().Unix()
|
||||
message.TimeStamp = &now
|
||||
@@ -305,8 +307,10 @@ func (c *CartGrain) HandleMessage(message *Message, isReplay bool) (*CallResult,
|
||||
c.mu.Unlock()
|
||||
}
|
||||
result, err := json.Marshal(c)
|
||||
return &CallResult{
|
||||
return &FrameWithPayload{
|
||||
Frame: Frame{
|
||||
StatusCode: 200,
|
||||
Data: result,
|
||||
},
|
||||
Payload: result,
|
||||
}, err
|
||||
}
|
||||
|
||||
@@ -1,163 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"log"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type CartPacketQueue struct {
|
||||
mu sync.RWMutex
|
||||
expectedPackages map[CartMessage]*CartListener
|
||||
}
|
||||
|
||||
const CurrentPacketVersion = 2
|
||||
|
||||
type CartListener map[CartId]Listener
|
||||
|
||||
func NewCartPacketQueue(connection *PersistentConnection) *CartPacketQueue {
|
||||
queue := &CartPacketQueue{
|
||||
expectedPackages: make(map[CartMessage]*CartListener),
|
||||
}
|
||||
go queue.HandleConnection(connection)
|
||||
return queue
|
||||
}
|
||||
|
||||
func (p *CartPacketQueue) RemoveListeners() {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
for _, l := range p.expectedPackages {
|
||||
for _, l := range *l {
|
||||
close(l.Chan)
|
||||
}
|
||||
}
|
||||
p.expectedPackages = make(map[CartMessage]*CartListener)
|
||||
}
|
||||
|
||||
func (p *CartPacketQueue) HandleConnection(connection *PersistentConnection) error {
|
||||
defer p.RemoveListeners()
|
||||
defer connection.Close()
|
||||
var packet CartPacket
|
||||
reader := bufio.NewReader(connection)
|
||||
for {
|
||||
err := ReadCartPacket(reader, &packet)
|
||||
if err != nil {
|
||||
log.Printf("Error receiving packet: %v\n", err)
|
||||
return connection.HandleConnectionError(err)
|
||||
}
|
||||
if packet.Version != CurrentPacketVersion {
|
||||
log.Printf("Incorrect version: %v\n", packet.Version)
|
||||
return connection.HandleConnectionError(fmt.Errorf("incorrect version: %d", packet.Version))
|
||||
}
|
||||
if packet.DataLength == 0 {
|
||||
go p.HandleData(packet.MessageType, packet.Id, CallResult{
|
||||
StatusCode: packet.StatusCode,
|
||||
Data: []byte{},
|
||||
})
|
||||
continue
|
||||
}
|
||||
data, err := GetPacketData(reader, packet.DataLength)
|
||||
if err != nil {
|
||||
log.Printf("Error receiving packet data: %v\n", err)
|
||||
return connection.HandleConnectionError(err)
|
||||
}
|
||||
go p.HandleData(packet.MessageType, packet.Id, CallResult{
|
||||
StatusCode: packet.StatusCode,
|
||||
Data: data,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (p *CartPacketQueue) HandleData(t CartMessage, id CartId, data CallResult) {
|
||||
p.getListener(t, id, func(l *Listener) {
|
||||
l.Chan <- data
|
||||
l.Count--
|
||||
})
|
||||
// p.mu.Lock()
|
||||
// defer p.mu.Unlock()
|
||||
// pl, ok := p.expectedPackages[t]
|
||||
// if ok {
|
||||
// l, ok := (*pl)[id]
|
||||
// if ok {
|
||||
// l.Chan <- data
|
||||
// l.Count--
|
||||
// if l.Count == 0 {
|
||||
// close(l.Chan)
|
||||
// delete(*pl, id)
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
}
|
||||
|
||||
func (p *CartPacketQueue) getListener(t CartMessage, id CartId, fn func(*Listener)) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
pl, ok := p.expectedPackages[t]
|
||||
if ok {
|
||||
l, ok := (*pl)[id]
|
||||
if ok {
|
||||
fn(&l)
|
||||
if l.Count == 0 {
|
||||
close(l.Chan)
|
||||
delete(*pl, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func CallResultWithTimeout(onTimeout func() CallResult) chan CallResult {
|
||||
ch := make(chan CallResult, 1)
|
||||
resultCh := make(chan CallResult, 1)
|
||||
select {
|
||||
case ret := <-resultCh:
|
||||
ch <- ret
|
||||
case <-time.After(300 * time.Millisecond):
|
||||
ch <- onTimeout()
|
||||
}
|
||||
return ch
|
||||
}
|
||||
|
||||
func (p *CartPacketQueue) MakeChannel(messageType CartMessage, id CartId) chan CallResult {
|
||||
return CallResultWithTimeout(func() CallResult {
|
||||
p.getListener(messageType, id, func(l *Listener) {
|
||||
l.Count--
|
||||
})
|
||||
return CallResult{
|
||||
StatusCode: 504,
|
||||
Data: []byte("timeout cart call"),
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (p *CartPacketQueue) Expect(messageType CartMessage, id CartId) <-chan CallResult {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
l, ok := p.expectedPackages[messageType]
|
||||
if ok {
|
||||
if idl, idOk := (*l)[id]; idOk {
|
||||
idl.Count++
|
||||
return idl.Chan
|
||||
}
|
||||
ch := p.MakeChannel(messageType, id)
|
||||
|
||||
(*l)[id] = Listener{
|
||||
Chan: ch,
|
||||
Count: 1,
|
||||
}
|
||||
|
||||
return ch
|
||||
}
|
||||
|
||||
ch := p.MakeChannel(messageType, id)
|
||||
p.expectedPackages[messageType] = &CartListener{
|
||||
id: Listener{
|
||||
Chan: ch,
|
||||
Count: 1,
|
||||
},
|
||||
}
|
||||
|
||||
return ch
|
||||
|
||||
}
|
||||
@@ -27,8 +27,8 @@ var (
|
||||
)
|
||||
|
||||
type GrainPool interface {
|
||||
Process(id CartId, messages ...Message) (*CallResult, error)
|
||||
Get(id CartId) (*CallResult, error)
|
||||
Process(id CartId, messages ...Message) (*FrameWithPayload, error)
|
||||
Get(id CartId) (*FrameWithPayload, error)
|
||||
}
|
||||
|
||||
type Ttl struct {
|
||||
@@ -142,23 +142,29 @@ func (p *GrainLocalPool) GetGrain(id CartId) (*CartGrain, error) {
|
||||
return grain, err
|
||||
}
|
||||
|
||||
func (p *GrainLocalPool) Process(id CartId, messages ...Message) ([]byte, error) {
|
||||
func (p *GrainLocalPool) Process(id CartId, messages ...Message) (*FrameWithPayload, error) {
|
||||
grain, err := p.GetGrain(id)
|
||||
var result *FrameWithPayload
|
||||
if err == nil && grain != nil {
|
||||
for _, message := range messages {
|
||||
_, err = grain.HandleMessage(&message, false)
|
||||
result, err = grain.HandleMessage(&message, false)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return result, err
|
||||
}
|
||||
return json.Marshal(grain)
|
||||
return result, err
|
||||
}
|
||||
|
||||
func (p *GrainLocalPool) Get(id CartId) ([]byte, error) {
|
||||
func (p *GrainLocalPool) Get(id CartId) (*FrameWithPayload, error) {
|
||||
grain, err := p.GetGrain(id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return json.Marshal(grain)
|
||||
data, err := json.Marshal(grain)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ret := MakeFrameWithPayload(0, 200, data)
|
||||
return &ret, nil
|
||||
}
|
||||
|
||||
122
packet-queue.go
122
packet-queue.go
@@ -1,122 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"log"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type PacketQueue struct {
|
||||
mu sync.RWMutex
|
||||
expectedPackages map[PoolMessage]*Listener
|
||||
}
|
||||
|
||||
type CallResult struct {
|
||||
StatusCode uint32
|
||||
Data []byte
|
||||
}
|
||||
|
||||
type Listener struct {
|
||||
Count int
|
||||
Chan chan CallResult
|
||||
}
|
||||
|
||||
func NewPacketQueue(connection *PersistentConnection) *PacketQueue {
|
||||
queue := &PacketQueue{
|
||||
expectedPackages: make(map[PoolMessage]*Listener),
|
||||
}
|
||||
go queue.HandleConnection(connection)
|
||||
return queue
|
||||
}
|
||||
|
||||
func (p *PacketQueue) RemoveListeners() {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
for _, l := range p.expectedPackages {
|
||||
close(l.Chan)
|
||||
}
|
||||
p.expectedPackages = make(map[PoolMessage]*Listener)
|
||||
}
|
||||
|
||||
func (p *PacketQueue) HandleConnection(connection *PersistentConnection) error {
|
||||
defer connection.Close()
|
||||
defer p.RemoveListeners()
|
||||
var packet Packet
|
||||
reader := bufio.NewReader(connection)
|
||||
|
||||
for {
|
||||
err := ReadPacket(reader, &packet)
|
||||
if err != nil {
|
||||
return connection.HandleConnectionError(err)
|
||||
}
|
||||
if packet.Version != CurrentPacketVersion {
|
||||
log.Printf("Incorrect packet version: %v\n", packet.Version)
|
||||
return connection.HandleConnectionError(fmt.Errorf("incorrect packet version: %d", packet.Version))
|
||||
}
|
||||
if packet.DataLength == 0 {
|
||||
go p.HandleData(packet.MessageType, CallResult{
|
||||
StatusCode: packet.StatusCode,
|
||||
Data: []byte{},
|
||||
})
|
||||
continue
|
||||
}
|
||||
data, err := GetPacketData(reader, packet.DataLength)
|
||||
if err != nil {
|
||||
log.Printf("Error receiving packet data: %v\n", err)
|
||||
return connection.HandleConnectionError(err)
|
||||
} else {
|
||||
go p.HandleData(packet.MessageType, CallResult{
|
||||
StatusCode: packet.StatusCode,
|
||||
Data: data,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *PacketQueue) HandleData(t PoolMessage, data CallResult) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
l, ok := p.expectedPackages[t]
|
||||
if ok {
|
||||
l.Chan <- data
|
||||
l.Count--
|
||||
if l.Count == 0 {
|
||||
close(l.Chan)
|
||||
delete(p.expectedPackages, t)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (p *PacketQueue) Expect(messageType PoolMessage) <-chan CallResult {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
l, ok := p.expectedPackages[messageType]
|
||||
if ok {
|
||||
l.Count++
|
||||
return l.Chan
|
||||
}
|
||||
|
||||
ch := make(chan CallResult, 1)
|
||||
go func() {
|
||||
time.Sleep(time.Millisecond * 300)
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
ch <- CallResult{
|
||||
StatusCode: 504,
|
||||
Data: []byte("timeout cart call"),
|
||||
}
|
||||
|
||||
close(ch)
|
||||
}()
|
||||
p.expectedPackages[messageType] = &Listener{
|
||||
Count: 1,
|
||||
Chan: ch,
|
||||
}
|
||||
|
||||
return ch
|
||||
|
||||
}
|
||||
@@ -1,28 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestQueue(t *testing.T) {
|
||||
localPool := NewGrainLocalPool(100, time.Minute, func(id CartId) (*CartGrain, error) {
|
||||
return &CartGrain{
|
||||
Id: id,
|
||||
storageMessages: []Message{},
|
||||
Items: []*CartItem{},
|
||||
TotalPrice: 0,
|
||||
}, nil
|
||||
})
|
||||
pool, err := NewSyncedPool(localPool, "localhost", nil)
|
||||
if err != nil {
|
||||
t.Errorf("Error creating pool: %v", err)
|
||||
}
|
||||
|
||||
err = pool.AddRemote("localhost")
|
||||
if err != nil {
|
||||
t.Errorf("Error adding remote: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
}
|
||||
130
packet.go
130
packet.go
@@ -1,85 +1,77 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"io"
|
||||
)
|
||||
|
||||
type CartMessage uint32
|
||||
type PackageVersion uint32
|
||||
|
||||
const (
|
||||
RemoteGetState = CartMessage(0x01)
|
||||
RemoteHandleMutation = CartMessage(0x02)
|
||||
ResponseBody = CartMessage(0x03)
|
||||
RemoteGetStateReply = CartMessage(0x04)
|
||||
RemoteHandleMutationReply = CartMessage(0x05)
|
||||
RemoteGetState = FrameType(0x01)
|
||||
RemoteHandleMutation = FrameType(0x02)
|
||||
ResponseBody = FrameType(0x03)
|
||||
RemoteGetStateReply = FrameType(0x04)
|
||||
RemoteHandleMutationReply = FrameType(0x05)
|
||||
)
|
||||
|
||||
type CartPacket struct {
|
||||
Version PackageVersion
|
||||
MessageType CartMessage
|
||||
DataLength uint32
|
||||
StatusCode uint32
|
||||
Id CartId
|
||||
}
|
||||
// type CartPacket struct {
|
||||
// Version PackageVersion
|
||||
// MessageType CartMessage
|
||||
// DataLength uint32
|
||||
// StatusCode uint32
|
||||
// Id CartId
|
||||
// }
|
||||
|
||||
type Packet struct {
|
||||
Version PackageVersion
|
||||
MessageType PoolMessage
|
||||
DataLength uint32
|
||||
StatusCode uint32
|
||||
}
|
||||
// type Packet struct {
|
||||
// Version PackageVersion
|
||||
// MessageType PoolMessage
|
||||
// DataLength uint32
|
||||
// StatusCode uint32
|
||||
// }
|
||||
|
||||
var headerData = make([]byte, 4)
|
||||
// var headerData = make([]byte, 4)
|
||||
|
||||
func matchHeader(conn io.Reader) error {
|
||||
// func matchHeader(conn io.Reader) error {
|
||||
|
||||
pos := 0
|
||||
for pos < 4 {
|
||||
// pos := 0
|
||||
// for pos < 4 {
|
||||
|
||||
l, err := conn.Read(headerData)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for i := 0; i < l; i++ {
|
||||
if headerData[i] == header[pos] {
|
||||
pos++
|
||||
if pos == 4 {
|
||||
return nil
|
||||
}
|
||||
} else {
|
||||
pos = 0
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
// l, err := conn.Read(headerData)
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
// for i := 0; i < l; i++ {
|
||||
// if headerData[i] == header[pos] {
|
||||
// pos++
|
||||
// if pos == 4 {
|
||||
// return nil
|
||||
// }
|
||||
// } else {
|
||||
// pos = 0
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// return nil
|
||||
// }
|
||||
|
||||
func ReadPacket(conn io.Reader, packet *Packet) error {
|
||||
err := matchHeader(conn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return binary.Read(conn, binary.LittleEndian, packet)
|
||||
}
|
||||
// func ReadPacket(conn io.Reader, packet *Packet) error {
|
||||
// err := matchHeader(conn)
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
// return binary.Read(conn, binary.LittleEndian, packet)
|
||||
// }
|
||||
|
||||
func ReadCartPacket(conn io.Reader, packet *CartPacket) error {
|
||||
err := matchHeader(conn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return binary.Read(conn, binary.LittleEndian, packet)
|
||||
}
|
||||
// func ReadCartPacket(conn io.Reader, packet *CartPacket) error {
|
||||
// err := matchHeader(conn)
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
// return binary.Read(conn, binary.LittleEndian, packet)
|
||||
// }
|
||||
|
||||
func GetPacketData(conn io.Reader, len uint32) ([]byte, error) {
|
||||
if len == 0 {
|
||||
return []byte{}, nil
|
||||
}
|
||||
data := make([]byte, len)
|
||||
_, err := conn.Read(data)
|
||||
return data, err
|
||||
}
|
||||
// func GetPacketData(conn io.Reader, len uint32) ([]byte, error) {
|
||||
// if len == 0 {
|
||||
// return []byte{}, nil
|
||||
// }
|
||||
// data := make([]byte, len)
|
||||
// _, err := conn.Read(data)
|
||||
// return data, err
|
||||
// }
|
||||
|
||||
// func ReceivePacket(conn io.Reader) (uint32, []byte, error) {
|
||||
// var packet Packet
|
||||
|
||||
@@ -55,7 +55,7 @@ func ErrorHandler(fn func(w http.ResponseWriter, r *http.Request) error) func(w
|
||||
}
|
||||
}
|
||||
|
||||
func (s *PoolServer) WriteResult(w http.ResponseWriter, result *CallResult) error {
|
||||
func (s *PoolServer) WriteResult(w http.ResponseWriter, result *FrameWithPayload) error {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("X-Pod-Name", s.pod_name)
|
||||
if result.StatusCode != 200 {
|
||||
@@ -65,11 +65,11 @@ func (s *PoolServer) WriteResult(w http.ResponseWriter, result *CallResult) erro
|
||||
} else {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
}
|
||||
w.Write([]byte(result.Data))
|
||||
w.Write([]byte(result.Payload))
|
||||
return nil
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, err := w.Write(result.Data)
|
||||
_, err := w.Write(result.Payload)
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
@@ -46,8 +46,8 @@ func (p *RemoteGrainPool) Delete(id CartId) {
|
||||
p.mu.Unlock()
|
||||
}
|
||||
|
||||
func (p *RemoteGrainPool) Process(id CartId, messages ...Message) (*CallResult, error) {
|
||||
var result *CallResult
|
||||
func (p *RemoteGrainPool) Process(id CartId, messages ...Message) (*FrameWithPayload, error) {
|
||||
var result *FrameWithPayload
|
||||
grain, err := p.findOrCreateGrain(id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -58,7 +58,7 @@ func (p *RemoteGrainPool) Process(id CartId, messages ...Message) (*CallResult,
|
||||
return result, err
|
||||
}
|
||||
|
||||
func (p *RemoteGrainPool) Get(id CartId) (*CallResult, error) {
|
||||
func (p *RemoteGrainPool) Get(id CartId) (*FrameWithPayload, error) {
|
||||
grain, err := p.findOrCreateGrain(id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -3,7 +3,6 @@ package main
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/prometheus/client_golang/prometheus/promauto"
|
||||
@@ -13,6 +12,25 @@ func (id CartId) String() string {
|
||||
return strings.Trim(string(id[:]), "\x00")
|
||||
}
|
||||
|
||||
type CartIdPayload struct {
|
||||
Id CartId
|
||||
Data []byte
|
||||
}
|
||||
|
||||
func MakeCartInnerFrame(id CartId, payload []byte) []byte {
|
||||
return append(id[:], payload...)
|
||||
}
|
||||
|
||||
func GetCartFrame(data []byte) (*CartIdPayload, error) {
|
||||
if len(data) < 16 {
|
||||
return nil, fmt.Errorf("data too short")
|
||||
}
|
||||
return &CartIdPayload{
|
||||
Id: CartId(data[:16]),
|
||||
Data: data[16:],
|
||||
}, nil
|
||||
}
|
||||
|
||||
func ToCartId(id string) CartId {
|
||||
var result [16]byte
|
||||
copy(result[:], []byte(id))
|
||||
@@ -20,21 +38,16 @@ func ToCartId(id string) CartId {
|
||||
}
|
||||
|
||||
type RemoteGrain struct {
|
||||
*CartClient
|
||||
*Connection
|
||||
Id CartId
|
||||
Host string
|
||||
}
|
||||
|
||||
func NewRemoteGrain(id CartId, host string) (*RemoteGrain, error) {
|
||||
client, err := CartDial(fmt.Sprintf("%s:1337", host))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &RemoteGrain{
|
||||
Id: id,
|
||||
Host: host,
|
||||
CartClient: client,
|
||||
Connection: NewConnection(fmt.Sprintf("%s:1337", host)),
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -49,47 +62,35 @@ var (
|
||||
})
|
||||
)
|
||||
|
||||
var start time.Time
|
||||
// var start time.Time
|
||||
|
||||
func MeasureLatency(fn func() (*CallResult, error)) (*CallResult, error) {
|
||||
start = time.Now()
|
||||
data, err := fn()
|
||||
if err != nil {
|
||||
return data, err
|
||||
}
|
||||
elapsed := time.Since(start).Milliseconds()
|
||||
go func() {
|
||||
remoteCartLatency.Add(float64(elapsed))
|
||||
remoteCartCallsTotal.Inc()
|
||||
}()
|
||||
return data, nil
|
||||
}
|
||||
// func MeasureLatency(fn func() (*CallResult, error)) (*CallResult, error) {
|
||||
// start = time.Now()
|
||||
// data, err := fn()
|
||||
// if err != nil {
|
||||
// return data, err
|
||||
// }
|
||||
// elapsed := time.Since(start).Milliseconds()
|
||||
// go func() {
|
||||
// remoteCartLatency.Add(float64(elapsed))
|
||||
// remoteCartCallsTotal.Inc()
|
||||
// }()
|
||||
// return data, nil
|
||||
// }
|
||||
|
||||
func (g *RemoteGrain) HandleMessage(message *Message, isReplay bool) (*CallResult, error) {
|
||||
func (g *RemoteGrain) HandleMessage(message *Message, isReplay bool) (*FrameWithPayload, error) {
|
||||
|
||||
data, err := GetData(message.Write)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
reply, err := MeasureLatency(func() (*CallResult, error) {
|
||||
return g.Call(RemoteHandleMutation, g.Id, RemoteHandleMutationReply, data)
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return reply, err
|
||||
}
|
||||
|
||||
func (g *RemoteGrain) Close() {
|
||||
g.CartClient.PersistentConnection.Close()
|
||||
return g.Call(RemoteHandleMutation, MakeCartInnerFrame(g.Id, data))
|
||||
}
|
||||
|
||||
func (g *RemoteGrain) GetId() CartId {
|
||||
return g.Id
|
||||
}
|
||||
|
||||
func (g *RemoteGrain) GetCurrentState() (*CallResult, error) {
|
||||
return MeasureLatency(func() (*CallResult, error) { return g.Call(RemoteGetState, g.Id, RemoteGetStateReply, []byte{}) })
|
||||
func (g *RemoteGrain) GetCurrentState() (*FrameWithPayload, error) {
|
||||
return g.Call(RemoteGetState, MakeCartInnerFrame(g.Id, nil))
|
||||
}
|
||||
|
||||
@@ -7,13 +7,13 @@ import (
|
||||
)
|
||||
|
||||
type RemoteHost struct {
|
||||
*Client
|
||||
*Connection
|
||||
Host string
|
||||
MissedPings int
|
||||
}
|
||||
|
||||
func (h *RemoteHost) IsHealthy() bool {
|
||||
return !h.PersistentConnection.Dead && h.MissedPings < 3
|
||||
return h.MissedPings < 3
|
||||
}
|
||||
|
||||
func (h *RemoteHost) Initialize(p *SyncedPool) {
|
||||
@@ -38,15 +38,11 @@ func (h *RemoteHost) Initialize(p *SyncedPool) {
|
||||
}
|
||||
|
||||
func (h *RemoteHost) Ping() error {
|
||||
_, err := h.Call(Ping, Pong, []byte{})
|
||||
result, err := h.Call(Ping, nil)
|
||||
|
||||
if err != nil {
|
||||
if err != nil || result.StatusCode != 200 || result.Type != Pong {
|
||||
h.MissedPings++
|
||||
log.Printf("Error pinging remote %s, missed pings: %d", h.Host, h.MissedPings)
|
||||
if !h.IsHealthy() {
|
||||
h.Close()
|
||||
return fmt.Errorf("remote %s is dead", h.Host)
|
||||
}
|
||||
} else {
|
||||
h.MissedPings = 0
|
||||
}
|
||||
@@ -54,28 +50,28 @@ func (h *RemoteHost) Ping() error {
|
||||
}
|
||||
|
||||
func (h *RemoteHost) Negotiate(knownHosts []string) ([]string, error) {
|
||||
reply, err := h.Call(RemoteNegotiate, RemoteNegotiateResponse, []byte(strings.Join(knownHosts, ";")))
|
||||
reply, err := h.Call(RemoteNegotiate, []byte(strings.Join(knownHosts, ";")))
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if reply.StatusCode != 200 {
|
||||
return nil, fmt.Errorf("remote returned error on negotiate: %s", string(reply.Data))
|
||||
return nil, fmt.Errorf("remote returned error on negotiate: %s", string(reply.Payload))
|
||||
}
|
||||
|
||||
return strings.Split(string(reply.Data), ";"), nil
|
||||
return strings.Split(string(reply.Payload), ";"), nil
|
||||
}
|
||||
|
||||
func (g *RemoteHost) GetCartMappings() ([]CartId, error) {
|
||||
reply, err := g.Call(GetCartIds, CartIdsResponse, []byte{})
|
||||
reply, err := g.Call(GetCartIds, []byte{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if reply.StatusCode != 200 {
|
||||
log.Printf("Remote returned error on get cart mappings: %s", string(reply.Data))
|
||||
return nil, fmt.Errorf("remote returned error: %s", string(reply.Data))
|
||||
if reply.StatusCode != 200 || reply.Type != CartIdsResponse {
|
||||
log.Printf("Remote returned error on get cart mappings: %s", string(reply.Payload))
|
||||
return nil, fmt.Errorf("remote returned incorrect data")
|
||||
}
|
||||
parts := strings.Split(string(reply.Data), ";")
|
||||
parts := strings.Split(string(reply.Payload), ";")
|
||||
ids := make([]CartId, 0, len(parts))
|
||||
for _, p := range parts {
|
||||
ids = append(ids, ToCartId(p))
|
||||
@@ -84,14 +80,11 @@ func (g *RemoteHost) GetCartMappings() ([]CartId, error) {
|
||||
}
|
||||
|
||||
func (r *RemoteHost) ConfirmChange(id CartId, host string) error {
|
||||
reply, err := r.Call(RemoteGrainChanged, AckChange, []byte(fmt.Sprintf("%s;%s", id, host)))
|
||||
reply, err := r.Call(RemoteGrainChanged, []byte(fmt.Sprintf("%s;%s", id, host)))
|
||||
|
||||
if err != nil {
|
||||
if err != nil || reply.StatusCode != 200 || reply.Type != AckChange {
|
||||
return err
|
||||
}
|
||||
if string(reply.Data) != "ok" {
|
||||
return fmt.Errorf("remote grain change failed %s", string(reply.Data))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
)
|
||||
|
||||
type GrainHandler struct {
|
||||
*CartServer
|
||||
*GenericListener
|
||||
pool *GrainLocalPool
|
||||
}
|
||||
|
||||
@@ -20,13 +20,14 @@ func (h *GrainHandler) GetState(id CartId, reply *Grain) error {
|
||||
}
|
||||
|
||||
func NewGrainHandler(pool *GrainLocalPool, listen string) (*GrainHandler, error) {
|
||||
server, err := CartListen(listen)
|
||||
conn := NewConnection(listen)
|
||||
server, err := conn.Listen()
|
||||
handler := &GrainHandler{
|
||||
CartServer: server,
|
||||
GenericListener: server,
|
||||
pool: pool,
|
||||
}
|
||||
server.HandleCall(RemoteHandleMutation, handler.RemoteHandleMessageHandler)
|
||||
server.HandleCall(RemoteGetState, handler.RemoteGetStateHandler)
|
||||
server.AddHandler(RemoteHandleMutation, handler.RemoteHandleMessageHandler)
|
||||
server.AddHandler(RemoteGetState, handler.RemoteGetStateHandler)
|
||||
return handler, err
|
||||
}
|
||||
|
||||
@@ -34,29 +35,36 @@ func (h *GrainHandler) IsHealthy() bool {
|
||||
return len(h.pool.grains) < h.pool.PoolSize
|
||||
}
|
||||
|
||||
func (h *GrainHandler) RemoteHandleMessageHandler(id CartId, data []byte) (CartMessage, []byte, error) {
|
||||
func (h *GrainHandler) RemoteHandleMessageHandler(data *FrameWithPayload, resultChan chan<- FrameWithPayload) error {
|
||||
cartData, err := GetCartFrame(data.Payload)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var msg Message
|
||||
err := ReadMessage(bytes.NewReader(data), &msg)
|
||||
err = ReadMessage(bytes.NewReader(cartData.Data), &msg)
|
||||
if err != nil {
|
||||
fmt.Println("Error reading message:", err)
|
||||
return RemoteHandleMutationReply, nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
replyData, err := h.pool.Process(id, msg)
|
||||
replyData, err := h.pool.Process(cartData.Id, msg)
|
||||
|
||||
if err != nil {
|
||||
fmt.Println("Error handling message:", err)
|
||||
}
|
||||
if err != nil {
|
||||
return RemoteHandleMutationReply, nil, err
|
||||
}
|
||||
return RemoteHandleMutationReply, replyData, nil
|
||||
resultChan <- *replyData
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *GrainHandler) RemoteGetStateHandler(id CartId, data []byte) (CartMessage, []byte, error) {
|
||||
reply, err := h.pool.Get(id)
|
||||
func (h *GrainHandler) RemoteGetStateHandler(data *FrameWithPayload, resultChan chan<- FrameWithPayload) error {
|
||||
cartData, err := GetCartFrame(data.Payload)
|
||||
if err != nil {
|
||||
return RemoteGetStateReply, nil, err
|
||||
return err
|
||||
}
|
||||
return RemoteGetStateReply, reply, nil
|
||||
reply, err := h.pool.Get(cartData.Id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
resultChan <- *reply
|
||||
return nil
|
||||
}
|
||||
|
||||
118
synced-pool.go
118
synced-pool.go
@@ -22,7 +22,7 @@ type HealthHandler interface {
|
||||
}
|
||||
|
||||
type SyncedPool struct {
|
||||
*Server
|
||||
Server *GenericListener
|
||||
mu sync.RWMutex
|
||||
Hostname string
|
||||
local *GrainLocalPool
|
||||
@@ -61,11 +61,16 @@ var (
|
||||
})
|
||||
)
|
||||
|
||||
func (p *SyncedPool) PongHandler(data []byte) (PoolMessage, []byte, error) {
|
||||
return Pong, data, nil
|
||||
var (
|
||||
PongResponse = MakeFrameWithPayload(Pong, 200, nil)
|
||||
)
|
||||
|
||||
func (p *SyncedPool) PongHandler(data *FrameWithPayload, resultChan chan<- FrameWithPayload) error {
|
||||
resultChan <- PongResponse
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *SyncedPool) GetCartIdHandler(data []byte) (PoolMessage, []byte, error) {
|
||||
func (p *SyncedPool) GetCartIdHandler(data *FrameWithPayload, resultChan chan<- FrameWithPayload) error {
|
||||
ids := make([]string, 0, len(p.local.grains))
|
||||
for id := range p.local.grains {
|
||||
if p.local.grains[id] == nil {
|
||||
@@ -78,45 +83,45 @@ func (p *SyncedPool) GetCartIdHandler(data []byte) (PoolMessage, []byte, error)
|
||||
ids = append(ids, s)
|
||||
}
|
||||
log.Printf("Returning %d cart ids\n", len(ids))
|
||||
return CartIdsResponse, []byte(strings.Join(ids, ";")), nil
|
||||
resultChan <- MakeFrameWithPayload(CartIdsResponse, 200, []byte(strings.Join(ids, ";")))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *SyncedPool) NegotiateHandler(data []byte) (PoolMessage, []byte, error) {
|
||||
func (p *SyncedPool) NegotiateHandler(data *FrameWithPayload, resultChan chan<- FrameWithPayload) error {
|
||||
negotiationCount.Inc()
|
||||
log.Printf("Handling negotiation\n")
|
||||
for _, host := range p.ExcludeKnown(strings.Split(string(data), ";")) {
|
||||
for _, host := range p.ExcludeKnown(strings.Split(string(data.Payload), ";")) {
|
||||
if host == "" {
|
||||
continue
|
||||
}
|
||||
go p.AddRemote(host)
|
||||
|
||||
}
|
||||
|
||||
return RemoteNegotiateResponse, []byte("ok"), nil
|
||||
resultChan <- MakeFrameWithPayload(RemoteNegotiateResponse, 200, []byte("ok"))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *SyncedPool) GrainOwnerChangeHandler(data []byte) (PoolMessage, []byte, error) {
|
||||
func (p *SyncedPool) GrainOwnerChangeHandler(data *FrameWithPayload, resultChan chan<- FrameWithPayload) error {
|
||||
grainSyncCount.Inc()
|
||||
|
||||
idAndHostParts := strings.Split(string(data), ";")
|
||||
idAndHostParts := strings.Split(string(data.Payload), ";")
|
||||
if len(idAndHostParts) != 2 {
|
||||
log.Printf("Invalid remote grain change message\n")
|
||||
return AckChange, []byte("incorrect"), fmt.Errorf("invalid remote grain change message")
|
||||
resultChan <- MakeFrameWithPayload(AckError, 400, []byte("invalid"))
|
||||
return nil
|
||||
}
|
||||
id := ToCartId(idAndHostParts[0])
|
||||
host := idAndHostParts[1]
|
||||
log.Printf("Handling remote grain owner change to %s for id %s\n", host, id)
|
||||
for _, r := range p.remotes {
|
||||
if r.Host == host && r.IsHealthy() {
|
||||
// log.Printf("Remote grain %s changed to %s\n", id, host)
|
||||
|
||||
go p.SpawnRemoteGrain(id, host)
|
||||
|
||||
return AckChange, []byte("ok"), nil
|
||||
break
|
||||
}
|
||||
}
|
||||
go p.AddRemote(host)
|
||||
return AckChange, []byte("ok"), nil
|
||||
resultChan <- MakeFrameWithPayload(AckChange, 200, []byte("ok"))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *SyncedPool) RemoveRemoteGrain(id CartId) {
|
||||
@@ -142,12 +147,12 @@ func (p *SyncedPool) SpawnRemoteGrain(id CartId, host string) {
|
||||
log.Printf("Error creating remote grain %v\n", err)
|
||||
return
|
||||
}
|
||||
go func() {
|
||||
<-remote.PersistentConnection.Died
|
||||
p.RemoveRemoteGrain(id)
|
||||
p.HandleHostError(host)
|
||||
log.Printf("Remote grain %s died, host: %s\n", id.String(), host)
|
||||
}()
|
||||
// go func() {
|
||||
// <-remote.Died
|
||||
// p.RemoveRemoteGrain(id)
|
||||
// p.HandleHostError(host)
|
||||
// log.Printf("Remote grain %s died, host: %s\n", id.String(), host)
|
||||
// }()
|
||||
|
||||
p.mu.Lock()
|
||||
p.remoteIndex[id] = remote
|
||||
@@ -159,8 +164,6 @@ func (p *SyncedPool) HandleHostError(host string) {
|
||||
if r.Host == host {
|
||||
if !r.IsHealthy() {
|
||||
p.RemoveHost(r)
|
||||
} else {
|
||||
r.ErrorCount++
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -169,8 +172,8 @@ func (p *SyncedPool) HandleHostError(host string) {
|
||||
|
||||
func NewSyncedPool(local *GrainLocalPool, hostname string, discovery Discovery) (*SyncedPool, error) {
|
||||
listen := fmt.Sprintf("%s:1338", hostname)
|
||||
|
||||
server, err := Listen(listen)
|
||||
conn := NewConnection(listen)
|
||||
server, err := conn.Listen()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -186,10 +189,10 @@ func NewSyncedPool(local *GrainLocalPool, hostname string, discovery Discovery)
|
||||
remoteIndex: make(map[CartId]*RemoteGrain),
|
||||
}
|
||||
|
||||
server.HandleCall(Ping, pool.PongHandler)
|
||||
server.HandleCall(GetCartIds, pool.GetCartIdHandler)
|
||||
server.HandleCall(RemoteNegotiate, pool.NegotiateHandler)
|
||||
server.HandleCall(RemoteGrainChanged, pool.GrainOwnerChangeHandler)
|
||||
server.AddHandler(Ping, pool.PongHandler)
|
||||
server.AddHandler(GetCartIds, pool.GetCartIdHandler)
|
||||
server.AddHandler(RemoteNegotiate, pool.NegotiateHandler)
|
||||
server.AddHandler(RemoteGrainChanged, pool.GrainOwnerChangeHandler)
|
||||
|
||||
if discovery != nil {
|
||||
go func() {
|
||||
@@ -259,18 +262,11 @@ func (p *SyncedPool) ExcludeKnown(hosts []string) []string {
|
||||
}
|
||||
|
||||
func (p *SyncedPool) RemoveHost(host *RemoteHost) {
|
||||
if p.remotes[host.Host] == nil {
|
||||
return
|
||||
}
|
||||
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
h := p.remotes[host.Host]
|
||||
if h != nil {
|
||||
h.Close()
|
||||
}
|
||||
delete(p.remotes, host.Host)
|
||||
|
||||
p.mu.Unlock()
|
||||
p.RemoveHostMappedCarts(host)
|
||||
connectedRemotes.Set(float64(len(p.remotes)))
|
||||
}
|
||||
|
||||
@@ -279,24 +275,21 @@ func (p *SyncedPool) RemoveHostMappedCarts(host *RemoteHost) {
|
||||
defer p.mu.Unlock()
|
||||
for id, r := range p.remoteIndex {
|
||||
if r.Host == host.Host {
|
||||
p.remoteIndex[id].Close()
|
||||
delete(p.remoteIndex, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type PoolMessage uint32
|
||||
|
||||
const (
|
||||
RemoteNegotiate = PoolMessage(3)
|
||||
RemoteGrainChanged = PoolMessage(4)
|
||||
AckChange = PoolMessage(5)
|
||||
//AckError = PoolMessage(6)
|
||||
Ping = PoolMessage(7)
|
||||
Pong = PoolMessage(8)
|
||||
GetCartIds = PoolMessage(9)
|
||||
CartIdsResponse = PoolMessage(10)
|
||||
RemoteNegotiateResponse = PoolMessage(11)
|
||||
RemoteNegotiate = FrameType(3)
|
||||
RemoteGrainChanged = FrameType(4)
|
||||
AckChange = FrameType(5)
|
||||
AckError = FrameType(6)
|
||||
Ping = FrameType(7)
|
||||
Pong = FrameType(8)
|
||||
GetCartIds = FrameType(9)
|
||||
CartIdsResponse = FrameType(10)
|
||||
RemoteNegotiateResponse = FrameType(11)
|
||||
)
|
||||
|
||||
func (p *SyncedPool) Negotiate() {
|
||||
@@ -377,25 +370,22 @@ func (p *SyncedPool) AddRemote(host string) error {
|
||||
if host == "" || p.IsKnown(host) || hasHost {
|
||||
return nil
|
||||
}
|
||||
client, err := Dial(fmt.Sprintf("%s:1338", host))
|
||||
if err != nil {
|
||||
client := NewConnection(fmt.Sprintf("%s:1338", host))
|
||||
response, err := client.Call(Ping, nil)
|
||||
if err != nil || response.StatusCode != 200 || response.Type != Pong {
|
||||
log.Printf("Error connecting to remote %s: %v\n", host, err)
|
||||
return err
|
||||
}
|
||||
|
||||
remote := RemoteHost{
|
||||
Client: client,
|
||||
Connection: client,
|
||||
MissedPings: 0,
|
||||
Host: host,
|
||||
}
|
||||
p.mu.Lock()
|
||||
p.remotes[host] = &remote
|
||||
p.mu.Unlock()
|
||||
go func() {
|
||||
<-remote.PersistentConnection.Died
|
||||
log.Printf("Removing host, remote died %s", host)
|
||||
p.RemoveHost(&remote)
|
||||
}()
|
||||
|
||||
go func() {
|
||||
for range time.Tick(time.Second * 3) {
|
||||
|
||||
@@ -450,9 +440,9 @@ func (p *SyncedPool) getGrain(id CartId) (Grain, error) {
|
||||
return localGrain, nil
|
||||
}
|
||||
|
||||
func (p *SyncedPool) Process(id CartId, messages ...Message) (*CallResult, error) {
|
||||
func (p *SyncedPool) Process(id CartId, messages ...Message) (*FrameWithPayload, error) {
|
||||
pool, err := p.getGrain(id)
|
||||
var res *CallResult
|
||||
var res *FrameWithPayload
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -465,7 +455,7 @@ func (p *SyncedPool) Process(id CartId, messages ...Message) (*CallResult, error
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func (p *SyncedPool) Get(id CartId) (*CallResult, error) {
|
||||
func (p *SyncedPool) Get(id CartId) (*FrameWithPayload, error) {
|
||||
grain, err := p.getGrain(id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -1,96 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"log"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type CartClient struct {
|
||||
*CartTCPClient
|
||||
}
|
||||
|
||||
func CartDial(address string) (*CartClient, error) {
|
||||
|
||||
mux, err := NewCartTCPClient(address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
client := &CartClient{
|
||||
CartTCPClient: mux,
|
||||
}
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func (c *Client) Close() {
|
||||
log.Printf("Closing connection to %s\n", c.PersistentConnection.address)
|
||||
c.PersistentConnection.Close()
|
||||
}
|
||||
|
||||
type CartTCPClient struct {
|
||||
PersistentConnection *PersistentConnection
|
||||
sendMux sync.Mutex
|
||||
ErrorCount int
|
||||
address string
|
||||
*CartPacketQueue
|
||||
}
|
||||
|
||||
func NewCartTCPClient(address string) (*CartTCPClient, error) {
|
||||
connection, err := NewPersistentConnection(address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &CartTCPClient{
|
||||
ErrorCount: 0,
|
||||
PersistentConnection: connection,
|
||||
address: address,
|
||||
CartPacketQueue: NewCartPacketQueue(connection),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *CartTCPClient) SendPacket(messageType CartMessage, id CartId, data []byte) error {
|
||||
m.sendMux.Lock()
|
||||
defer m.sendMux.Unlock()
|
||||
m.PersistentConnection.Conn.Write(header[:])
|
||||
err := binary.Write(m.PersistentConnection, binary.LittleEndian, CartPacket{
|
||||
Version: CurrentPacketVersion,
|
||||
MessageType: messageType,
|
||||
DataLength: uint32(len(data)),
|
||||
Id: id,
|
||||
})
|
||||
if err != nil {
|
||||
return m.PersistentConnection.HandleConnectionError(err)
|
||||
}
|
||||
_, err = m.PersistentConnection.Write(data)
|
||||
return m.PersistentConnection.HandleConnectionError(err)
|
||||
}
|
||||
|
||||
func (m *CartTCPClient) call(messageType CartMessage, id CartId, responseType CartMessage, data []byte) (*CallResult, error) {
|
||||
packetChan := m.Expect(responseType, id)
|
||||
err := m.SendPacket(messageType, id, data)
|
||||
if err != nil {
|
||||
return nil, m.PersistentConnection.HandleConnectionError(err)
|
||||
}
|
||||
|
||||
ret := <-packetChan
|
||||
return &ret, nil
|
||||
}
|
||||
|
||||
func isRetirableError(err error) bool {
|
||||
log.Printf("is retryable error: %v", err)
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *CartTCPClient) Call(messageType CartMessage, id CartId, responseType CartMessage, data []byte) (*CallResult, error) {
|
||||
retries := 0
|
||||
result, err := m.call(messageType, id, responseType, data)
|
||||
for err != nil && retries < 3 {
|
||||
if !isRetirableError(err) {
|
||||
break
|
||||
}
|
||||
retries++
|
||||
log.Printf("Retrying call to %d\n", messageType)
|
||||
result, err = m.call(messageType, id, responseType, data)
|
||||
}
|
||||
return result, err
|
||||
}
|
||||
@@ -1,160 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type CartServer struct {
|
||||
*TCPCartServerMux
|
||||
}
|
||||
|
||||
func CartListen(address string) (*CartServer, error) {
|
||||
listener, err := net.Listen("tcp", address)
|
||||
server := &CartServer{
|
||||
NewCartTCPServerMux(),
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
go func() {
|
||||
for {
|
||||
conn, err := listener.Accept()
|
||||
if err != nil {
|
||||
log.Printf("Error accepting connection: %v\n", err)
|
||||
continue
|
||||
}
|
||||
go server.HandleConnection(conn)
|
||||
}
|
||||
}()
|
||||
return server, nil
|
||||
}
|
||||
|
||||
type TCPCartServerMux struct {
|
||||
mu sync.RWMutex
|
||||
sendMux sync.Mutex
|
||||
listeners map[CartMessage]func(CartId, []byte) error
|
||||
functions map[CartMessage]func(CartId, []byte) (CartMessage, []byte, error)
|
||||
}
|
||||
|
||||
func NewCartTCPServerMux() *TCPCartServerMux {
|
||||
m := &TCPCartServerMux{
|
||||
mu: sync.RWMutex{},
|
||||
listeners: make(map[CartMessage]func(CartId, []byte) error),
|
||||
functions: make(map[CartMessage]func(CartId, []byte) (CartMessage, []byte, error)),
|
||||
}
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *TCPCartServerMux) handleListener(messageType CartMessage, id CartId, data []byte) (bool, error) {
|
||||
m.mu.RLock()
|
||||
handler, ok := m.listeners[messageType]
|
||||
m.mu.RUnlock()
|
||||
if ok {
|
||||
err := handler(id, data)
|
||||
if err != nil {
|
||||
return true, err
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (m *TCPCartServerMux) handleFunction(connection net.Conn, messageType CartMessage, id CartId, data []byte) (bool, error) {
|
||||
m.mu.RLock()
|
||||
fn, ok := m.functions[messageType]
|
||||
m.mu.RUnlock()
|
||||
m.sendMux.Lock()
|
||||
defer m.sendMux.Unlock()
|
||||
if ok {
|
||||
responseType, responseData, err := fn(id, data)
|
||||
connection.Write(header[:])
|
||||
if err != nil {
|
||||
errData := []byte(err.Error())
|
||||
err = binary.Write(connection, binary.LittleEndian, CartPacket{
|
||||
Version: CurrentPacketVersion,
|
||||
MessageType: responseType,
|
||||
DataLength: uint32(len(errData)),
|
||||
StatusCode: 500,
|
||||
Id: id,
|
||||
})
|
||||
_, err = connection.Write(errData)
|
||||
return true, err
|
||||
}
|
||||
err = binary.Write(connection, binary.LittleEndian, CartPacket{
|
||||
Version: CurrentPacketVersion,
|
||||
MessageType: responseType,
|
||||
DataLength: uint32(len(responseData)),
|
||||
StatusCode: 200,
|
||||
Id: id,
|
||||
})
|
||||
if err != nil {
|
||||
return true, err
|
||||
}
|
||||
packetsSent.Inc()
|
||||
_, err = connection.Write(responseData)
|
||||
return true, err
|
||||
} else {
|
||||
log.Printf("No cart handler for type: %d\n", messageType)
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (m *TCPCartServerMux) HandleConnection(connection net.Conn) error {
|
||||
var packet CartPacket
|
||||
var err error
|
||||
defer connection.Close()
|
||||
reader := bufio.NewReader(connection)
|
||||
for {
|
||||
err = ReadCartPacket(reader, &packet)
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
return nil
|
||||
}
|
||||
log.Printf("Error receiving packet: %v\n", err)
|
||||
return err
|
||||
}
|
||||
if packet.Version != CurrentPacketVersion {
|
||||
log.Printf("Incorrect packet version: %d\n", packet.Version)
|
||||
continue
|
||||
}
|
||||
data, err := GetPacketData(reader, packet.DataLength)
|
||||
if err != nil {
|
||||
log.Printf("Error getting packet data: %v\n", err)
|
||||
}
|
||||
go m.HandleData(connection, packet.MessageType, packet.Id, data)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *TCPCartServerMux) HandleData(connection net.Conn, t CartMessage, id CartId, data []byte) {
|
||||
status, err := m.handleListener(t, id, data)
|
||||
if err != nil {
|
||||
log.Printf("Error handling listener: %v\n", err)
|
||||
}
|
||||
if !status {
|
||||
status, err = m.handleFunction(connection, t, id, data)
|
||||
if err != nil {
|
||||
log.Printf("Error handling function: %v\n", err)
|
||||
}
|
||||
if !status {
|
||||
log.Printf("Unknown message type: %d\n", t)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *TCPCartServerMux) ListenFor(messageType CartMessage, handler func(CartId, []byte) error) {
|
||||
m.mu.Lock()
|
||||
m.listeners[messageType] = handler
|
||||
m.mu.Unlock()
|
||||
}
|
||||
|
||||
func (m *TCPCartServerMux) HandleCall(messageType CartMessage, handler func(CartId, []byte) (CartMessage, []byte, error)) {
|
||||
m.mu.Lock()
|
||||
m.functions[messageType] = handler
|
||||
m.mu.Unlock()
|
||||
}
|
||||
@@ -1,57 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCartTcpHelpers(t *testing.T) {
|
||||
|
||||
server, err := CartListen("localhost:51337")
|
||||
if err != nil {
|
||||
t.Errorf("Error listening: %v\n", err)
|
||||
}
|
||||
client, err := CartDial("localhost:51337")
|
||||
if err != nil {
|
||||
t.Errorf("Error dialing: %v\n", err)
|
||||
}
|
||||
var messageData string
|
||||
server.ListenFor(1, func(id CartId, data []byte) error {
|
||||
log.Printf("Received message: %s\n", string(data))
|
||||
messageData = string(data)
|
||||
return nil
|
||||
})
|
||||
server.HandleCall(666, func(id CartId, data []byte) (CartMessage, []byte, error) {
|
||||
log.Printf("Received 666 call: %s\n", string(data))
|
||||
return 3, []byte("Hello, client!"), fmt.Errorf("Det blev fel")
|
||||
})
|
||||
server.HandleCall(2, func(id CartId, data []byte) (CartMessage, []byte, error) {
|
||||
log.Printf("Received 2 call: %s\n", string(data))
|
||||
return 4, []byte("Hello, client!"), nil
|
||||
})
|
||||
// server.HandleCall(Ping, func(id CartId, data []byte) (CartMessage, []byte, error) {
|
||||
// return Pong, nil, nil
|
||||
// })
|
||||
id := ToCartId("kalle")
|
||||
client.SendPacket(1, id, []byte("Hello, world!"))
|
||||
answer, err := client.Call(2, id, 4, []byte("Hello, server!"))
|
||||
if err != nil {
|
||||
t.Errorf("Error calling: %v\n", err)
|
||||
}
|
||||
s, err := client.Call(666, id, 3, []byte("Hello, server!"))
|
||||
client.PersistentConnection.Close()
|
||||
if err != nil {
|
||||
t.Errorf("Error calling: %v\n", err)
|
||||
}
|
||||
if s.StatusCode != 500 {
|
||||
t.Errorf("Expected 500, got %d\n", s.StatusCode)
|
||||
}
|
||||
|
||||
if string(answer.Data) != "Hello, client!" {
|
||||
t.Errorf("Expected answer 'Hello, client!', got %s\n", string(answer.Data))
|
||||
}
|
||||
if messageData != "Hello, world!" {
|
||||
t.Errorf("Expected message 'Hello, world!', got %s\n", messageData)
|
||||
}
|
||||
}
|
||||
144
tcp-client.go
144
tcp-client.go
@@ -1,144 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"log"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
*TCPClient
|
||||
}
|
||||
|
||||
func Dial(address string) (*Client, error) {
|
||||
|
||||
mux, err := NewTCPClient(address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
client := &Client{
|
||||
TCPClient: mux,
|
||||
}
|
||||
return client, nil
|
||||
}
|
||||
|
||||
type TCPClient struct {
|
||||
PersistentConnection *PersistentConnection
|
||||
sendMux sync.Mutex
|
||||
ErrorCount int
|
||||
address string
|
||||
*PacketQueue
|
||||
}
|
||||
|
||||
type PersistentConnection struct {
|
||||
net.Conn
|
||||
Died chan bool
|
||||
Dead bool
|
||||
address string
|
||||
}
|
||||
|
||||
func NewPersistentConnection(address string) (*PersistentConnection, error) {
|
||||
|
||||
p := &PersistentConnection{
|
||||
Died: make(chan bool, 1),
|
||||
Dead: false,
|
||||
address: address,
|
||||
}
|
||||
err := p.Connect()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func (m *PersistentConnection) Connect() error {
|
||||
fails := 0
|
||||
for {
|
||||
connection, err := net.Dial("tcp", m.address)
|
||||
if err != nil {
|
||||
log.Printf("Can't connect to %s: %v, count: %d", m.address, err, fails)
|
||||
fails++
|
||||
if fails > 15 {
|
||||
log.Printf("Too many connection failures, closing connection to %s", m.address)
|
||||
m.Died <- true
|
||||
m.Dead = true
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
m.Conn = connection
|
||||
break
|
||||
}
|
||||
time.Sleep(time.Millisecond * 300)
|
||||
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *PersistentConnection) Close() {
|
||||
log.Printf("Closing connection to %s\n", m.address)
|
||||
m.Conn.Close()
|
||||
m.Died <- true
|
||||
m.Dead = true
|
||||
}
|
||||
|
||||
func (m *PersistentConnection) HandleConnectionError(err error) error {
|
||||
if err != nil {
|
||||
log.Printf("Error from to %s: %v\n", m.address, err)
|
||||
m.Conn.Close()
|
||||
m.Connect()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func NewTCPClient(address string) (*TCPClient, error) {
|
||||
|
||||
connection, err := NewPersistentConnection(address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &TCPClient{
|
||||
ErrorCount: 0,
|
||||
PersistentConnection: connection,
|
||||
address: address,
|
||||
PacketQueue: NewPacketQueue(connection),
|
||||
}, nil
|
||||
}
|
||||
|
||||
type PacketHeader [4]byte
|
||||
|
||||
var (
|
||||
header = PacketHeader([4]byte{0x01, 0x02, 0x03, 0x04})
|
||||
)
|
||||
|
||||
func (m *TCPClient) SendPacket(messageType PoolMessage, data []byte) error {
|
||||
m.sendMux.Lock()
|
||||
defer m.sendMux.Unlock()
|
||||
m.PersistentConnection.Write(header[:])
|
||||
err := binary.Write(m.PersistentConnection, binary.LittleEndian, Packet{
|
||||
Version: CurrentPacketVersion,
|
||||
MessageType: messageType,
|
||||
StatusCode: 0,
|
||||
DataLength: uint32(len(data)),
|
||||
})
|
||||
if err != nil {
|
||||
return m.PersistentConnection.HandleConnectionError(err)
|
||||
}
|
||||
_, err = m.PersistentConnection.Write(data)
|
||||
return m.PersistentConnection.HandleConnectionError(err)
|
||||
}
|
||||
|
||||
func (m *TCPClient) Call(messageType PoolMessage, responseType PoolMessage, data []byte) (*CallResult, error) {
|
||||
packetChan := m.Expect(responseType)
|
||||
err := m.SendPacket(messageType, data)
|
||||
if err != nil {
|
||||
m.RemoveListeners()
|
||||
return nil, m.PersistentConnection.HandleConnectionError(err)
|
||||
}
|
||||
|
||||
ret := <-packetChan
|
||||
return &ret, nil
|
||||
|
||||
}
|
||||
@@ -15,12 +15,23 @@ type Connection struct {
|
||||
}
|
||||
|
||||
type FrameType uint32
|
||||
type StatusCode uint32
|
||||
type CheckSum uint32
|
||||
|
||||
type Frame struct {
|
||||
Id uint64
|
||||
Type FrameType
|
||||
StatusCode uint32
|
||||
StatusCode StatusCode
|
||||
Length uint32
|
||||
Checksum CheckSum
|
||||
}
|
||||
|
||||
func (f *Frame) IsValid() bool {
|
||||
return f.Checksum == MakeChecksum(f.Type, f.StatusCode, f.Length)
|
||||
}
|
||||
|
||||
func MakeChecksum(msg FrameType, statusCode StatusCode, length uint32) CheckSum {
|
||||
sum := CheckSum((uint32(msg) + uint32(statusCode) + length) / 8)
|
||||
return sum
|
||||
}
|
||||
|
||||
type FrameWithPayload struct {
|
||||
@@ -28,6 +39,19 @@ type FrameWithPayload struct {
|
||||
Payload []byte
|
||||
}
|
||||
|
||||
func MakeFrameWithPayload(msg FrameType, statusCode StatusCode, payload []byte) FrameWithPayload {
|
||||
len := uint32(len(payload))
|
||||
return FrameWithPayload{
|
||||
Frame: Frame{
|
||||
Type: msg,
|
||||
StatusCode: 0,
|
||||
Length: len,
|
||||
Checksum: MakeChecksum(msg, 0, len),
|
||||
},
|
||||
Payload: payload,
|
||||
}
|
||||
}
|
||||
|
||||
type FrameData interface {
|
||||
ToBytes() []byte
|
||||
FromBytes([]byte) error
|
||||
@@ -41,11 +65,7 @@ func NewConnection(address string) *Connection {
|
||||
}
|
||||
|
||||
func SendFrame(conn net.Conn, data *FrameWithPayload) error {
|
||||
_, err := conn.Write(header[:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = binary.Write(conn, binary.LittleEndian, data.Frame)
|
||||
err := binary.Write(conn, binary.LittleEndian, data.Frame)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -53,68 +73,67 @@ func SendFrame(conn net.Conn, data *FrameWithPayload) error {
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *Connection) CallAsync(msg FrameType, data FrameData, ch chan<- *FrameWithPayload) error {
|
||||
func (c *Connection) CallAsync(msg FrameType, payload []byte, ch chan<- FrameWithPayload) (net.Conn, error) {
|
||||
conn, err := net.Dial("tcp", c.address)
|
||||
go WaitForFrame(conn, ch)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
payload := data.ToBytes()
|
||||
toSend := &FrameWithPayload{
|
||||
Frame: Frame{
|
||||
Id: c.count,
|
||||
Type: msg,
|
||||
StatusCode: 0,
|
||||
Length: uint32(len(payload)),
|
||||
},
|
||||
Payload: payload,
|
||||
return conn, err
|
||||
}
|
||||
toSend := MakeFrameWithPayload(msg, 1, payload)
|
||||
|
||||
err = SendFrame(conn, toSend)
|
||||
err = SendFrame(conn, &toSend)
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
close(ch)
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.count++
|
||||
return nil
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func (c *Connection) Call(msg FrameType, data FrameData) (*FrameWithPayload, error) {
|
||||
ch := make(chan *FrameWithPayload, 1)
|
||||
c.CallAsync(msg, data, ch)
|
||||
func (c *Connection) Call(msg FrameType, data []byte) (*FrameWithPayload, error) {
|
||||
ch := make(chan FrameWithPayload, 1)
|
||||
conn, err := c.CallAsync(msg, data, ch)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer conn.Close()
|
||||
select {
|
||||
case ret := <-ch:
|
||||
return ret, nil
|
||||
case <-time.After(5 * time.Second):
|
||||
return &ret, nil
|
||||
case <-time.After(MaxCallDuration):
|
||||
return nil, fmt.Errorf("timeout")
|
||||
}
|
||||
}
|
||||
|
||||
func WaitForFrame(conn net.Conn, resultChan chan<- *FrameWithPayload) error {
|
||||
defer conn.Close()
|
||||
func WaitForFrame(conn net.Conn, resultChan chan<- FrameWithPayload) error {
|
||||
var err error
|
||||
var frame Frame
|
||||
r := bufio.NewReader(conn)
|
||||
h := make([]byte, 4)
|
||||
r.Read(h)
|
||||
if h[0] == header[0] && h[1] == header[1] && h[2] == header[2] && h[3] == header[3] {
|
||||
frame := Frame{}
|
||||
|
||||
err = binary.Read(r, binary.LittleEndian, &frame)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if frame.IsValid() {
|
||||
payload := make([]byte, frame.Length)
|
||||
_, err = r.Read(payload)
|
||||
resultChan <- &FrameWithPayload{
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
resultChan <- FrameWithPayload{
|
||||
Frame: frame,
|
||||
Payload: payload,
|
||||
}
|
||||
return err
|
||||
}
|
||||
resultChan <- nil
|
||||
return err
|
||||
return fmt.Errorf("checksum mismatch")
|
||||
}
|
||||
|
||||
type GenericListener struct {
|
||||
Closed bool
|
||||
handlers map[FrameType]func(*FrameWithPayload, chan<- *FrameWithPayload) error
|
||||
handlers map[FrameType]func(*FrameWithPayload, chan<- FrameWithPayload) error
|
||||
}
|
||||
|
||||
func (c *Connection) Listen() (*GenericListener, error) {
|
||||
@@ -123,7 +142,7 @@ func (c *Connection) Listen() (*GenericListener, error) {
|
||||
return nil, err
|
||||
}
|
||||
ret := &GenericListener{
|
||||
handlers: make(map[FrameType]func(*FrameWithPayload, chan<- *FrameWithPayload) error),
|
||||
handlers: make(map[FrameType]func(*FrameWithPayload, chan<- FrameWithPayload) error),
|
||||
}
|
||||
go func() {
|
||||
for !ret.Closed {
|
||||
@@ -137,36 +156,44 @@ func (c *Connection) Listen() (*GenericListener, error) {
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
const (
|
||||
MaxCallDuration = 500 * time.Millisecond
|
||||
)
|
||||
|
||||
func (l *GenericListener) HandleConnection(conn net.Conn) {
|
||||
ch := make(chan *FrameWithPayload, 1)
|
||||
ch := make(chan FrameWithPayload, 1)
|
||||
go WaitForFrame(conn, ch)
|
||||
select {
|
||||
case frame := <-ch:
|
||||
go l.HandleFrame(conn, frame)
|
||||
case <-time.After(1 * time.Second):
|
||||
go l.HandleFrame(conn, &frame)
|
||||
case <-time.After(MaxCallDuration):
|
||||
close(ch)
|
||||
log.Printf("Timeout waiting for frame\n")
|
||||
}
|
||||
}
|
||||
|
||||
func (l *GenericListener) AddHandler(msg FrameType, handler func(*FrameWithPayload, chan<- *FrameWithPayload) error) {
|
||||
func (l *GenericListener) AddHandler(msg FrameType, handler func(*FrameWithPayload, chan<- FrameWithPayload) error) {
|
||||
l.handlers[msg] = handler
|
||||
}
|
||||
|
||||
func (l *GenericListener) HandleFrame(conn net.Conn, frame *FrameWithPayload) {
|
||||
handler, ok := l.handlers[frame.Type]
|
||||
defer conn.Close()
|
||||
if ok {
|
||||
go func() {
|
||||
resultChan := make(chan *FrameWithPayload, 1)
|
||||
resultChan := make(chan FrameWithPayload, 1)
|
||||
defer close(resultChan)
|
||||
err := handler(frame, resultChan)
|
||||
if err != nil {
|
||||
log.Fatalf("Error handling frame: %v\n", err)
|
||||
}
|
||||
SendFrame(conn, <-resultChan)
|
||||
result := <-resultChan
|
||||
err = SendFrame(conn, &result)
|
||||
if err != nil {
|
||||
log.Fatalf("Error sending frame: %v\n", err)
|
||||
}
|
||||
}()
|
||||
} else {
|
||||
conn.Close()
|
||||
log.Fatalf("No handler for frame type %d\n", frame.Type)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,37 +2,19 @@ package main
|
||||
|
||||
import "testing"
|
||||
|
||||
type StringData string
|
||||
|
||||
func (s StringData) ToBytes() []byte {
|
||||
return []byte(s)
|
||||
}
|
||||
|
||||
func (s StringData) FromBytes(data []byte) error {
|
||||
s = StringData(data)
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestGenericConnection(t *testing.T) {
|
||||
conn := NewConnection("localhost:51337")
|
||||
listener, err := conn.Listen()
|
||||
if err != nil {
|
||||
t.Errorf("Error listening: %v\n", err)
|
||||
}
|
||||
listener.AddHandler(1, func(input *FrameWithPayload, resultChan chan<- *FrameWithPayload) error {
|
||||
payload := []byte("Hello, world!")
|
||||
resultChan <- &FrameWithPayload{
|
||||
Frame: Frame{
|
||||
Type: 2,
|
||||
Id: input.Id,
|
||||
StatusCode: 200,
|
||||
Length: uint32(len("Hello, world!")),
|
||||
},
|
||||
Payload: payload,
|
||||
}
|
||||
datta := []byte("Hello, world!")
|
||||
listener.AddHandler(1, func(input *FrameWithPayload, resultChan chan<- FrameWithPayload) error {
|
||||
|
||||
resultChan <- MakeFrameWithPayload(2, 200, datta)
|
||||
return nil
|
||||
})
|
||||
r, err := conn.Call(1, StringData("Hello, world!"))
|
||||
r, err := conn.Call(1, datta)
|
||||
if err != nil {
|
||||
t.Errorf("Error calling: %v\n", err)
|
||||
}
|
||||
@@ -40,9 +22,9 @@ func TestGenericConnection(t *testing.T) {
|
||||
t.Errorf("Expected type 2, got %d\n", r.Type)
|
||||
}
|
||||
i := 100
|
||||
results := make(chan *FrameWithPayload, i)
|
||||
results := make(chan FrameWithPayload, i)
|
||||
for i > 0 {
|
||||
conn.CallAsync(1, StringData("Hello, world!"), results)
|
||||
go conn.CallAsync(1, datta, results)
|
||||
i--
|
||||
}
|
||||
for i < 100 {
|
||||
|
||||
@@ -1,161 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type Server struct {
|
||||
*TCPServerMux
|
||||
}
|
||||
|
||||
func Listen(address string) (*Server, error) {
|
||||
listener, err := net.Listen("tcp", address)
|
||||
server := &Server{
|
||||
NewTCPServerMux(),
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
go func() {
|
||||
for {
|
||||
conn, err := listener.Accept()
|
||||
if err != nil {
|
||||
log.Printf("Error accepting connection: %v\n", err)
|
||||
continue
|
||||
}
|
||||
go server.HandleConnection(conn)
|
||||
}
|
||||
}()
|
||||
return server, nil
|
||||
}
|
||||
|
||||
type TCPServerMux struct {
|
||||
mu sync.RWMutex
|
||||
sendMux sync.Mutex
|
||||
listeners map[PoolMessage]func(data []byte) error
|
||||
functions map[PoolMessage]func(data []byte) (PoolMessage, []byte, error)
|
||||
}
|
||||
|
||||
func NewTCPServerMux() *TCPServerMux {
|
||||
m := &TCPServerMux{
|
||||
mu: sync.RWMutex{},
|
||||
listeners: make(map[PoolMessage]func(data []byte) error),
|
||||
functions: make(map[PoolMessage]func(data []byte) (PoolMessage, []byte, error)),
|
||||
}
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *TCPServerMux) handleListener(messageType PoolMessage, data []byte) (bool, error) {
|
||||
m.mu.RLock()
|
||||
handler, ok := m.listeners[messageType]
|
||||
m.mu.RUnlock()
|
||||
if ok {
|
||||
err := handler(data)
|
||||
if err != nil {
|
||||
return true, err
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (m *TCPServerMux) handleFunction(connection net.Conn, messageType PoolMessage, data []byte) (bool, error) {
|
||||
m.mu.RLock()
|
||||
function, ok := m.functions[messageType]
|
||||
m.mu.RUnlock()
|
||||
m.sendMux.Lock()
|
||||
defer m.sendMux.Unlock()
|
||||
if ok {
|
||||
connection.Write(header[:])
|
||||
responseType, responseData, err := function(data)
|
||||
if err != nil {
|
||||
errData := []byte(err.Error())
|
||||
err = binary.Write(connection, binary.LittleEndian, Packet{
|
||||
Version: CurrentPacketVersion,
|
||||
MessageType: responseType,
|
||||
StatusCode: 500,
|
||||
DataLength: uint32(len(errData)),
|
||||
})
|
||||
_, err = connection.Write(errData)
|
||||
return true, err
|
||||
}
|
||||
err = binary.Write(connection, binary.LittleEndian, Packet{
|
||||
Version: CurrentPacketVersion,
|
||||
MessageType: responseType,
|
||||
StatusCode: 200,
|
||||
DataLength: uint32(len(responseData)),
|
||||
})
|
||||
if err != nil {
|
||||
return true, err
|
||||
}
|
||||
packetsSent.Inc()
|
||||
_, err = connection.Write(responseData)
|
||||
return true, err
|
||||
} else {
|
||||
log.Printf("No pool handler for type: %d\n", messageType)
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (m *TCPServerMux) HandleConnection(connection net.Conn) error {
|
||||
|
||||
defer connection.Close()
|
||||
var packet Packet
|
||||
reader := bufio.NewReader(connection)
|
||||
for {
|
||||
err := ReadPacket(reader, &packet)
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
return nil
|
||||
}
|
||||
log.Printf("Error receiving packet: %v\n", err)
|
||||
return err
|
||||
}
|
||||
if packet.Version != CurrentPacketVersion {
|
||||
log.Printf("Incorrect package version: %v\n", err)
|
||||
continue
|
||||
}
|
||||
data, err := GetPacketData(reader, packet.DataLength)
|
||||
if err != nil {
|
||||
log.Printf("Error receiving packet data: %v\n", err)
|
||||
return err
|
||||
}
|
||||
go m.HandleData(connection, packet.MessageType, data)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *TCPServerMux) HandleData(connection net.Conn, t PoolMessage, data []byte) {
|
||||
// listener := m.listeners[t]
|
||||
// handler := m.functions[t]
|
||||
status, err := m.handleListener(t, data)
|
||||
if err != nil {
|
||||
log.Printf("Error handling listener: %v\n", err)
|
||||
}
|
||||
if !status {
|
||||
status, err = m.handleFunction(connection, t, data)
|
||||
if err != nil {
|
||||
log.Printf("Error handling function: %v\n", err)
|
||||
}
|
||||
if !status {
|
||||
log.Printf("Unknown message type: %d\n", t)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *TCPServerMux) ListenFor(messageType PoolMessage, handler func(data []byte) error) {
|
||||
m.mu.Lock()
|
||||
m.listeners[messageType] = handler
|
||||
m.mu.Unlock()
|
||||
}
|
||||
|
||||
func (m *TCPServerMux) HandleCall(messageType PoolMessage, handler func(data []byte) (PoolMessage, []byte, error)) {
|
||||
m.mu.Lock()
|
||||
m.functions[messageType] = handler
|
||||
m.mu.Unlock()
|
||||
}
|
||||
54
tcp_test.go
54
tcp_test.go
@@ -1,54 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestTcpHelpers(t *testing.T) {
|
||||
|
||||
server, err := Listen("localhost:51337")
|
||||
if err != nil {
|
||||
t.Errorf("Error listening: %v\n", err)
|
||||
}
|
||||
client, err := Dial("localhost:51337")
|
||||
if err != nil {
|
||||
t.Errorf("Error dialing: %v\n", err)
|
||||
}
|
||||
var messageData string
|
||||
server.ListenFor(1, func(data []byte) error {
|
||||
log.Printf("Received message: %s\n", string(data))
|
||||
messageData = string(data)
|
||||
return nil
|
||||
})
|
||||
server.HandleCall(2, func(data []byte) (PoolMessage, []byte, error) {
|
||||
log.Printf("Received call: %s\n", string(data))
|
||||
return 3, []byte("Hello, client!"), nil
|
||||
})
|
||||
server.HandleCall(Ping, func(data []byte) (PoolMessage, []byte, error) {
|
||||
return Pong, nil, nil
|
||||
})
|
||||
|
||||
client.SendPacket(1, []byte("Hello, world!"))
|
||||
answer, err := client.Call(2, 3, []byte("Hello, server!"))
|
||||
if err != nil {
|
||||
t.Errorf("Error calling: %v\n", err)
|
||||
}
|
||||
for i := 0; i < 100; i++ {
|
||||
_, err = client.Call(Ping, Pong, nil)
|
||||
if err != nil {
|
||||
t.Errorf("Error calling: %v\n", err)
|
||||
}
|
||||
}
|
||||
_, err = client.Call(Ping, Pong, nil)
|
||||
if err != nil {
|
||||
t.Errorf("Error calling: %v\n", err)
|
||||
}
|
||||
client.Close()
|
||||
if string(answer.Data) != "Hello, client!" {
|
||||
t.Errorf("Expected answer 'Hello, client!', got %s\n", string(answer.Data))
|
||||
}
|
||||
if messageData != "Hello, world!" {
|
||||
t.Errorf("Expected message 'Hello, world!', got %s\n", messageData)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user