diff --git a/cart-grain.go b/cart-grain.go index 817a97b..c33c899 100644 --- a/cart-grain.go +++ b/cart-grain.go @@ -3,7 +3,6 @@ package main import ( "encoding/json" "fmt" - "log" "time" messages "git.tornberg.me/go-cart-actor/proto" @@ -63,7 +62,6 @@ func getItemData(sku string) (*messages.AddItem, error) { price := 0 priceField, ok := item.Fields[4] if ok { - priceFloat, ok := priceField.(float64) if !ok { price, ok = priceField.(int) @@ -110,7 +108,6 @@ func (c *CartGrain) GetStorageMessage(since int64) []StorableMessage { } func (c *CartGrain) HandleMessage(message *Message, isReplay bool) ([]byte, error) { - log.Printf("Handling message %d", message.Type) if message.TimeStamp == nil { now := time.Now().Unix() message.TimeStamp = &now @@ -120,14 +117,14 @@ func (c *CartGrain) HandleMessage(message *Message, isReplay bool) ([]byte, erro case AddRequestType: msg, ok := message.Content.(*messages.AddRequest) if !ok { - err = fmt.Errorf("invalid content type") + err = fmt.Errorf("expected AddRequest") } else { return c.AddItem(msg.Sku) } case AddItemType: msg, ok := message.Content.(*messages.AddItem) if !ok { - err = fmt.Errorf("invalid content type") + err = fmt.Errorf("expected AddItem") } else { c.Items = append(c.Items, CartItem{ Sku: msg.Sku, @@ -138,7 +135,7 @@ func (c *CartGrain) HandleMessage(message *Message, isReplay bool) ([]byte, erro c.TotalPrice += msg.Price } default: - err = fmt.Errorf("unknown message type") + err = fmt.Errorf("unknown message type %d", message.Type) } if err != nil { return nil, err diff --git a/data/5.prot b/data/5.prot new file mode 100644 index 0000000..d45fb22 Binary files /dev/null and b/data/5.prot differ diff --git a/data/state.gob b/data/state.gob index b481792..f1608e0 100644 Binary files a/data/state.gob and b/data/state.gob differ diff --git a/data/state.gob.bak b/data/state.gob.bak index f590136..b481792 100644 Binary files a/data/state.gob.bak and b/data/state.gob.bak differ diff --git a/disk-storage.go b/disk-storage.go index ec26330..f1b411c 100644 --- a/disk-storage.go +++ b/disk-storage.go @@ -65,7 +65,7 @@ func loadMessages(grain Grain, id CartId) error { for err == nil { var msg Message - err = MessageFromReader(file, &msg) + err = ReadMessage(file, &msg) if err == nil { grain.HandleMessage(&msg, true) } diff --git a/grain-pool.go b/grain-pool.go index 6d43338..56f22d2 100644 --- a/grain-pool.go +++ b/grain-pool.go @@ -41,6 +41,18 @@ func NewGrainLocalPool(size int, ttl time.Duration, spawn func(id CartId) (*Cart return ret } +func (p *GrainLocalPool) SetAvailable(availableWithLastChangeUnix map[CartId]int64) { + for id := range availableWithLastChangeUnix { + if _, ok := p.grains[id]; !ok { + p.grains[id] = nil + p.expiry = append(p.expiry, Ttl{ + Expires: time.Now().Add(p.Ttl), + Grain: nil, + }) + } + } +} + func (p *GrainLocalPool) Purge() { lastChangeTime := time.Now().Add(-p.Ttl) keepChanged := lastChangeTime.Unix() @@ -69,7 +81,7 @@ func (p *GrainLocalPool) GetGrains() map[CartId]*CartGrain { func (p *GrainLocalPool) GetGrain(id CartId) (*CartGrain, error) { var err error grain, ok := p.grains[id] - if !ok { + if grain == nil || !ok { if len(p.grains) >= p.PoolSize { if p.expiry[0].Expires.Before(time.Now()) { delete(p.grains, p.expiry[0].Grain.GetId()) diff --git a/main.go b/main.go index 0bb6c56..50889b2 100644 --- a/main.go +++ b/main.go @@ -1,7 +1,6 @@ package main import ( - "encoding/json" "log" "net/http" "os" @@ -30,38 +29,11 @@ type App struct { storage *DiskStorage } -func (a *App) HandleGet(w http.ResponseWriter, r *http.Request) { - id := r.PathValue("id") - grain, err := a.pool.Get(ToCartId(id)) - if err != nil { - w.WriteHeader(http.StatusNotFound) - return - } - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(grain) -} - -func (a *App) HandleAddSku(w http.ResponseWriter, r *http.Request) { - id := r.PathValue("id") - sku := r.PathValue("sku") - grain, err := a.pool.Process(ToCartId(id), Message{ - Type: AddRequestType, - Content: &messages.AddRequest{Sku: sku}, - }) - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - w.Write([]byte(err.Error())) - return - } - - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(grain) -} - func (a *App) Save() error { for id, grain := range a.pool.GetGrains() { + if grain == nil { + continue + } err := a.storage.Store(id, grain) if err != nil { log.Printf("Error saving grain %s: %v\n", id, err) @@ -80,6 +52,53 @@ func (a *App) HandleSave(w http.ResponseWriter, r *http.Request) { } } +type PoolServer struct { + pool GrainPool +} + +func NewPoolServer(pool GrainPool) *PoolServer { + return &PoolServer{ + pool: pool, + } +} + +func (s *PoolServer) HandleGet(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + data, err := s.pool.Get(ToCartId(id)) + if err != nil { + w.WriteHeader(http.StatusNotFound) + w.Write([]byte(err.Error())) + return + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write(data) +} + +func (s *PoolServer) HandleAddSku(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + sku := r.PathValue("sku") + data, err := s.pool.Process(ToCartId(id), Message{ + Type: AddRequestType, + Content: &messages.AddRequest{Sku: sku}, + }) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte(err.Error())) + return + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write(data) +} + +func (s *PoolServer) Serve() *http.ServeMux { + mux := http.NewServeMux() + mux.HandleFunc("GET /{id}", s.HandleGet) + mux.HandleFunc("GET /{id}/add/{sku}", s.HandleAddSku) + return mux +} + func main() { // Create a new instance of the server storage, err := NewDiskStorage("data/state.gob") @@ -98,39 +117,11 @@ func main() { go rpcHandler.Serve() remotePool := NewRemoteGrainPool("localhost:1337") - + remoteServer := NewPoolServer(remotePool) + localServer := NewPoolServer(app.pool) mux := http.NewServeMux() - mux.HandleFunc("GET /api/{id}", app.HandleGet) - mux.HandleFunc("GET /api/{id}/add/{sku}", app.HandleAddSku) - mux.HandleFunc("GET /remote/{id}/add", func(w http.ResponseWriter, r *http.Request) { - id := r.PathValue("id") - ts := time.Now().Unix() - data, err := remotePool.Process(ToCartId(id), Message{ - Type: AddRequestType, - TimeStamp: &ts, - Content: &messages.AddRequest{Sku: "49565"}, - }) - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - w.Write([]byte(err.Error())) - return - } - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - w.Write(data) - }) - mux.HandleFunc("GET /remote/{id}", func(w http.ResponseWriter, r *http.Request) { - id := r.PathValue("id") - data, err := remotePool.Get(ToCartId(id)) - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - w.Write([]byte(err.Error())) - return - } - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - w.Write(data) - }) + mux.Handle("/remote/", http.StripPrefix("/remote", remoteServer.Serve())) + mux.Handle("/local/", http.StripPrefix("/local", localServer.Serve())) mux.HandleFunc("GET /save", app.HandleSave) http.ListenAndServe(":8080", mux) diff --git a/message.go b/message.go index fc40015..c143f11 100644 --- a/message.go +++ b/message.go @@ -78,7 +78,7 @@ func (m Message) Write(w io.Writer) error { return err } -func MessageFromReader(reader io.Reader, m *Message) error { +func ReadMessage(reader io.Reader, m *Message) error { header := StorableMessageHeader{} err := binary.Read(reader, binary.LittleEndian, &header) if err != nil { diff --git a/packet.go b/packet.go index b539821..fbb78ea 100644 --- a/packet.go +++ b/packet.go @@ -54,6 +54,16 @@ func SendPacket(conn io.Writer, messageType uint16, datafn func(w io.Writer) err return err } +func SendRawResponse(conn io.Writer, data []byte) error { + binary.Write(conn, binary.LittleEndian, ResponsePacket{ + Version: 1, + MessageType: ResponseBody, + DataLength: uint16(len(data)), + }) + _, err := conn.Write(data) + return err +} + func SendProxyResponse(conn io.Writer, data any) error { return SendPacket(conn, ResponseBody, func(w io.Writer) error { data, err := json.Marshal(data) @@ -65,17 +75,6 @@ func SendProxyResponse(conn io.Writer, data any) error { }) } -// func ReceiveCartPacket(conn io.Reader) (CartPacket, []byte, error) { -// var packet CartPacket -// err := binary.Read(conn, binary.LittleEndian, &packet) -// if err != nil { -// return packet, nil, err -// } -// data := make([]byte, packet.DataLength) -// _, err = conn.Read(data) -// return packet, data, err -// } - func ReceivePacket(conn io.Reader) (uint16, []byte, error) { var packet ResponsePacket err := binary.Read(conn, binary.LittleEndian, &packet) diff --git a/rpc-pool.go b/rpc-pool.go index fd00d89..eb61bf2 100644 --- a/rpc-pool.go +++ b/rpc-pool.go @@ -1,13 +1,14 @@ package main import ( + "fmt" "io" "net" "strings" ) type RemoteGrainPool struct { - Hosts []string + Host string grains map[CartId]RemoteGrain } @@ -46,7 +47,6 @@ func (g *RemoteGrain) Connect() error { } func (g *RemoteGrain) HandleMessage(message *Message, isReplay bool) ([]byte, error) { - err := SendCartPacket(g.client, g.Id, RemoteHandleMessage, message.Write) if err != nil { return nil, err @@ -60,7 +60,6 @@ func (g *RemoteGrain) GetId() CartId { } func (g *RemoteGrain) GetCurrentState() ([]byte, error) { - err := SendCartPacket(g.client, g.Id, RemoteGetState, func(w io.Writer) error { return nil }) @@ -71,9 +70,9 @@ func (g *RemoteGrain) GetCurrentState() ([]byte, error) { return data, err } -func NewRemoteGrainPool(addr ...string) *RemoteGrainPool { +func NewRemoteGrainPool(addr string) *RemoteGrainPool { return &RemoteGrainPool{ - Hosts: addr, + Host: addr, grains: make(map[CartId]RemoteGrain), } } @@ -83,17 +82,26 @@ func (p *RemoteGrainPool) findRemoteGrain(id CartId) *RemoteGrain { if !ok { return nil } + grain.Connect() return &grain } +func (p *RemoteGrainPool) findOrCreateGrain(id CartId) *RemoteGrain { + grain := p.findRemoteGrain(id) + if grain == nil { + grain = NewRemoteGrain(id, p.Host) + p.grains[id] = *grain + grain.Connect() + } + return grain +} + func (p *RemoteGrainPool) Process(id CartId, messages ...Message) ([]byte, error) { var result []byte var err error - grain := p.findRemoteGrain(id) + grain := p.findOrCreateGrain(id) if grain == nil { - grain = NewRemoteGrain(id, p.Hosts[0]) - grain.Connect() - p.grains[id] = *grain + return nil, fmt.Errorf("grain not found") } for _, message := range messages { result, err = grain.HandleMessage(&message, false) @@ -102,9 +110,9 @@ func (p *RemoteGrainPool) Process(id CartId, messages ...Message) ([]byte, error } func (p *RemoteGrainPool) Get(id CartId) ([]byte, error) { - grain := p.findRemoteGrain(id) + grain := p.findOrCreateGrain(id) if grain == nil { - return nil, nil + return nil, fmt.Errorf("grain not found") } return grain.GetCurrentState() } diff --git a/rpc-server.go b/rpc-server.go index dfa0899..5ea341c 100644 --- a/rpc-server.go +++ b/rpc-server.go @@ -64,26 +64,24 @@ func (h *GrainHandler) handleClient(conn net.Conn) { switch packet.MessageType { case RemoteHandleMessage: - fmt.Printf("Handling message\n") var msg Message - err = MessageFromReader(conn, &msg) + err = ReadMessage(conn, &msg) if err != nil { fmt.Println("Error reading message:", err) } - fmt.Printf("Message: %s, %v\n", packet.Id.String(), msg) - grain, err := h.pool.Process(packet.Id, msg) + + data, err := h.pool.Process(packet.Id, msg) if err != nil { fmt.Println("Error handling message:", err) } - SendProxyResponse(conn, grain) + SendRawResponse(conn, data) case RemoteGetState: - fmt.Printf("Package: %s %v\n", packet.Id.String(), packet) - grain, err := h.pool.Get(packet.Id) + data, err := h.pool.Get(packet.Id) if err != nil { fmt.Println("Error getting grain:", err) } - SendProxyResponse(conn, grain) + SendRawResponse(conn, data) } } diff --git a/synced-pool.go b/synced-pool.go new file mode 100644 index 0000000..fccbc50 --- /dev/null +++ b/synced-pool.go @@ -0,0 +1,46 @@ +package main + +type SyncedPool struct { + local *GrainLocalPool + remotes []RemoteGrainPool + remoteIndex map[CartId]*RemoteGrainPool +} + +func NewSyncedPool(local *GrainLocalPool) *SyncedPool { + return &SyncedPool{ + local: local, + remotes: make([]RemoteGrainPool, 0), + remoteIndex: make(map[CartId]*RemoteGrainPool), + } +} + +func (p *SyncedPool) AddRemote(remote RemoteGrainPool) { + p.remotes = append(p.remotes, remote) + // get all available grains from remote, and start syncing +} + +func (p *SyncedPool) Process(id CartId, messages ...Message) ([]byte, error) { + // check if local grain exists + _, ok := p.local.grains[id] + if !ok { + // check if remote grain exists + remoteGrain, ok := p.remoteIndex[id] + if ok { + return remoteGrain.Process(id, messages...) + } + } + return p.local.Process(id, messages...) +} + +func (p *SyncedPool) Get(id CartId) ([]byte, error) { + // check if local grain exists + _, ok := p.local.grains[id] + if !ok { + // check if remote grain exists + remoteGrain, ok := p.remoteIndex[id] + if ok { + return remoteGrain.Get(id) + } + } + return p.local.Get(id) +}