233 lines
4.9 KiB
Go
233 lines
4.9 KiB
Go
package main
|
|
|
|
import (
|
|
"bufio"
|
|
"encoding/binary"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"net"
|
|
"time"
|
|
|
|
"github.com/yudhasubki/netpool"
|
|
)
|
|
|
|
type Connection struct {
|
|
address string
|
|
pool netpool.Netpooler
|
|
count uint64
|
|
}
|
|
|
|
type FrameType uint32
|
|
type StatusCode uint32
|
|
type CheckSum uint32
|
|
|
|
type Frame struct {
|
|
Type FrameType
|
|
StatusCode StatusCode
|
|
Length uint32
|
|
Checksum CheckSum
|
|
}
|
|
|
|
func (f *Frame) IsValid() bool {
|
|
return f.Checksum == MakeChecksum(f.Type, f.StatusCode, f.Length)
|
|
}
|
|
|
|
func MakeChecksum(msg FrameType, statusCode StatusCode, length uint32) CheckSum {
|
|
sum := CheckSum((uint32(msg) + uint32(statusCode) + length) / 8)
|
|
return sum
|
|
}
|
|
|
|
type FrameWithPayload struct {
|
|
Frame
|
|
Payload []byte
|
|
}
|
|
|
|
func MakeFrameWithPayload(msg FrameType, statusCode StatusCode, payload []byte) FrameWithPayload {
|
|
len := uint32(len(payload))
|
|
return FrameWithPayload{
|
|
Frame: Frame{
|
|
Type: msg,
|
|
StatusCode: statusCode,
|
|
Length: len,
|
|
Checksum: MakeChecksum(msg, statusCode, len),
|
|
},
|
|
Payload: payload,
|
|
}
|
|
}
|
|
|
|
type FrameData interface {
|
|
ToBytes() []byte
|
|
FromBytes([]byte) error
|
|
}
|
|
|
|
func NewConnection(address string, pool netpool.Netpooler) *Connection {
|
|
return &Connection{
|
|
count: 0,
|
|
pool: pool,
|
|
address: address,
|
|
}
|
|
}
|
|
|
|
func SendFrame(conn net.Conn, data *FrameWithPayload) error {
|
|
|
|
err := binary.Write(conn, binary.LittleEndian, data.Frame)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
_, err = conn.Write(data.Payload)
|
|
|
|
return err
|
|
}
|
|
|
|
func (c *Connection) CallAsync(msg FrameType, payload []byte, ch chan<- FrameWithPayload) (net.Conn, error) {
|
|
conn, err := c.pool.Get()
|
|
//conn, err := net.Dial("tcp", c.address)
|
|
if err != nil {
|
|
return conn, err
|
|
}
|
|
go WaitForFrame(conn, ch)
|
|
|
|
go func(toSend FrameWithPayload) {
|
|
err = SendFrame(conn, &toSend)
|
|
if err != nil {
|
|
log.Printf("Error sending frame: %v\n", err)
|
|
//close(ch)
|
|
//conn.Close()
|
|
}
|
|
}(MakeFrameWithPayload(msg, 1, payload))
|
|
|
|
c.count++
|
|
return conn, err
|
|
}
|
|
|
|
func (c *Connection) Call(msg FrameType, data []byte) (*FrameWithPayload, error) {
|
|
ch := make(chan FrameWithPayload, 1)
|
|
conn, err := c.CallAsync(msg, data, ch)
|
|
defer c.pool.Put(conn, err)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
defer close(ch)
|
|
|
|
ret := <-ch
|
|
return &ret, nil
|
|
// select {
|
|
// case ret := <-ch:
|
|
// return &ret, nil
|
|
// case <-time.After(MaxCallDuration):
|
|
// return nil, fmt.Errorf("timeout waiting for frame")
|
|
// }
|
|
}
|
|
|
|
func WaitForFrame(conn net.Conn, resultChan chan<- FrameWithPayload) error {
|
|
var err error
|
|
var frame Frame
|
|
|
|
err = binary.Read(conn, binary.LittleEndian, &frame)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if frame.IsValid() {
|
|
payload := make([]byte, frame.Length)
|
|
_, err = conn.Read(payload)
|
|
if err != nil {
|
|
conn.Close()
|
|
return err
|
|
}
|
|
resultChan <- FrameWithPayload{
|
|
Frame: frame,
|
|
Payload: payload,
|
|
}
|
|
return err
|
|
}
|
|
log.Println("Checksum mismatch")
|
|
return fmt.Errorf("checksum mismatch")
|
|
}
|
|
|
|
type GenericListener struct {
|
|
StopListener bool
|
|
handlers map[FrameType]func(*FrameWithPayload, chan<- FrameWithPayload) error
|
|
}
|
|
|
|
func (c *Connection) Listen() (*GenericListener, error) {
|
|
l, err := net.Listen("tcp", c.address)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
ret := &GenericListener{
|
|
handlers: make(map[FrameType]func(*FrameWithPayload, chan<- FrameWithPayload) error),
|
|
}
|
|
go func() {
|
|
for !ret.StopListener {
|
|
connection, err := l.Accept()
|
|
if err != nil {
|
|
log.Printf("Error accepting connection: %v\n", err)
|
|
continue
|
|
}
|
|
go ret.HandleConnection(connection)
|
|
}
|
|
}()
|
|
return ret, nil
|
|
}
|
|
|
|
const (
|
|
MaxCallDuration = 300 * time.Millisecond
|
|
ListenerKeepalive = 5 * time.Second
|
|
)
|
|
|
|
func (l *GenericListener) HandleConnection(conn net.Conn) {
|
|
var err error
|
|
var frame Frame
|
|
log.Printf("Server Connection accepted: %s\n", conn.RemoteAddr().String())
|
|
b := bufio.NewReader(conn)
|
|
for err != io.EOF {
|
|
|
|
err = binary.Read(b, binary.LittleEndian, &frame)
|
|
|
|
if err == nil && frame.IsValid() {
|
|
payload := make([]byte, frame.Length)
|
|
_, err = b.Read(payload)
|
|
if err == nil {
|
|
err = l.HandleFrame(conn, &FrameWithPayload{
|
|
Frame: frame,
|
|
Payload: payload,
|
|
})
|
|
}
|
|
}
|
|
}
|
|
conn.Close()
|
|
log.Printf("Server Connection closed")
|
|
}
|
|
|
|
func (l *GenericListener) AddHandler(msg FrameType, handler func(*FrameWithPayload, chan<- FrameWithPayload) error) {
|
|
l.handlers[msg] = handler
|
|
}
|
|
|
|
func (l *GenericListener) HandleFrame(conn net.Conn, frame *FrameWithPayload) error {
|
|
handler, ok := l.handlers[frame.Type]
|
|
if ok {
|
|
go func() {
|
|
resultChan := make(chan FrameWithPayload, 1)
|
|
defer close(resultChan)
|
|
err := handler(frame, resultChan)
|
|
if err != nil {
|
|
errFrame := MakeFrameWithPayload(frame.Type, 500, []byte(err.Error()))
|
|
SendFrame(conn, &errFrame)
|
|
log.Printf("Handler returned error: %s", err)
|
|
return
|
|
}
|
|
result := <-resultChan
|
|
err = SendFrame(conn, &result)
|
|
if err != nil {
|
|
log.Printf("Error sending frame: %s", err)
|
|
}
|
|
}()
|
|
} else {
|
|
conn.Close()
|
|
return fmt.Errorf("no handler for frame type %d", frame.Type)
|
|
}
|
|
return nil
|
|
}
|