diff --git a/cart-packet-queue.go b/cart-packet-queue.go index e420a31..ddc8abc 100644 --- a/cart-packet-queue.go +++ b/cart-packet-queue.go @@ -1,102 +1,186 @@ package main import ( - "fmt" "io" "log" "net" "sync" - "time" ) -type CartPacketWithData struct { - MessageType uint32 - Id CartId - Added time.Time - Consumed bool - Data []byte -} - type CartPacketQueue struct { - mu sync.RWMutex - Packets []CartPacketWithData - //connection net.Conn + mu sync.RWMutex + expectedPackages map[uint32]*CartListener } -const cartCap = 150 +type CartListener map[CartId]chan []byte func NewCartPacketQueue(connection net.Conn) *CartPacketQueue { queue := &CartPacketQueue{ - Packets: make([]CartPacketWithData, 0, cartCap), - //connection: connection, + expectedPackages: make(map[uint32]*CartListener), } - go func() { - defer connection.Close() - var packet CartPacket - for { - err := ReadCartPacket(connection, &packet) - if err != nil { - if err == io.EOF { - return - } - log.Printf("Error receiving packet: %v\n", err) - //return - } - - data, err := GetPacketData(connection, packet.DataLength) - if err != nil { - log.Printf("Error receiving packet data: %v\n", err) - return - } - go queue.HandleData(packet.MessageType, packet.Id, data) - } - }() + go queue.HandleConnection(connection) return queue } -func (p *CartPacketQueue) HandleData(t uint32, id CartId, data []byte) { - ts := time.Now() - l := make([]CartPacketWithData, 0, cartCap) - p.mu.RLock() - breakAt := ts.Add(-time.Millisecond * 250) - for _, packet := range p.Packets { - if !packet.Consumed && packet.Added.After(breakAt) { - l = append(l, packet) - if len(l) >= cartCap { - break - } - } - } - p.mu.RUnlock() - p.mu.Lock() - p.Packets = append([]CartPacketWithData{ - { - MessageType: t, - Id: id, - Added: ts, - Data: data, - }, - }, l...) - p.mu.Unlock() -} - -func (p *CartPacketQueue) Expect(messageType uint32, id CartId, timeToWait time.Duration) (*CartPacketWithData, error) { - start := time.Now().Add(-time.Millisecond) +func (p *CartPacketQueue) HandleConnection(connection net.Conn) error { + defer connection.Close() + var packet CartPacket for { - if time.Since(start) > timeToWait { - return nil, fmt.Errorf("timeout waiting for message type %d", messageType) - } - p.mu.RLock() - for _, packet := range p.Packets { - if !packet.Consumed && packet.MessageType == messageType && packet.Id == id && packet.Added.After(start) { - packet.Consumed = true - p.mu.RUnlock() - return &packet, nil + err := ReadCartPacket(connection, &packet) + if err != nil { + if err == io.EOF { + return nil } + log.Printf("Error receiving packet: %v\n", err) + return err } - p.mu.RUnlock() - time.Sleep(time.Millisecond * 2) + data, err := GetPacketData(connection, packet.DataLength) + if err != nil { + log.Printf("Error receiving packet data: %v\n", err) + return err + } + go p.HandleData(packet.MessageType, packet.Id, data) } } + +func (p *CartPacketQueue) HandleData(t uint32, id CartId, data []byte) { + p.mu.Lock() + defer p.mu.Unlock() + l, ok := p.expectedPackages[t] + if ok { + ch, ok := (*l)[id] + if ok { + ch <- data + close(ch) + delete(*l, id) + } + } + data = nil +} + +func (p *CartPacketQueue) Expect(messageType uint32, id CartId) <-chan []byte { + p.mu.Lock() + defer p.mu.Unlock() + l, ok := p.expectedPackages[messageType] + if ok { + if ch, idOk := (*l)[id]; idOk { + return ch + } + ch := make(chan []byte) + (*l)[id] = ch + return ch + } + + ch := make(chan []byte) + p.expectedPackages[messageType] = &CartListener{ + id: ch, + } + + return ch + +} + +// package main + +// import ( +// "fmt" +// "io" +// "log" +// "net" +// "sync" +// "time" +// ) + +// type CartPacketWithData struct { +// MessageType uint32 +// Id CartId +// Added time.Time +// Consumed bool +// Data []byte +// } + +// type CartPacketQueue struct { +// mu sync.RWMutex +// Packets []CartPacketWithData +// //connection net.Conn +// } + +// const cartCap = 150 + +// func NewCartPacketQueue(connection net.Conn) *CartPacketQueue { + +// queue := &CartPacketQueue{ +// Packets: make([]CartPacketWithData, 0, cartCap), +// //connection: connection, +// } +// go func() { +// defer connection.Close() +// var packet CartPacket +// for { +// err := ReadCartPacket(connection, &packet) +// if err != nil { +// if err == io.EOF { +// return +// } +// log.Printf("Error receiving packet: %v\n", err) +// //return +// } + +// data, err := GetPacketData(connection, packet.DataLength) +// if err != nil { +// log.Printf("Error receiving packet data: %v\n", err) +// return +// } +// go queue.HandleData(packet.MessageType, packet.Id, data) +// } +// }() +// return queue +// } + +// func (p *CartPacketQueue) HandleData(t uint32, id CartId, data []byte) { +// ts := time.Now() +// l := make([]CartPacketWithData, 0, cartCap) +// p.mu.RLock() +// breakAt := ts.Add(-time.Millisecond * 250) +// for _, packet := range p.Packets { +// if !packet.Consumed && packet.Added.After(breakAt) { +// l = append(l, packet) +// if len(l) >= cartCap { +// break +// } +// } +// } +// p.mu.RUnlock() +// p.mu.Lock() +// p.Packets = append([]CartPacketWithData{ +// { +// MessageType: t, +// Id: id, +// Added: ts, +// Data: data, +// }, +// }, l...) +// p.mu.Unlock() +// } + +// func (p *CartPacketQueue) Expect(messageType uint32, id CartId, timeToWait time.Duration) (*CartPacketWithData, error) { +// start := time.Now().Add(-time.Millisecond) + +// for { +// if time.Since(start) > timeToWait { +// return nil, fmt.Errorf("timeout waiting for message type %d", messageType) +// } +// p.mu.RLock() +// for _, packet := range p.Packets { +// if !packet.Consumed && packet.MessageType == messageType && packet.Id == id && packet.Added.After(start) { +// packet.Consumed = true +// p.mu.RUnlock() +// return &packet, nil +// } +// } +// p.mu.RUnlock() +// time.Sleep(time.Millisecond * 2) +// } +// } diff --git a/packet-queue.go b/packet-queue.go index c1b591b..4f641b1 100644 --- a/packet-queue.go +++ b/packet-queue.go @@ -1,98 +1,97 @@ package main import ( - "fmt" "io" "log" "net" "sync" - "time" ) -type PacketWithData struct { - MessageType uint32 - Added time.Time - Consumed bool - Data []byte -} +// type PacketWithData struct { +// MessageType uint32 +// Added time.Time +// Consumed bool +// Data []byte +// } type PacketQueue struct { - mu sync.RWMutex - Packets []PacketWithData + mu sync.RWMutex + expectedPackages map[uint32]*Listener + //Packets []PacketWithData //connection net.Conn } -const cap = 150 +//const cap = 150 + +type Listener struct { + Count int + Chan chan []byte +} func NewPacketQueue(connection net.Conn) *PacketQueue { queue := &PacketQueue{ - Packets: make([]PacketWithData, 0, cap), + expectedPackages: make(map[uint32]*Listener), + //Packets: make([]PacketWithData, 0, cap+1), //connection: connection, } - go func() { - defer connection.Close() - var packet Packet - for { - err := ReadPacket(connection, &packet) - if err != nil { - if err == io.EOF { - return - } - log.Printf("Error receiving packet: %v\n", err) - //return - } - data, err := GetPacketData(connection, packet.DataLength) - if err != nil { - log.Printf("Error receiving packet data: %v\n", err) - } - - go queue.HandleData(packet.MessageType, data) - } - }() + go queue.HandleConnection(connection) return queue } -func (p *PacketQueue) HandleData(t uint32, data []byte) { - ts := time.Now() - l := make([]PacketWithData, 0, cap) - p.mu.RLock() - breakAt := ts.Add(-time.Millisecond * 250) - for _, packet := range p.Packets { - if !packet.Consumed && packet.Added.After(breakAt) { - l = append(l, packet) - if len(l) >= cap { - break - } - } - } - p.mu.RUnlock() - p.mu.Lock() - p.Packets = append([]PacketWithData{ - { - MessageType: t, - Added: ts, - Data: data, - }, - }, l...) - p.mu.Unlock() -} - -func (p *PacketQueue) Expect(messageType uint32, timeToWait time.Duration) (*PacketWithData, error) { - start := time.Now().Add(-time.Millisecond) +func (p *PacketQueue) HandleConnection(connection net.Conn) error { + defer connection.Close() + var packet Packet for { - if time.Since(start) > timeToWait { - return nil, fmt.Errorf("timeout waiting for message type %d", messageType) - } - p.mu.RLock() - defer p.mu.RUnlock() - for _, packet := range p.Packets { - if !packet.Consumed && packet.MessageType == messageType && packet.Added.After(start) { - packet.Consumed = true - return &packet, nil + err := ReadPacket(connection, &packet) + if err != nil { + if err == io.EOF { + return nil } + log.Printf("Error receiving packet: %v\n", err) + return err } - time.Sleep(time.Millisecond * 4) + data, err := GetPacketData(connection, packet.DataLength) + if err != nil { + log.Printf("Error receiving packet data: %v\n", err) + return err + } + go p.HandleData(packet.MessageType, data) } } + +func (p *PacketQueue) HandleData(t uint32, data []byte) { + p.mu.Lock() + defer p.mu.Unlock() + l, ok := p.expectedPackages[t] + if ok { + l.Chan <- data + l.Count-- + if l.Count == 0 { + close(l.Chan) + delete(p.expectedPackages, t) + } + return + } + data = nil +} + +func (p *PacketQueue) Expect(messageType uint32) <-chan []byte { + p.mu.Lock() + defer p.mu.Unlock() + l, ok := p.expectedPackages[messageType] + if ok { + l.Count++ + return l.Chan + } + + ch := make(chan []byte) + p.expectedPackages[messageType] = &Listener{ + Count: 1, + Chan: ch, + } + + return ch + +} diff --git a/tcp-cart-client.go b/tcp-cart-client.go index cb46f5a..1602ff2 100644 --- a/tcp-cart-client.go +++ b/tcp-cart-client.go @@ -2,6 +2,7 @@ package main import ( "encoding/binary" + "fmt" "net" "time" ) @@ -91,13 +92,15 @@ func (m *CartTCPClient) SendPacket(messageType uint32, id CartId, data []byte) e // } func (m *CartTCPClient) Call(messageType uint32, id CartId, responseType uint32, data []byte) ([]byte, error) { + packetChan := m.Expect(responseType, id) err := m.SendPacket(messageType, id, data) if err != nil { return nil, err } - packet, err := m.Expect(responseType, id, time.Second) - if err != nil { - return nil, err + select { + case ret := <-packetChan: + return ret, nil + case <-time.After(3 * time.Second): + return nil, fmt.Errorf("timeout") } - return packet.Data, nil } diff --git a/tcp-cart-mux-server.go b/tcp-cart-mux-server.go index b1ab0da..2b30966 100644 --- a/tcp-cart-mux-server.go +++ b/tcp-cart-mux-server.go @@ -106,18 +106,22 @@ func (m *TCPCartServerMux) HandleConnection(connection net.Conn) error { if err != nil { log.Printf("Error getting packet data: %v\n", err) } - status, err := m.handleListener(packet.MessageType, packet.Id, data) + go m.HandleData(connection, packet.MessageType, packet.Id, data) + } +} + +func (m *TCPCartServerMux) HandleData(connection net.Conn, t uint32, id CartId, data []byte) { + status, err := m.handleListener(t, id, data) + if err != nil { + log.Printf("Error handling listener: %v\n", err) + } + if !status { + status, err = m.handleFunction(connection, t, id, data) if err != nil { - log.Printf("Error handling listener: %v\n", err) + log.Printf("Error handling function: %v\n", err) } if !status { - status, err = m.handleFunction(connection, packet.MessageType, packet.Id, data) - if err != nil { - log.Printf("Error handling function: %v\n", err) - } - if !status { - log.Printf("Unknown message type: %d\n", packet.MessageType) - } + log.Printf("Unknown message type: %d\n", t) } } } diff --git a/tcp-client.go b/tcp-client.go index 8b253f2..0f18a48 100644 --- a/tcp-client.go +++ b/tcp-client.go @@ -2,6 +2,7 @@ package main import ( "encoding/binary" + "fmt" "net" "time" ) @@ -94,13 +95,16 @@ func (m *TCPClient) SendPacket(messageType uint32, data []byte) error { // } func (m *TCPClient) Call(messageType uint32, responseType uint32, data []byte) ([]byte, error) { + packetChan := m.Expect(responseType) err := m.SendPacket(messageType, data) if err != nil { return nil, err } - packet, err := m.Expect(responseType, time.Second) - if err != nil { - return nil, err + + select { + case ret := <-packetChan: + return ret, nil + case <-time.After(3 * time.Second): + return nil, fmt.Errorf("timeout") } - return packet.Data, nil } diff --git a/tcp-mux-server.go b/tcp-mux-server.go index dc7ce0b..5536375 100644 --- a/tcp-mux-server.go +++ b/tcp-mux-server.go @@ -104,18 +104,24 @@ func (m *TCPServerMux) HandleConnection(connection net.Conn) error { if err != nil { log.Printf("Error receiving packet data: %v\n", err) } - status, err := m.handleListener(packet.MessageType, data) + go m.HandleData(connection, packet.MessageType, data) + } +} + +func (m *TCPServerMux) HandleData(connection net.Conn, t uint32, data []byte) { + // listener := m.listeners[t] + // handler := m.functions[t] + status, err := m.handleListener(t, data) + if err != nil { + log.Printf("Error handling listener: %v\n", err) + } + if !status { + status, err = m.handleFunction(connection, t, data) if err != nil { - log.Printf("Error handling listener: %v\n", err) + log.Printf("Error handling function: %v\n", err) } if !status { - status, err = m.handleFunction(connection, packet.MessageType, data) - if err != nil { - log.Printf("Error handling function: %v\n", err) - } - if !status { - log.Printf("Unknown message type: %d\n", packet.MessageType) - } + log.Printf("Unknown message type: %d\n", t) } } } diff --git a/tcp-mux_test.go b/tcp_test.go similarity index 98% rename from tcp-mux_test.go rename to tcp_test.go index bc9c908..1aafeb1 100644 --- a/tcp-mux_test.go +++ b/tcp_test.go @@ -38,6 +38,7 @@ func TestTcpHelpers(t *testing.T) { if err != nil { t.Errorf("Error calling: %v\n", err) } + client.Close() if string(answer) != "Hello, client!" { t.Errorf("Expected answer 'Hello, client!', got %s\n", string(answer)) }