213 lines
4.4 KiB
Go
213 lines
4.4 KiB
Go
package main
|
|
|
|
import (
|
|
"bufio"
|
|
"encoding/binary"
|
|
"fmt"
|
|
"log"
|
|
"net"
|
|
"time"
|
|
)
|
|
|
|
type Connection struct {
|
|
address string
|
|
count uint64
|
|
}
|
|
|
|
type FrameType uint32
|
|
type StatusCode uint32
|
|
type CheckSum uint32
|
|
|
|
type Frame struct {
|
|
Type FrameType
|
|
StatusCode StatusCode
|
|
Length uint32
|
|
Checksum CheckSum
|
|
}
|
|
|
|
func (f *Frame) IsValid() bool {
|
|
return f.Checksum == MakeChecksum(f.Type, f.StatusCode, f.Length)
|
|
}
|
|
|
|
func MakeChecksum(msg FrameType, statusCode StatusCode, length uint32) CheckSum {
|
|
sum := CheckSum((uint32(msg) + uint32(statusCode) + length) / 8)
|
|
return sum
|
|
}
|
|
|
|
type FrameWithPayload struct {
|
|
Frame
|
|
Payload []byte
|
|
}
|
|
|
|
func MakeFrameWithPayload(msg FrameType, statusCode StatusCode, payload []byte) FrameWithPayload {
|
|
len := uint32(len(payload))
|
|
return FrameWithPayload{
|
|
Frame: Frame{
|
|
Type: msg,
|
|
StatusCode: statusCode,
|
|
Length: len,
|
|
Checksum: MakeChecksum(msg, statusCode, len),
|
|
},
|
|
Payload: payload,
|
|
}
|
|
}
|
|
|
|
type FrameData interface {
|
|
ToBytes() []byte
|
|
FromBytes([]byte) error
|
|
}
|
|
|
|
func NewConnection(address string) *Connection {
|
|
return &Connection{
|
|
count: 0,
|
|
address: address,
|
|
}
|
|
}
|
|
|
|
func SendFrame(conn net.Conn, data *FrameWithPayload) error {
|
|
w := bufio.NewWriter(conn)
|
|
err := binary.Write(w, binary.LittleEndian, data.Frame)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
_, err = w.Write(data.Payload)
|
|
w.Flush()
|
|
return err
|
|
}
|
|
|
|
func (c *Connection) CallAsync(msg FrameType, payload []byte, ch chan<- FrameWithPayload) (net.Conn, error) {
|
|
conn, err := net.Dial("tcp", c.address)
|
|
if err != nil {
|
|
return conn, err
|
|
}
|
|
go WaitForFrame(conn, ch)
|
|
|
|
go func(toSend FrameWithPayload) {
|
|
err = SendFrame(conn, &toSend)
|
|
if err != nil {
|
|
log.Printf("Error sending frame: %v\n", err)
|
|
//close(ch)
|
|
//conn.Close()
|
|
}
|
|
}(MakeFrameWithPayload(msg, 1, payload))
|
|
|
|
c.count++
|
|
return conn, nil
|
|
}
|
|
|
|
func (c *Connection) Call(msg FrameType, data []byte) (*FrameWithPayload, error) {
|
|
ch := make(chan FrameWithPayload, 1)
|
|
conn, err := c.CallAsync(msg, data, ch)
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
defer conn.Close()
|
|
defer close(ch)
|
|
|
|
select {
|
|
case ret := <-ch:
|
|
return &ret, nil
|
|
case <-time.After(MaxCallDuration):
|
|
return nil, fmt.Errorf("timeout")
|
|
}
|
|
}
|
|
|
|
func WaitForFrame(conn net.Conn, resultChan chan<- FrameWithPayload) error {
|
|
var err error
|
|
var frame Frame
|
|
r := bufio.NewReader(conn)
|
|
|
|
err = binary.Read(r, binary.LittleEndian, &frame)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if frame.IsValid() {
|
|
payload := make([]byte, frame.Length)
|
|
_, err = r.Read(payload)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
resultChan <- FrameWithPayload{
|
|
Frame: frame,
|
|
Payload: payload,
|
|
}
|
|
return err
|
|
}
|
|
log.Println("Checksum mismatch")
|
|
return fmt.Errorf("checksum mismatch")
|
|
}
|
|
|
|
type GenericListener struct {
|
|
StopListener bool
|
|
handlers map[FrameType]func(*FrameWithPayload, chan<- FrameWithPayload) error
|
|
}
|
|
|
|
func (c *Connection) Listen() (*GenericListener, error) {
|
|
l, err := net.Listen("tcp", c.address)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
ret := &GenericListener{
|
|
handlers: make(map[FrameType]func(*FrameWithPayload, chan<- FrameWithPayload) error),
|
|
}
|
|
go func() {
|
|
for !ret.StopListener {
|
|
connection, err := l.Accept()
|
|
if err != nil {
|
|
log.Fatalf("Error accepting connection: %v\n", err)
|
|
}
|
|
go ret.HandleConnection(connection)
|
|
}
|
|
}()
|
|
return ret, nil
|
|
}
|
|
|
|
const (
|
|
MaxCallDuration = 2500 * time.Millisecond
|
|
)
|
|
|
|
func (l *GenericListener) HandleConnection(conn net.Conn) {
|
|
ch := make(chan FrameWithPayload, 1)
|
|
conn.SetReadDeadline(time.Now().Add(MaxCallDuration))
|
|
go WaitForFrame(conn, ch)
|
|
select {
|
|
case frame := <-ch:
|
|
err := l.HandleFrame(conn, &frame)
|
|
if err != nil {
|
|
log.Fatalf("Error handling frame: %v\n", err)
|
|
}
|
|
case <-time.After(MaxCallDuration):
|
|
close(ch)
|
|
log.Printf("Timeout waiting for frame\n")
|
|
}
|
|
}
|
|
|
|
func (l *GenericListener) AddHandler(msg FrameType, handler func(*FrameWithPayload, chan<- FrameWithPayload) error) {
|
|
l.handlers[msg] = handler
|
|
}
|
|
|
|
func (l *GenericListener) HandleFrame(conn net.Conn, frame *FrameWithPayload) error {
|
|
handler, ok := l.handlers[frame.Type]
|
|
if ok {
|
|
go func() {
|
|
resultChan := make(chan FrameWithPayload, 1)
|
|
defer close(resultChan)
|
|
err := handler(frame, resultChan)
|
|
if err != nil {
|
|
log.Fatalf("Error handling frame: %v\n", err)
|
|
}
|
|
result := <-resultChan
|
|
err = SendFrame(conn, &result)
|
|
if err != nil {
|
|
log.Fatalf("Error sending frame: %v\n", err)
|
|
}
|
|
}()
|
|
} else {
|
|
conn.Close()
|
|
return fmt.Errorf("no handler for frame type %d", frame.Type)
|
|
}
|
|
return nil
|
|
}
|