package main import ( "bufio" "encoding/binary" "fmt" "log" "net" "time" ) type Connection struct { address string count uint64 } type FrameType uint32 type StatusCode uint32 type CheckSum uint32 type Frame struct { Type FrameType StatusCode StatusCode Length uint32 Checksum CheckSum } func (f *Frame) IsValid() bool { return f.Checksum == MakeChecksum(f.Type, f.StatusCode, f.Length) } func MakeChecksum(msg FrameType, statusCode StatusCode, length uint32) CheckSum { sum := CheckSum((uint32(msg) + uint32(statusCode) + length) / 8) return sum } type FrameWithPayload struct { Frame Payload []byte } func MakeFrameWithPayload(msg FrameType, statusCode StatusCode, payload []byte) FrameWithPayload { len := uint32(len(payload)) return FrameWithPayload{ Frame: Frame{ Type: msg, StatusCode: statusCode, Length: len, Checksum: MakeChecksum(msg, statusCode, len), }, Payload: payload, } } type FrameData interface { ToBytes() []byte FromBytes([]byte) error } func NewConnection(address string) *Connection { return &Connection{ count: 0, address: address, } } func SendFrame(conn net.Conn, data *FrameWithPayload) error { w := bufio.NewWriter(conn) err := binary.Write(w, binary.LittleEndian, data.Frame) if err != nil { return err } _, err = w.Write(data.Payload) w.Flush() return err } func (c *Connection) CallAsync(msg FrameType, payload []byte, ch chan<- FrameWithPayload) (net.Conn, error) { conn, err := net.Dial("tcp", c.address) if err != nil { return conn, err } go WaitForFrame(conn, ch) go func(toSend FrameWithPayload) { err = SendFrame(conn, &toSend) if err != nil { log.Printf("Error sending frame: %v\n", err) //close(ch) //conn.Close() } }(MakeFrameWithPayload(msg, 1, payload)) c.count++ return conn, nil } func (c *Connection) Call(msg FrameType, data []byte) (*FrameWithPayload, error) { ch := make(chan FrameWithPayload, 1) conn, err := c.CallAsync(msg, data, ch) if err != nil { return nil, err } defer conn.Close() defer close(ch) select { case ret := <-ch: return &ret, nil case <-time.After(MaxCallDuration): return nil, fmt.Errorf("timeout") } } func WaitForFrame(conn net.Conn, resultChan chan<- FrameWithPayload) error { var err error var frame Frame r := bufio.NewReader(conn) err = binary.Read(r, binary.LittleEndian, &frame) if err != nil { return err } if frame.IsValid() { payload := make([]byte, frame.Length) _, err = r.Read(payload) if err != nil { return err } resultChan <- FrameWithPayload{ Frame: frame, Payload: payload, } return err } log.Println("Checksum mismatch") return fmt.Errorf("checksum mismatch") } type GenericListener struct { StopListener bool handlers map[FrameType]func(*FrameWithPayload, chan<- FrameWithPayload) error } func (c *Connection) Listen() (*GenericListener, error) { l, err := net.Listen("tcp", c.address) if err != nil { return nil, err } ret := &GenericListener{ handlers: make(map[FrameType]func(*FrameWithPayload, chan<- FrameWithPayload) error), } go func() { for !ret.StopListener { connection, err := l.Accept() if err != nil { log.Fatalf("Error accepting connection: %v\n", err) } go ret.HandleConnection(connection) } }() return ret, nil } const ( MaxCallDuration = 2500 * time.Millisecond ) func (l *GenericListener) HandleConnection(conn net.Conn) { ch := make(chan FrameWithPayload, 1) conn.SetReadDeadline(time.Now().Add(MaxCallDuration)) go WaitForFrame(conn, ch) select { case frame := <-ch: err := l.HandleFrame(conn, &frame) if err != nil { log.Fatalf("Error handling frame: %v\n", err) } case <-time.After(MaxCallDuration): close(ch) log.Printf("Timeout waiting for frame\n") } } func (l *GenericListener) AddHandler(msg FrameType, handler func(*FrameWithPayload, chan<- FrameWithPayload) error) { l.handlers[msg] = handler } func (l *GenericListener) HandleFrame(conn net.Conn, frame *FrameWithPayload) error { handler, ok := l.handlers[frame.Type] if ok { go func() { resultChan := make(chan FrameWithPayload, 1) defer close(resultChan) err := handler(frame, resultChan) if err != nil { log.Fatalf("Error handling frame: %v\n", err) } result := <-resultChan err = SendFrame(conn, &result) if err != nil { log.Fatalf("Error sending frame: %v\n", err) } }() } else { conn.Close() return fmt.Errorf("no handler for frame type %d", frame.Type) } return nil }