package main import ( "bufio" "encoding/binary" "fmt" "io" "log" "net" "time" "github.com/yudhasubki/netpool" ) type Connection struct { address string pool netpool.Netpooler count uint64 } type FrameType uint32 type StatusCode uint32 type CheckSum uint32 type Frame struct { Type FrameType StatusCode StatusCode Length uint32 Checksum CheckSum } func (f *Frame) IsValid() bool { return f.Checksum == MakeChecksum(f.Type, f.StatusCode, f.Length) } func MakeChecksum(msg FrameType, statusCode StatusCode, length uint32) CheckSum { sum := CheckSum((uint32(msg) + uint32(statusCode) + length) / 8) return sum } type FrameWithPayload struct { Frame Payload []byte } func MakeFrameWithPayload(msg FrameType, statusCode StatusCode, payload []byte) FrameWithPayload { len := uint32(len(payload)) return FrameWithPayload{ Frame: Frame{ Type: msg, StatusCode: statusCode, Length: len, Checksum: MakeChecksum(msg, statusCode, len), }, Payload: payload, } } type FrameData interface { ToBytes() []byte FromBytes([]byte) error } func NewConnection(address string, pool netpool.Netpooler) *Connection { return &Connection{ count: 0, pool: pool, address: address, } } func SendFrame(conn net.Conn, data *FrameWithPayload) error { err := binary.Write(conn, binary.LittleEndian, data.Frame) if err != nil { return err } _, err = conn.Write(data.Payload) return err } func (c *Connection) CallAsync(msg FrameType, payload []byte, ch chan<- FrameWithPayload) (net.Conn, error) { conn, err := c.pool.Get() //conn, err := net.Dial("tcp", c.address) if err != nil { return conn, err } go WaitForFrame(conn, ch) go func(toSend FrameWithPayload) { err = SendFrame(conn, &toSend) if err != nil { log.Printf("Error sending frame: %v\n", err) //close(ch) //conn.Close() } }(MakeFrameWithPayload(msg, 1, payload)) c.count++ return conn, err } func (c *Connection) Call(msg FrameType, data []byte) (*FrameWithPayload, error) { ch := make(chan FrameWithPayload, 1) conn, err := c.CallAsync(msg, data, ch) if err != nil { return nil, err } defer c.pool.Put(conn, err) // conn.Close() defer close(ch) ret := <-ch return &ret, nil // select { // case ret := <-ch: // return &ret, nil // case <-time.After(MaxCallDuration): // return nil, fmt.Errorf("timeout waiting for frame") // } } func WaitForFrame(conn net.Conn, resultChan chan<- FrameWithPayload) error { var err error var frame Frame err = binary.Read(conn, binary.LittleEndian, &frame) if err != nil { return err } if frame.IsValid() { payload := make([]byte, frame.Length) _, err = conn.Read(payload) if err != nil { conn.Close() return err } resultChan <- FrameWithPayload{ Frame: frame, Payload: payload, } return err } log.Println("Checksum mismatch") return fmt.Errorf("checksum mismatch") } type GenericListener struct { StopListener bool handlers map[FrameType]func(*FrameWithPayload, chan<- FrameWithPayload) error } func (c *Connection) Listen() (*GenericListener, error) { l, err := net.Listen("tcp", c.address) if err != nil { return nil, err } ret := &GenericListener{ handlers: make(map[FrameType]func(*FrameWithPayload, chan<- FrameWithPayload) error), } go func() { for !ret.StopListener { connection, err := l.Accept() if err != nil { log.Printf("Error accepting connection: %v\n", err) continue } go ret.HandleConnection(connection) } }() return ret, nil } const ( MaxCallDuration = 300 * time.Millisecond ListenerKeepalive = 5 * time.Second ) func (l *GenericListener) HandleConnection(conn net.Conn) { var err error var frame Frame b := bufio.NewReader(conn) for err != io.EOF { err = binary.Read(b, binary.LittleEndian, &frame) if err == nil && frame.IsValid() { payload := make([]byte, frame.Length) _, err = b.Read(payload) if err == nil { err = l.HandleFrame(conn, &FrameWithPayload{ Frame: frame, Payload: payload, }) } } } } func (l *GenericListener) AddHandler(msg FrameType, handler func(*FrameWithPayload, chan<- FrameWithPayload) error) { l.handlers[msg] = handler } func (l *GenericListener) HandleFrame(conn net.Conn, frame *FrameWithPayload) error { handler, ok := l.handlers[frame.Type] if ok { go func() { resultChan := make(chan FrameWithPayload, 1) defer close(resultChan) err := handler(frame, resultChan) if err != nil { errFrame := MakeFrameWithPayload(frame.Type, 500, []byte(err.Error())) SendFrame(conn, &errFrame) log.Printf("Handler returned error: %s", err) return } result := <-resultChan err = SendFrame(conn, &result) if err != nil { log.Printf("Error sending frame: %s", err) } }() } else { conn.Close() return fmt.Errorf("no handler for frame type %d", frame.Type) } return nil }