173 lines
3.5 KiB
Go
173 lines
3.5 KiB
Go
package main
|
|
|
|
import (
|
|
"bufio"
|
|
"encoding/binary"
|
|
"fmt"
|
|
"log"
|
|
"net"
|
|
"time"
|
|
)
|
|
|
|
type Connection struct {
|
|
address string
|
|
count uint64
|
|
}
|
|
|
|
type FrameType uint32
|
|
|
|
type Frame struct {
|
|
Id uint64
|
|
Type FrameType
|
|
StatusCode uint32
|
|
Length uint32
|
|
}
|
|
|
|
type FrameWithPayload struct {
|
|
Frame
|
|
Payload []byte
|
|
}
|
|
|
|
type FrameData interface {
|
|
ToBytes() []byte
|
|
FromBytes([]byte) error
|
|
}
|
|
|
|
func NewConnection(address string) *Connection {
|
|
return &Connection{
|
|
count: 0,
|
|
address: address,
|
|
}
|
|
}
|
|
|
|
func SendFrame(conn net.Conn, data *FrameWithPayload) error {
|
|
_, err := conn.Write(header[:])
|
|
if err != nil {
|
|
return err
|
|
}
|
|
err = binary.Write(conn, binary.LittleEndian, data.Frame)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
_, err = conn.Write(data.Payload)
|
|
return err
|
|
}
|
|
|
|
func (c *Connection) CallAsync(msg FrameType, data FrameData, ch chan<- *FrameWithPayload) error {
|
|
conn, err := net.Dial("tcp", c.address)
|
|
go WaitForFrame(conn, ch)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
payload := data.ToBytes()
|
|
toSend := &FrameWithPayload{
|
|
Frame: Frame{
|
|
Id: c.count,
|
|
Type: msg,
|
|
StatusCode: 0,
|
|
Length: uint32(len(payload)),
|
|
},
|
|
Payload: payload,
|
|
}
|
|
|
|
err = SendFrame(conn, toSend)
|
|
if err != nil {
|
|
close(ch)
|
|
return err
|
|
}
|
|
|
|
c.count++
|
|
return nil
|
|
}
|
|
|
|
func (c *Connection) Call(msg FrameType, data FrameData) (*FrameWithPayload, error) {
|
|
ch := make(chan *FrameWithPayload, 1)
|
|
c.CallAsync(msg, data, ch)
|
|
select {
|
|
case ret := <-ch:
|
|
return ret, nil
|
|
case <-time.After(5 * time.Second):
|
|
return nil, fmt.Errorf("timeout")
|
|
}
|
|
}
|
|
|
|
func WaitForFrame(conn net.Conn, resultChan chan<- *FrameWithPayload) error {
|
|
defer conn.Close()
|
|
var err error
|
|
r := bufio.NewReader(conn)
|
|
h := make([]byte, 4)
|
|
r.Read(h)
|
|
if h[0] == header[0] && h[1] == header[1] && h[2] == header[2] && h[3] == header[3] {
|
|
frame := Frame{}
|
|
err = binary.Read(r, binary.LittleEndian, &frame)
|
|
payload := make([]byte, frame.Length)
|
|
_, err = r.Read(payload)
|
|
resultChan <- &FrameWithPayload{
|
|
Frame: frame,
|
|
Payload: payload,
|
|
}
|
|
return err
|
|
}
|
|
resultChan <- nil
|
|
return err
|
|
}
|
|
|
|
type GenericListener struct {
|
|
Closed bool
|
|
handlers map[FrameType]func(*FrameWithPayload, chan<- *FrameWithPayload) error
|
|
}
|
|
|
|
func (c *Connection) Listen() (*GenericListener, error) {
|
|
l, err := net.Listen("tcp", c.address)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
ret := &GenericListener{
|
|
handlers: make(map[FrameType]func(*FrameWithPayload, chan<- *FrameWithPayload) error),
|
|
}
|
|
go func() {
|
|
for !ret.Closed {
|
|
connection, err := l.Accept()
|
|
if err != nil {
|
|
log.Fatalf("Error accepting connection: %v\n", err)
|
|
}
|
|
go ret.HandleConnection(connection)
|
|
}
|
|
}()
|
|
return ret, nil
|
|
}
|
|
|
|
func (l *GenericListener) HandleConnection(conn net.Conn) {
|
|
ch := make(chan *FrameWithPayload, 1)
|
|
go WaitForFrame(conn, ch)
|
|
select {
|
|
case frame := <-ch:
|
|
go l.HandleFrame(conn, frame)
|
|
case <-time.After(1 * time.Second):
|
|
close(ch)
|
|
log.Printf("Timeout waiting for frame\n")
|
|
}
|
|
}
|
|
|
|
func (l *GenericListener) AddHandler(msg FrameType, handler func(*FrameWithPayload, chan<- *FrameWithPayload) error) {
|
|
l.handlers[msg] = handler
|
|
}
|
|
|
|
func (l *GenericListener) HandleFrame(conn net.Conn, frame *FrameWithPayload) {
|
|
handler, ok := l.handlers[frame.Type]
|
|
defer conn.Close()
|
|
if ok {
|
|
go func() {
|
|
resultChan := make(chan *FrameWithPayload, 1)
|
|
defer close(resultChan)
|
|
err := handler(frame, resultChan)
|
|
if err != nil {
|
|
log.Fatalf("Error handling frame: %v\n", err)
|
|
}
|
|
SendFrame(conn, <-resultChan)
|
|
}()
|
|
} else {
|
|
log.Fatalf("No handler for frame type %d\n", frame.Type)
|
|
}
|
|
}
|