diff --git a/cart-packet-queue.go b/cart-packet-queue.go index 04e1e1a..de6ded8 100644 --- a/cart-packet-queue.go +++ b/cart-packet-queue.go @@ -2,9 +2,8 @@ package main import ( "bufio" - "io" + "fmt" "log" - "net" "sync" "time" ) @@ -18,7 +17,7 @@ const CurrentPacketVersion = 2 type CartListener map[CartId]Listener -func NewCartPacketQueue(connection net.Conn) *CartPacketQueue { +func NewCartPacketQueue(connection *PersistentConnection) *CartPacketQueue { queue := &CartPacketQueue{ expectedPackages: make(map[uint32]*CartListener), @@ -38,7 +37,7 @@ func (p *CartPacketQueue) RemoveListeners() { p.expectedPackages = make(map[uint32]*CartListener) } -func (p *CartPacketQueue) HandleConnection(connection net.Conn) error { +func (p *CartPacketQueue) HandleConnection(connection *PersistentConnection) error { defer p.RemoveListeners() defer connection.Close() var packet CartPacket @@ -47,15 +46,13 @@ func (p *CartPacketQueue) HandleConnection(connection net.Conn) error { for { err := ReadCartPacket(reader, &packet) if err != nil { - if err == io.EOF { - return nil - } + log.Printf("Error receiving packet: %v\n", err) - return err + return connection.HandleConnectionError(err) } if packet.Version != CurrentPacketVersion { log.Printf("Incorrect version: %v\n", packet.Version) - return nil + return connection.HandleConnectionError(fmt.Errorf("incorrect version: %d", packet.Version)) } if packet.DataLength == 0 { go p.HandleData(packet.MessageType, packet.Id, CallResult{ @@ -67,7 +64,7 @@ func (p *CartPacketQueue) HandleConnection(connection net.Conn) error { data, err := GetPacketData(reader, packet.DataLength) if err != nil { log.Printf("Error receiving packet data: %v\n", err) - return err + return connection.HandleConnectionError(err) } go p.HandleData(packet.MessageType, packet.Id, CallResult{ StatusCode: packet.StatusCode, @@ -121,106 +118,3 @@ func (p *CartPacketQueue) Expect(messageType uint32, id CartId) <-chan CallResul return ch } - -// package main - -// import ( -// "fmt" -// "io" -// "log" -// "net" -// "sync" -// "time" -// ) - -// type CartPacketWithData struct { -// MessageType uint32 -// Id CartId -// Added time.Time -// Consumed bool -// Data []byte -// } - -// type CartPacketQueue struct { -// mu sync.RWMutex -// Packets []CartPacketWithData -// //connection net.Conn -// } - -// const cartCap = 150 - -// func NewCartPacketQueue(connection net.Conn) *CartPacketQueue { - -// queue := &CartPacketQueue{ -// Packets: make([]CartPacketWithData, 0, cartCap), -// //connection: connection, -// } -// go func() { -// defer connection.Close() -// var packet CartPacket -// for { -// err := ReadCartPacket(connection, &packet) -// if err != nil { -// if err == io.EOF { -// return -// } -// log.Printf("Error receiving packet: %v\n", err) -// //return -// } - -// data, err := GetPacketData(connection, packet.DataLength) -// if err != nil { -// log.Printf("Error receiving packet data: %v\n", err) -// return -// } -// go queue.HandleData(packet.MessageType, packet.Id, data) -// } -// }() -// return queue -// } - -// func (p *CartPacketQueue) HandleData(t uint32, id CartId, data []byte) { -// ts := time.Now() -// l := make([]CartPacketWithData, 0, cartCap) -// p.mu.RLock() -// breakAt := ts.Add(-time.Millisecond * 250) -// for _, packet := range p.Packets { -// if !packet.Consumed && packet.Added.After(breakAt) { -// l = append(l, packet) -// if len(l) >= cartCap { -// break -// } -// } -// } -// p.mu.RUnlock() -// p.mu.Lock() -// p.Packets = append([]CartPacketWithData{ -// { -// MessageType: t, -// Id: id, -// Added: ts, -// Data: data, -// }, -// }, l...) -// p.mu.Unlock() -// } - -// func (p *CartPacketQueue) Expect(messageType uint32, id CartId, timeToWait time.Duration) (*CartPacketWithData, error) { -// start := time.Now().Add(-time.Millisecond) - -// for { -// if time.Since(start) > timeToWait { -// return nil, fmt.Errorf("timeout waiting for message type %d", messageType) -// } -// p.mu.RLock() -// for _, packet := range p.Packets { -// if !packet.Consumed && packet.MessageType == messageType && packet.Id == id && packet.Added.After(start) { -// packet.Consumed = true -// p.mu.RUnlock() -// return &packet, nil -// } -// } -// p.mu.RUnlock() -// time.Sleep(time.Millisecond * 2) -// } -// } diff --git a/packet-queue.go b/packet-queue.go index d395db5..18a14be 100644 --- a/packet-queue.go +++ b/packet-queue.go @@ -3,11 +3,8 @@ package main import ( "bufio" "fmt" - "io" "log" - "net" "sync" - "time" ) type PacketQueue struct { @@ -25,7 +22,7 @@ type Listener struct { Chan chan CallResult } -func NewPacketQueue(connection net.Conn) *PacketQueue { +func NewPacketQueue(connection *PersistentConnection) *PacketQueue { queue := &PacketQueue{ expectedPackages: make(map[uint32]*Listener), } @@ -42,24 +39,20 @@ func (p *PacketQueue) RemoveListeners() { p.expectedPackages = make(map[uint32]*Listener) } -func (p *PacketQueue) HandleConnection(connection net.Conn) error { +func (p *PacketQueue) HandleConnection(connection *PersistentConnection) error { defer connection.Close() defer p.RemoveListeners() var packet Packet reader := bufio.NewReader(connection) - connection.SetReadDeadline(time.Now().Add(time.Millisecond * 200)) + for { err := ReadPacket(reader, &packet) if err != nil { - if err == io.EOF { - return nil - } - log.Printf("Error receiving packet: %v\n", err) - return err + return connection.HandleConnectionError(err) } if packet.Version != CurrentPacketVersion { log.Printf("Incorrect packet version: %v\n", packet.Version) - return fmt.Errorf("incorrect packet version: %d", packet.Version) + return connection.HandleConnectionError(fmt.Errorf("incorrect packet version: %d", packet.Version)) } if packet.DataLength == 0 { go p.HandleData(packet.MessageType, CallResult{ @@ -71,7 +64,7 @@ func (p *PacketQueue) HandleConnection(connection net.Conn) error { data, err := GetPacketData(reader, packet.DataLength) if err != nil { log.Printf("Error receiving packet data: %v\n", err) - return err + return connection.HandleConnectionError(err) } else { go p.HandleData(packet.MessageType, CallResult{ StatusCode: packet.StatusCode, diff --git a/tcp-cart-client.go b/tcp-cart-client.go index cf0cc4e..09ac905 100644 --- a/tcp-cart-client.go +++ b/tcp-cart-client.go @@ -3,7 +3,7 @@ package main import ( "encoding/binary" "fmt" - "net" + "log" "time" ) @@ -28,76 +28,51 @@ func (c *Client) Close() { } type CartTCPClient struct { - net.Conn + *PersistentConnection ErrorCount int address string *CartPacketQueue } func NewCartTCPClient(address string) (*CartTCPClient, error) { - connection, err := net.Dial("tcp", address) + connection, err := NewPersistentConnection(address) if err != nil { return nil, err } return &CartTCPClient{ - ErrorCount: 0, - Conn: connection, - address: address, - CartPacketQueue: NewCartPacketQueue(connection), + ErrorCount: 0, + PersistentConnection: connection, + address: address, + CartPacketQueue: NewCartPacketQueue(connection), }, nil } -func (m *CartTCPClient) Connect() error { - if m.Conn == nil { - connection, err := net.Dial("tcp", m.address) - if err != nil { - m.ErrorCount++ - return err - } - m.ErrorCount = 0 - m.Conn = connection - } - return nil -} - func (m *CartTCPClient) SendPacket(messageType uint32, id CartId, data []byte) error { - err := m.Connect() - if err != nil { - return err - } - err = binary.Write(m.Conn, binary.LittleEndian, CartPacket{ + err := binary.Write(m.Conn, binary.LittleEndian, CartPacket{ Version: CurrentPacketVersion, MessageType: messageType, DataLength: uint32(len(data)), Id: id, }) if err != nil { - return err + return m.HandleConnectionError(err) } _, err = m.Conn.Write(data) - m.Conn.SetDeadline(time.Now().Add(time.Second * 10)) - return err + return m.HandleConnectionError(err) } -// func (m *CartTCPClient) SendPacketFn(messageType uint16, id CartId, datafn func(w io.Writer) error) error { -// data, err := GetData(datafn) -// if err != nil { -// return err -// } -// return m.SendPacket(messageType, id, data) -// } - func (m *CartTCPClient) Call(messageType uint32, id CartId, responseType uint32, data []byte) (*CallResult, error) { packetChan := m.Expect(responseType, id) err := m.SendPacket(messageType, id, data) if err != nil { - return nil, err + return nil, m.HandleConnectionError(err) } select { case ret := <-packetChan: return &ret, nil - case <-time.After(time.Second * 10): - return nil, fmt.Errorf("timeout") + case <-time.After(time.Second): + log.Printf("Timeout waiting for cart response to message type %d\n", responseType) + return nil, m.HandleConnectionError(fmt.Errorf("timeout")) } } diff --git a/tcp-client.go b/tcp-client.go index 6e2c1ea..8146525 100644 --- a/tcp-client.go +++ b/tcp-client.go @@ -3,6 +3,7 @@ package main import ( "encoding/binary" "fmt" + "log" "net" "time" ) @@ -24,54 +25,66 @@ func Dial(address string) (*Client, error) { } type TCPClient struct { - net.Conn + *PersistentConnection ErrorCount int address string *PacketQueue } -func NewTCPClient(address string) (*TCPClient, error) { +type PersistentConnection struct { + net.Conn + address string +} + +func NewPersistentConnection(address string) (*PersistentConnection, error) { connection, err := net.Dial("tcp", address) if err != nil { return nil, err } - return &TCPClient{ - ErrorCount: 0, - Conn: connection, - address: address, - PacketQueue: NewPacketQueue(connection), + return &PersistentConnection{ + Conn: connection, + address: address, }, nil } -func (m *TCPClient) Connect() error { - if m.Conn == nil { - connection, err := net.Dial("tcp", m.address) - if err != nil { - return err - } - m.ErrorCount = 0 - m.Conn = connection +func (m *PersistentConnection) Connect() error { + connection, err := net.Dial("tcp", m.address) + if err != nil { + return err } + m.Conn = connection return nil } -func (m *TCPClient) HandleConnectionError(err error) error { +func (m *PersistentConnection) Close() { + m.Conn.Close() +} + +func (m *PersistentConnection) HandleConnectionError(err error) error { if err != nil { - m.ErrorCount++ + m.Conn.Close() + m.Connect() } return err } -func (m *TCPClient) Close() { - m.Conn.Close() +func NewTCPClient(address string) (*TCPClient, error) { + + connection, err := NewPersistentConnection(address) + if err != nil { + return nil, err + } + return &TCPClient{ + ErrorCount: 0, + PersistentConnection: connection, + address: address, + PacketQueue: NewPacketQueue(connection), + }, nil } func (m *TCPClient) SendPacket(messageType uint32, data []byte) error { - err := m.Connect() - if err != nil { - return err - } - err = binary.Write(m.Conn, binary.LittleEndian, Packet{ + + err := binary.Write(m.Conn, binary.LittleEndian, Packet{ Version: CurrentPacketVersion, MessageType: messageType, StatusCode: 0, @@ -81,29 +94,22 @@ func (m *TCPClient) SendPacket(messageType uint32, data []byte) error { return m.HandleConnectionError(err) } _, err = m.Conn.Write(data) - m.Conn.SetDeadline(time.Now().Add(time.Second * 10)) return m.HandleConnectionError(err) } -// func (m *TCPClient) SendPacketFn(messageType uint32, datafn func(w io.Writer) error) error { -// data, err := GetData(datafn) -// if err != nil { -// return err -// } -// return m.SendPacket(messageType, data) -// } - func (m *TCPClient) Call(messageType uint32, responseType uint32, data []byte) (*CallResult, error) { packetChan := m.Expect(responseType) err := m.SendPacket(messageType, data) if err != nil { - return nil, err + m.RemoveListeners() + return nil, m.HandleConnectionError(err) } select { case ret := <-packetChan: return &ret, nil - case <-time.After(time.Second * 10): - return nil, fmt.Errorf("timeout") + case <-time.After(time.Second): + log.Printf("Timeout waiting for cart response to message type %d\n", responseType) + return nil, m.HandleConnectionError(fmt.Errorf("timeout")) } }