diff --git a/cart-grain.go b/cart-grain.go index 50ab0a0..67b1814 100644 --- a/cart-grain.go +++ b/cart-grain.go @@ -13,6 +13,20 @@ import ( type CartId [16]byte +func (id CartId) MarshalJSON() ([]byte, error) { + return json.Marshal(id.String()) +} + +func (id *CartId) UnmarshalJSON(data []byte) error { + var str string + err := json.Unmarshal(data, &str) + if err != nil { + return err + } + copy(id[:], []byte(str)) + return nil +} + type CartItem struct { Sku string `json:"sku"` Name string `json:"name"` diff --git a/data/1.prot b/data/1.prot new file mode 100644 index 0000000..9557f27 Binary files /dev/null and b/data/1.prot differ diff --git a/data/4.prot b/data/4.prot new file mode 100644 index 0000000..927516f Binary files /dev/null and b/data/4.prot differ diff --git a/data/state.gob b/data/state.gob new file mode 100644 index 0000000..b481792 Binary files /dev/null and b/data/state.gob differ diff --git a/data/state.gob.bak b/data/state.gob.bak new file mode 100644 index 0000000..f590136 Binary files /dev/null and b/data/state.gob.bak differ diff --git a/disk-storage.go b/disk-storage.go index d230ec7..ec26330 100644 --- a/disk-storage.go +++ b/disk-storage.go @@ -1,7 +1,7 @@ package main import ( - "encoding/json" + "encoding/gob" "errors" "fmt" "log" @@ -84,7 +84,7 @@ func (s *DiskStorage) saveState() error { return err } defer file.Close() - err = json.NewEncoder(file).Encode(s.LastSaves) + err = gob.NewEncoder(file).Encode(s.LastSaves) if err != nil { return err } @@ -94,12 +94,12 @@ func (s *DiskStorage) saveState() error { } func (s *DiskStorage) loadState() error { - file, err := os.Open("data/state.json") + file, err := os.Open(s.stateFile) if err != nil { return err } defer file.Close() - return json.NewDecoder(file).Decode(&s.LastSaves) + return gob.NewDecoder(file).Decode(&s.LastSaves) } func (s *DiskStorage) Store(id CartId, grain *CartGrain) error { diff --git a/grain-pool.go b/grain-pool.go index b331948..448af04 100644 --- a/grain-pool.go +++ b/grain-pool.go @@ -65,7 +65,7 @@ func (p *GrainLocalPool) GetGrains() map[CartId]*CartGrain { return p.grains } -func (p *GrainLocalPool) Process(id CartId, messages ...Message) (interface{}, error) { +func (p *GrainLocalPool) GetGrain(id CartId) (*CartGrain, error) { var err error grain, ok := p.grains[id] if !ok { @@ -81,6 +81,11 @@ func (p *GrainLocalPool) Process(id CartId, messages ...Message) (interface{}, e p.grains[id] = grain } + return grain, err +} + +func (p *GrainLocalPool) Process(id CartId, messages ...Message) (interface{}, error) { + grain, err := p.GetGrain(id) if err == nil && grain != nil { for _, message := range messages { _, err = grain.HandleMessage(&message, false) @@ -90,9 +95,5 @@ func (p *GrainLocalPool) Process(id CartId, messages ...Message) (interface{}, e } func (p *GrainLocalPool) Get(id CartId) (Grain, error) { - grain, ok := p.grains[id] - if !ok { - return nil, fmt.Errorf("grain not found") - } - return grain, nil + return p.GetGrain(id) } diff --git a/main.go b/main.go index ee351e9..0bb6c56 100644 --- a/main.go +++ b/main.go @@ -82,7 +82,7 @@ func (a *App) HandleSave(w http.ResponseWriter, r *http.Request) { func main() { // Create a new instance of the server - storage, err := NewDiskStorage("data/state.json") + storage, err := NewDiskStorage("data/state.gob") if err != nil { log.Printf("Error loading state: %v\n", err) } @@ -102,7 +102,7 @@ func main() { mux := http.NewServeMux() mux.HandleFunc("GET /api/{id}", app.HandleGet) mux.HandleFunc("GET /api/{id}/add/{sku}", app.HandleAddSku) - mux.HandleFunc("GET /remote/{id}", func(w http.ResponseWriter, r *http.Request) { + mux.HandleFunc("GET /remote/{id}/add", func(w http.ResponseWriter, r *http.Request) { id := r.PathValue("id") ts := time.Now().Unix() data, err := remotePool.Process(ToCartId(id), Message{ @@ -117,7 +117,19 @@ func main() { } w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(data) + w.Write(data) + }) + mux.HandleFunc("GET /remote/{id}", func(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + data, err := remotePool.Get(ToCartId(id)) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte(err.Error())) + return + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write(data) }) mux.HandleFunc("GET /save", app.HandleSave) http.ListenAndServe(":8080", mux) diff --git a/message.go b/message.go index d38ab27..fc40015 100644 --- a/message.go +++ b/message.go @@ -42,42 +42,6 @@ func GetData(fn func(w io.Writer) error) ([]byte, error) { return b, nil } -// func (w *MessageWriter) WriteUint64(value uint64) (int, error) { -// bytes := make([]byte, 8) -// binary.LittleEndian.PutUint64(bytes, value) -// return w.Write(bytes) -// } - -// func (w *MessageWriter) WriteInt64(value int64) (int, error) { -// return w.WriteUint64(uint64(value)) -// } - -// func (w *MessageWriter) WriteMessage(m *Message) (int, error) { -// var i, l int -// var err error -// i, err = w.WriteUint64(m.Type) -// l += i -// i, err = w.WriteInt64(*m.TimeStamp) -// l += i -// var messageBytes []byte -// var err error -// if m.Type == AddRequestType { -// messageBytes, err = proto.Marshal(m.Content.(*messages.AddRequest)) -// } else if m.Type == AddItemType { -// messageBytes, err = proto.Marshal(m.Content.(*messages.AddItem)) -// } else { -// return fmt.Errorf("unknown message type") -// } -// if err != nil { -// return err -// } -// if err := w.WriteUint64(uint64(len(messageBytes))); err != nil { -// return err -// } -// _, err = w.Write(messageBytes) -// return err -// } - func (m Message) Write(w io.Writer) error { data, err := GetData(func(wr io.Writer) error { if m.Type == AddRequestType { @@ -140,6 +104,8 @@ func MessageFromReader(reader io.Reader, m *Message) error { if err != nil { return err } + m.Type = header.Type + m.TimeStamp = &header.TimeStamp return nil } diff --git a/packet.go b/packet.go new file mode 100644 index 0000000..9d047fa --- /dev/null +++ b/packet.go @@ -0,0 +1,76 @@ +package main + +import ( + "encoding/binary" + "io" +) + +const ( + RemoteGetState = uint16(0x01) + RemoteHandleMessage = uint16(0x02) + ResponseBody = uint16(0x03) +) + +type CartPacket struct { + Version uint16 + MessageType uint16 + Id CartId + DataLength uint16 +} + +type ResponsePacket struct { + Version uint16 + MessageType uint16 + DataLength uint16 +} + +func SendCartPacket(conn io.Writer, id CartId, messageType uint16, datafn func(w io.Writer) error) error { + data, err := GetData(datafn) + if err != nil { + return err + } + binary.Write(conn, binary.LittleEndian, CartPacket{ + Version: 2, + MessageType: messageType, + Id: id, + DataLength: uint16(len(data)), + }) + _, err = conn.Write(data) + return err +} + +func SendPacket(conn io.Writer, messageType uint16, datafn func(w io.Writer) error) error { + data, err := GetData(datafn) + if err != nil { + return err + } + binary.Write(conn, binary.LittleEndian, ResponsePacket{ + Version: 1, + MessageType: messageType, + DataLength: uint16(len(data)), + }) + _, err = conn.Write(data) + return err +} + +// func ReceiveCartPacket(conn io.Reader) (CartPacket, []byte, error) { +// var packet CartPacket +// err := binary.Read(conn, binary.LittleEndian, &packet) +// if err != nil { +// return packet, nil, err +// } +// data := make([]byte, packet.DataLength) +// _, err = conn.Read(data) +// return packet, data, err +// } + +func ReceivePacket(conn io.Reader) (uint16, []byte, error) { + var packet ResponsePacket + err := binary.Read(conn, binary.LittleEndian, &packet) + if err != nil { + return packet.MessageType, nil, err + } + data := make([]byte, packet.DataLength) + _, err = conn.Read(data) + return packet.MessageType, data, err +} diff --git a/rpc-pool.go b/rpc-pool.go index 9232418..fd00d89 100644 --- a/rpc-pool.go +++ b/rpc-pool.go @@ -1,13 +1,9 @@ package main import ( - "encoding/binary" + "io" "net" -) - -const ( - RemoteGetState = uint16(0x01) - RemoteHandleMessage = uint16(0x02) + "strings" ) type RemoteGrainPool struct { @@ -16,7 +12,7 @@ type RemoteGrainPool struct { } func (id CartId) String() string { - return string(id[:]) + return strings.Trim(string(id[:]), "\x00") } func ToCartId(id string) CartId { @@ -49,46 +45,30 @@ func (g *RemoteGrain) Connect() error { return nil } -type Packet struct { - Version uint16 - MessageType uint16 - Id CartId - DataLength uint16 -} - -func (g *RemoteGrain) SendPacket(messageType uint16, data []byte) error { - binary.Write(g.client, binary.LittleEndian, Packet{ - Version: 2, - MessageType: messageType, - Id: g.Id, - DataLength: uint16(len(data)), - }) - return binary.Write(g.client, binary.LittleEndian, data) -} - func (g *RemoteGrain) HandleMessage(message *Message, isReplay bool) ([]byte, error) { - data, err := GetData(message.Write) + + err := SendCartPacket(g.client, g.Id, RemoteHandleMessage, message.Write) if err != nil { return nil, err } - err = g.SendPacket(RemoteHandleMessage, data) - result := make([]byte, 65535) - g.client.Read(result) - return result, err + _, data, err := ReceivePacket(g.client) + return data, err } func (g *RemoteGrain) GetId() CartId { return g.Id } -func (g *RemoteGrain) GetCurrentState() (Grain, error) { +func (g *RemoteGrain) GetCurrentState() ([]byte, error) { - var reply CartGrain - err := g.SendPacket(RemoteGetState, nil) + err := SendCartPacket(g.client, g.Id, RemoteGetState, func(w io.Writer) error { + return nil + }) if err != nil { return nil, err } - return &reply, err + _, data, err := ReceivePacket(g.client) + return data, err } func NewRemoteGrainPool(addr ...string) *RemoteGrainPool { @@ -106,8 +86,8 @@ func (p *RemoteGrainPool) findRemoteGrain(id CartId) *RemoteGrain { return &grain } -func (p *RemoteGrainPool) Process(id CartId, messages ...Message) (interface{}, error) { - var result interface{} +func (p *RemoteGrainPool) Process(id CartId, messages ...Message) ([]byte, error) { + var result []byte var err error grain := p.findRemoteGrain(id) if grain == nil { @@ -121,7 +101,7 @@ func (p *RemoteGrainPool) Process(id CartId, messages ...Message) (interface{}, return result, err } -func (p *RemoteGrainPool) Get(id CartId) (Grain, error) { +func (p *RemoteGrainPool) Get(id CartId) ([]byte, error) { grain := p.findRemoteGrain(id) if grain == nil { return nil, nil diff --git a/rpc-server.go b/rpc-server.go index 2d479ec..fbb38c0 100644 --- a/rpc-server.go +++ b/rpc-server.go @@ -2,6 +2,7 @@ package main import ( "encoding/binary" + "encoding/json" "fmt" "io" "net" @@ -49,7 +50,7 @@ func (h *GrainHandler) handleClient(conn net.Conn) { fmt.Println("Handling client connection") defer conn.Close() - var packet Packet + var packet CartPacket for { for { @@ -74,10 +75,37 @@ func (h *GrainHandler) handleClient(conn net.Conn) { fmt.Println("Error reading message:", err) } fmt.Printf("Message: %s, %v\n", packet.Id.String(), msg) + grain, err := h.pool.Get(packet.Id) + if err != nil { + fmt.Println("Error getting grain:", err) + } + _, err = grain.HandleMessage(&msg, false) + if err != nil { + fmt.Println("Error handling message:", err) + } + SendPacket(conn, ResponseBody, func(w io.Writer) error { + data, err := json.Marshal(grain) + if err != nil { + return err + } + w.Write(data) + return nil + }) case RemoteGetState: fmt.Printf("Package: %s %v\n", packet.Id.String(), packet) - + grain, err := h.pool.Get(packet.Id) + if err != nil { + fmt.Println("Error getting grain:", err) + } + SendPacket(conn, ResponseBody, func(w io.Writer) error { + data, err := json.Marshal(grain) + if err != nil { + return err + } + w.Write(data) + return nil + }) } }