diff --git a/cart-packet-queue.go b/cart-packet-queue.go index d6af474..5584dd5 100644 --- a/cart-packet-queue.go +++ b/cart-packet-queue.go @@ -5,6 +5,7 @@ import ( "fmt" "log" "sync" + "time" ) type CartPacketQueue struct { @@ -70,14 +71,34 @@ func (p *CartPacketQueue) HandleConnection(connection *PersistentConnection) err } func (p *CartPacketQueue) HandleData(t CartMessage, id CartId, data CallResult) { + p.getListener(t, id, func(l *Listener) { + l.Chan <- data + l.Count-- + }) + // p.mu.Lock() + // defer p.mu.Unlock() + // pl, ok := p.expectedPackages[t] + // if ok { + // l, ok := (*pl)[id] + // if ok { + // l.Chan <- data + // l.Count-- + // if l.Count == 0 { + // close(l.Chan) + // delete(*pl, id) + // } + // } + // } +} + +func (p *CartPacketQueue) getListener(t CartMessage, id CartId, fn func(*Listener)) { p.mu.Lock() defer p.mu.Unlock() pl, ok := p.expectedPackages[t] if ok { l, ok := (*pl)[id] if ok { - l.Chan <- data - l.Count-- + fn(&l) if l.Count == 0 { close(l.Chan) delete(*pl, id) @@ -86,6 +107,30 @@ func (p *CartPacketQueue) HandleData(t CartMessage, id CartId, data CallResult) } } +func CallResultWithTimeout(onTimeout func() CallResult) chan CallResult { + ch := make(chan CallResult, 1) + resultCh := make(chan CallResult, 1) + select { + case ret := <-resultCh: + ch <- ret + case <-time.After(300 * time.Millisecond): + ch <- onTimeout() + } + return ch +} + +func (p *CartPacketQueue) MakeChannel(messageType CartMessage, id CartId) chan CallResult { + return CallResultWithTimeout(func() CallResult { + p.getListener(messageType, id, func(l *Listener) { + l.Count-- + }) + return CallResult{ + StatusCode: 504, + Data: []byte("timeout cart call"), + } + }) +} + func (p *CartPacketQueue) Expect(messageType CartMessage, id CartId) <-chan CallResult { p.mu.Lock() defer p.mu.Unlock() @@ -95,15 +140,17 @@ func (p *CartPacketQueue) Expect(messageType CartMessage, id CartId) <-chan Call idl.Count++ return idl.Chan } - ch := make(chan CallResult) + ch := p.MakeChannel(messageType, id) + (*l)[id] = Listener{ Chan: ch, Count: 1, } + return ch } - ch := make(chan CallResult) + ch := p.MakeChannel(messageType, id) p.expectedPackages[messageType] = &CartListener{ id: Listener{ Chan: ch, diff --git a/packet-queue.go b/packet-queue.go index c16a80f..ad37026 100644 --- a/packet-queue.go +++ b/packet-queue.go @@ -5,6 +5,7 @@ import ( "fmt" "log" "sync" + "time" ) type PacketQueue struct { @@ -98,7 +99,19 @@ func (p *PacketQueue) Expect(messageType PoolMessage) <-chan CallResult { return l.Chan } - ch := make(chan CallResult) + ch := make(chan CallResult, 1) + go func() { + time.Sleep(time.Millisecond * 300) + p.mu.Lock() + defer p.mu.Unlock() + + ch <- CallResult{ + StatusCode: 504, + Data: []byte("timeout cart call"), + } + + close(ch) + }() p.expectedPackages[messageType] = &Listener{ Count: 1, Chan: ch, diff --git a/tcp-cart-client.go b/tcp-cart-client.go index 77d5b5f..694cdae 100644 --- a/tcp-cart-client.go +++ b/tcp-cart-client.go @@ -2,10 +2,8 @@ package main import ( "encoding/binary" - "fmt" "log" "sync" - "time" ) type CartClient struct { @@ -73,13 +71,9 @@ func (m *CartTCPClient) call(messageType CartMessage, id CartId, responseType Ca if err != nil { return nil, m.PersistentConnection.HandleConnectionError(err) } - select { - case ret := <-packetChan: - return &ret, nil - case <-time.After(time.Millisecond * 300): - log.Printf("Timeout waiting for cart response to message type %d\n", responseType) - return nil, m.PersistentConnection.HandleConnectionError(fmt.Errorf("timeout")) - } + + ret := <-packetChan + return &ret, nil } func isRetirableError(err error) bool { diff --git a/tcp-client.go b/tcp-client.go index 6cbeeb1..498085b 100644 --- a/tcp-client.go +++ b/tcp-client.go @@ -2,7 +2,6 @@ package main import ( "encoding/binary" - "fmt" "log" "net" "sync" @@ -139,11 +138,7 @@ func (m *TCPClient) Call(messageType PoolMessage, responseType PoolMessage, data return nil, m.PersistentConnection.HandleConnectionError(err) } - select { - case ret := <-packetChan: - return &ret, nil - case <-time.After(time.Second): - log.Printf("Timeout waiting for cart response to message type %d\n", responseType) - return nil, m.PersistentConnection.HandleConnectionError(fmt.Errorf("timeout")) - } + ret := <-packetChan + return &ret, nil + } diff --git a/tcp-connection.go b/tcp-connection.go new file mode 100644 index 0000000..618c43a --- /dev/null +++ b/tcp-connection.go @@ -0,0 +1,172 @@ +package main + +import ( + "bufio" + "encoding/binary" + "fmt" + "log" + "net" + "time" +) + +type Connection struct { + address string + count uint64 +} + +type FrameType uint32 + +type Frame struct { + Id uint64 + Type FrameType + StatusCode uint32 + Length uint32 +} + +type FrameWithPayload struct { + Frame + Payload []byte +} + +type FrameData interface { + ToBytes() []byte + FromBytes([]byte) error +} + +func NewConnection(address string) *Connection { + return &Connection{ + count: 0, + address: address, + } +} + +func SendFrame(conn net.Conn, data *FrameWithPayload) error { + _, err := conn.Write(header[:]) + if err != nil { + return err + } + err = binary.Write(conn, binary.LittleEndian, data.Frame) + if err != nil { + return err + } + _, err = conn.Write(data.Payload) + return err +} + +func (c *Connection) CallAsync(msg FrameType, data FrameData, ch chan<- *FrameWithPayload) error { + conn, err := net.Dial("tcp", c.address) + go WaitForFrame(conn, ch) + if err != nil { + return err + } + payload := data.ToBytes() + toSend := &FrameWithPayload{ + Frame: Frame{ + Id: c.count, + Type: msg, + StatusCode: 0, + Length: uint32(len(payload)), + }, + Payload: payload, + } + + err = SendFrame(conn, toSend) + if err != nil { + close(ch) + return err + } + + c.count++ + return nil +} + +func (c *Connection) Call(msg FrameType, data FrameData) (*FrameWithPayload, error) { + ch := make(chan *FrameWithPayload, 1) + c.CallAsync(msg, data, ch) + select { + case ret := <-ch: + return ret, nil + case <-time.After(5 * time.Second): + return nil, fmt.Errorf("timeout") + } +} + +func WaitForFrame(conn net.Conn, resultChan chan<- *FrameWithPayload) error { + defer conn.Close() + var err error + r := bufio.NewReader(conn) + h := make([]byte, 4) + r.Read(h) + if h[0] == header[0] && h[1] == header[1] && h[2] == header[2] && h[3] == header[3] { + frame := Frame{} + err = binary.Read(r, binary.LittleEndian, &frame) + payload := make([]byte, frame.Length) + _, err = r.Read(payload) + resultChan <- &FrameWithPayload{ + Frame: frame, + Payload: payload, + } + return err + } + resultChan <- nil + return err +} + +type GenericListener struct { + Closed bool + handlers map[FrameType]func(*FrameWithPayload, chan<- *FrameWithPayload) error +} + +func (c *Connection) Listen() (*GenericListener, error) { + l, err := net.Listen("tcp", c.address) + if err != nil { + return nil, err + } + ret := &GenericListener{ + handlers: make(map[FrameType]func(*FrameWithPayload, chan<- *FrameWithPayload) error), + } + go func() { + for !ret.Closed { + connection, err := l.Accept() + if err != nil { + log.Fatalf("Error accepting connection: %v\n", err) + } + go ret.HandleConnection(connection) + } + }() + return ret, nil +} + +func (l *GenericListener) HandleConnection(conn net.Conn) { + ch := make(chan *FrameWithPayload, 1) + go WaitForFrame(conn, ch) + select { + case frame := <-ch: + go l.HandleFrame(conn, frame) + case <-time.After(1 * time.Second): + close(ch) + log.Printf("Timeout waiting for frame\n") + } +} + +func (l *GenericListener) AddHandler(msg FrameType, handler func(*FrameWithPayload, chan<- *FrameWithPayload) error) { + l.handlers[msg] = handler +} + +func (l *GenericListener) HandleFrame(conn net.Conn, frame *FrameWithPayload) { + handler, ok := l.handlers[frame.Type] + defer conn.Close() + if ok { + go func() { + resultChan := make(chan *FrameWithPayload, 1) + defer close(resultChan) + err := handler(frame, resultChan) + if err != nil { + log.Fatalf("Error handling frame: %v\n", err) + } + SendFrame(conn, <-resultChan) + }() + } else { + log.Fatalf("No handler for frame type %d\n", frame.Type) + } +} diff --git a/tcp-connection_test.go b/tcp-connection_test.go new file mode 100644 index 0000000..d208f37 --- /dev/null +++ b/tcp-connection_test.go @@ -0,0 +1,56 @@ +package main + +import "testing" + +type StringData string + +func (s StringData) ToBytes() []byte { + return []byte(s) +} + +func (s StringData) FromBytes(data []byte) error { + s = StringData(data) + return nil +} + +func TestGenericConnection(t *testing.T) { + conn := NewConnection("localhost:51337") + listener, err := conn.Listen() + if err != nil { + t.Errorf("Error listening: %v\n", err) + } + listener.AddHandler(1, func(input *FrameWithPayload, resultChan chan<- *FrameWithPayload) error { + payload := []byte("Hello, world!") + resultChan <- &FrameWithPayload{ + Frame: Frame{ + Type: 2, + Id: input.Id, + StatusCode: 200, + Length: uint32(len("Hello, world!")), + }, + Payload: payload, + } + return nil + }) + r, err := conn.Call(1, StringData("Hello, world!")) + if err != nil { + t.Errorf("Error calling: %v\n", err) + } + if r.Type != 2 { + t.Errorf("Expected type 2, got %d\n", r.Type) + } + i := 100 + results := make(chan *FrameWithPayload, i) + for i > 0 { + conn.CallAsync(1, StringData("Hello, world!"), results) + i-- + } + for i < 100 { + r := <-results + if r.Type != 2 { + t.Errorf("Expected type 2, got %d\n", r.Type) + } + i++ + } + +}