refactor/serializing #1

Merged
mats merged 4 commits from refactor/serializing into main 2024-11-08 23:02:12 +01:00
12 changed files with 164 additions and 87 deletions
Showing only changes of commit 06ee7b1a27 - Show all commits

View File

@@ -13,6 +13,20 @@ import (
type CartId [16]byte type CartId [16]byte
func (id CartId) MarshalJSON() ([]byte, error) {
return json.Marshal(id.String())
}
func (id *CartId) UnmarshalJSON(data []byte) error {
var str string
err := json.Unmarshal(data, &str)
if err != nil {
return err
}
copy(id[:], []byte(str))
return nil
}
type CartItem struct { type CartItem struct {
Sku string `json:"sku"` Sku string `json:"sku"`
Name string `json:"name"` Name string `json:"name"`

BIN
data/1.prot Normal file

Binary file not shown.

BIN
data/4.prot Normal file

Binary file not shown.

BIN
data/state.gob Normal file

Binary file not shown.

BIN
data/state.gob.bak Normal file

Binary file not shown.

View File

@@ -1,7 +1,7 @@
package main package main
import ( import (
"encoding/json" "encoding/gob"
"errors" "errors"
"fmt" "fmt"
"log" "log"
@@ -84,7 +84,7 @@ func (s *DiskStorage) saveState() error {
return err return err
} }
defer file.Close() defer file.Close()
err = json.NewEncoder(file).Encode(s.LastSaves) err = gob.NewEncoder(file).Encode(s.LastSaves)
if err != nil { if err != nil {
return err return err
} }
@@ -94,12 +94,12 @@ func (s *DiskStorage) saveState() error {
} }
func (s *DiskStorage) loadState() error { func (s *DiskStorage) loadState() error {
file, err := os.Open("data/state.json") file, err := os.Open(s.stateFile)
if err != nil { if err != nil {
return err return err
} }
defer file.Close() defer file.Close()
return json.NewDecoder(file).Decode(&s.LastSaves) return gob.NewDecoder(file).Decode(&s.LastSaves)
} }
func (s *DiskStorage) Store(id CartId, grain *CartGrain) error { func (s *DiskStorage) Store(id CartId, grain *CartGrain) error {

View File

@@ -65,7 +65,7 @@ func (p *GrainLocalPool) GetGrains() map[CartId]*CartGrain {
return p.grains return p.grains
} }
func (p *GrainLocalPool) Process(id CartId, messages ...Message) (interface{}, error) { func (p *GrainLocalPool) GetGrain(id CartId) (*CartGrain, error) {
var err error var err error
grain, ok := p.grains[id] grain, ok := p.grains[id]
if !ok { if !ok {
@@ -81,6 +81,11 @@ func (p *GrainLocalPool) Process(id CartId, messages ...Message) (interface{}, e
p.grains[id] = grain p.grains[id] = grain
} }
return grain, err
}
func (p *GrainLocalPool) Process(id CartId, messages ...Message) (interface{}, error) {
grain, err := p.GetGrain(id)
if err == nil && grain != nil { if err == nil && grain != nil {
for _, message := range messages { for _, message := range messages {
_, err = grain.HandleMessage(&message, false) _, err = grain.HandleMessage(&message, false)
@@ -90,9 +95,5 @@ func (p *GrainLocalPool) Process(id CartId, messages ...Message) (interface{}, e
} }
func (p *GrainLocalPool) Get(id CartId) (Grain, error) { func (p *GrainLocalPool) Get(id CartId) (Grain, error) {
grain, ok := p.grains[id] return p.GetGrain(id)
if !ok {
return nil, fmt.Errorf("grain not found")
}
return grain, nil
} }

18
main.go
View File

@@ -82,7 +82,7 @@ func (a *App) HandleSave(w http.ResponseWriter, r *http.Request) {
func main() { func main() {
// Create a new instance of the server // Create a new instance of the server
storage, err := NewDiskStorage("data/state.json") storage, err := NewDiskStorage("data/state.gob")
if err != nil { if err != nil {
log.Printf("Error loading state: %v\n", err) log.Printf("Error loading state: %v\n", err)
} }
@@ -102,7 +102,7 @@ func main() {
mux := http.NewServeMux() mux := http.NewServeMux()
mux.HandleFunc("GET /api/{id}", app.HandleGet) mux.HandleFunc("GET /api/{id}", app.HandleGet)
mux.HandleFunc("GET /api/{id}/add/{sku}", app.HandleAddSku) mux.HandleFunc("GET /api/{id}/add/{sku}", app.HandleAddSku)
mux.HandleFunc("GET /remote/{id}", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("GET /remote/{id}/add", func(w http.ResponseWriter, r *http.Request) {
id := r.PathValue("id") id := r.PathValue("id")
ts := time.Now().Unix() ts := time.Now().Unix()
data, err := remotePool.Process(ToCartId(id), Message{ data, err := remotePool.Process(ToCartId(id), Message{
@@ -117,7 +117,19 @@ func main() {
} }
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(data) w.Write(data)
})
mux.HandleFunc("GET /remote/{id}", func(w http.ResponseWriter, r *http.Request) {
id := r.PathValue("id")
data, err := remotePool.Get(ToCartId(id))
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(err.Error()))
return
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write(data)
}) })
mux.HandleFunc("GET /save", app.HandleSave) mux.HandleFunc("GET /save", app.HandleSave)
http.ListenAndServe(":8080", mux) http.ListenAndServe(":8080", mux)

View File

@@ -42,42 +42,6 @@ func GetData(fn func(w io.Writer) error) ([]byte, error) {
return b, nil return b, nil
} }
// func (w *MessageWriter) WriteUint64(value uint64) (int, error) {
// bytes := make([]byte, 8)
// binary.LittleEndian.PutUint64(bytes, value)
// return w.Write(bytes)
// }
// func (w *MessageWriter) WriteInt64(value int64) (int, error) {
// return w.WriteUint64(uint64(value))
// }
// func (w *MessageWriter) WriteMessage(m *Message) (int, error) {
// var i, l int
// var err error
// i, err = w.WriteUint64(m.Type)
// l += i
// i, err = w.WriteInt64(*m.TimeStamp)
// l += i
// var messageBytes []byte
// var err error
// if m.Type == AddRequestType {
// messageBytes, err = proto.Marshal(m.Content.(*messages.AddRequest))
// } else if m.Type == AddItemType {
// messageBytes, err = proto.Marshal(m.Content.(*messages.AddItem))
// } else {
// return fmt.Errorf("unknown message type")
// }
// if err != nil {
// return err
// }
// if err := w.WriteUint64(uint64(len(messageBytes))); err != nil {
// return err
// }
// _, err = w.Write(messageBytes)
// return err
// }
func (m Message) Write(w io.Writer) error { func (m Message) Write(w io.Writer) error {
data, err := GetData(func(wr io.Writer) error { data, err := GetData(func(wr io.Writer) error {
if m.Type == AddRequestType { if m.Type == AddRequestType {
@@ -140,6 +104,8 @@ func MessageFromReader(reader io.Reader, m *Message) error {
if err != nil { if err != nil {
return err return err
} }
m.Type = header.Type
m.TimeStamp = &header.TimeStamp
return nil return nil
} }

76
packet.go Normal file
View File

@@ -0,0 +1,76 @@
package main
import (
"encoding/binary"
"io"
)
const (
RemoteGetState = uint16(0x01)
RemoteHandleMessage = uint16(0x02)
ResponseBody = uint16(0x03)
)
type CartPacket struct {
Version uint16
MessageType uint16
Id CartId
DataLength uint16
}
type ResponsePacket struct {
Version uint16
MessageType uint16
DataLength uint16
}
func SendCartPacket(conn io.Writer, id CartId, messageType uint16, datafn func(w io.Writer) error) error {
data, err := GetData(datafn)
if err != nil {
return err
}
binary.Write(conn, binary.LittleEndian, CartPacket{
Version: 2,
MessageType: messageType,
Id: id,
DataLength: uint16(len(data)),
})
_, err = conn.Write(data)
return err
}
func SendPacket(conn io.Writer, messageType uint16, datafn func(w io.Writer) error) error {
data, err := GetData(datafn)
if err != nil {
return err
}
binary.Write(conn, binary.LittleEndian, ResponsePacket{
Version: 1,
MessageType: messageType,
DataLength: uint16(len(data)),
})
_, err = conn.Write(data)
return err
}
// func ReceiveCartPacket(conn io.Reader) (CartPacket, []byte, error) {
// var packet CartPacket
// err := binary.Read(conn, binary.LittleEndian, &packet)
// if err != nil {
// return packet, nil, err
// }
// data := make([]byte, packet.DataLength)
// _, err = conn.Read(data)
// return packet, data, err
// }
func ReceivePacket(conn io.Reader) (uint16, []byte, error) {
var packet ResponsePacket
err := binary.Read(conn, binary.LittleEndian, &packet)
if err != nil {
return packet.MessageType, nil, err
}
data := make([]byte, packet.DataLength)
_, err = conn.Read(data)
return packet.MessageType, data, err
}

View File

@@ -1,13 +1,9 @@
package main package main
import ( import (
"encoding/binary" "io"
"net" "net"
) "strings"
const (
RemoteGetState = uint16(0x01)
RemoteHandleMessage = uint16(0x02)
) )
type RemoteGrainPool struct { type RemoteGrainPool struct {
@@ -16,7 +12,7 @@ type RemoteGrainPool struct {
} }
func (id CartId) String() string { func (id CartId) String() string {
return string(id[:]) return strings.Trim(string(id[:]), "\x00")
} }
func ToCartId(id string) CartId { func ToCartId(id string) CartId {
@@ -49,46 +45,30 @@ func (g *RemoteGrain) Connect() error {
return nil return nil
} }
type Packet struct {
Version uint16
MessageType uint16
Id CartId
DataLength uint16
}
func (g *RemoteGrain) SendPacket(messageType uint16, data []byte) error {
binary.Write(g.client, binary.LittleEndian, Packet{
Version: 2,
MessageType: messageType,
Id: g.Id,
DataLength: uint16(len(data)),
})
return binary.Write(g.client, binary.LittleEndian, data)
}
func (g *RemoteGrain) HandleMessage(message *Message, isReplay bool) ([]byte, error) { func (g *RemoteGrain) HandleMessage(message *Message, isReplay bool) ([]byte, error) {
data, err := GetData(message.Write)
err := SendCartPacket(g.client, g.Id, RemoteHandleMessage, message.Write)
if err != nil { if err != nil {
return nil, err return nil, err
} }
err = g.SendPacket(RemoteHandleMessage, data) _, data, err := ReceivePacket(g.client)
result := make([]byte, 65535) return data, err
g.client.Read(result)
return result, err
} }
func (g *RemoteGrain) GetId() CartId { func (g *RemoteGrain) GetId() CartId {
return g.Id return g.Id
} }
func (g *RemoteGrain) GetCurrentState() (Grain, error) { func (g *RemoteGrain) GetCurrentState() ([]byte, error) {
var reply CartGrain err := SendCartPacket(g.client, g.Id, RemoteGetState, func(w io.Writer) error {
err := g.SendPacket(RemoteGetState, nil) return nil
})
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &reply, err _, data, err := ReceivePacket(g.client)
return data, err
} }
func NewRemoteGrainPool(addr ...string) *RemoteGrainPool { func NewRemoteGrainPool(addr ...string) *RemoteGrainPool {
@@ -106,8 +86,8 @@ func (p *RemoteGrainPool) findRemoteGrain(id CartId) *RemoteGrain {
return &grain return &grain
} }
func (p *RemoteGrainPool) Process(id CartId, messages ...Message) (interface{}, error) { func (p *RemoteGrainPool) Process(id CartId, messages ...Message) ([]byte, error) {
var result interface{} var result []byte
var err error var err error
grain := p.findRemoteGrain(id) grain := p.findRemoteGrain(id)
if grain == nil { if grain == nil {
@@ -121,7 +101,7 @@ func (p *RemoteGrainPool) Process(id CartId, messages ...Message) (interface{},
return result, err return result, err
} }
func (p *RemoteGrainPool) Get(id CartId) (Grain, error) { func (p *RemoteGrainPool) Get(id CartId) ([]byte, error) {
grain := p.findRemoteGrain(id) grain := p.findRemoteGrain(id)
if grain == nil { if grain == nil {
return nil, nil return nil, nil

View File

@@ -2,6 +2,7 @@ package main
import ( import (
"encoding/binary" "encoding/binary"
"encoding/json"
"fmt" "fmt"
"io" "io"
"net" "net"
@@ -49,7 +50,7 @@ func (h *GrainHandler) handleClient(conn net.Conn) {
fmt.Println("Handling client connection") fmt.Println("Handling client connection")
defer conn.Close() defer conn.Close()
var packet Packet var packet CartPacket
for { for {
for { for {
@@ -74,10 +75,37 @@ func (h *GrainHandler) handleClient(conn net.Conn) {
fmt.Println("Error reading message:", err) fmt.Println("Error reading message:", err)
} }
fmt.Printf("Message: %s, %v\n", packet.Id.String(), msg) fmt.Printf("Message: %s, %v\n", packet.Id.String(), msg)
grain, err := h.pool.Get(packet.Id)
if err != nil {
fmt.Println("Error getting grain:", err)
}
_, err = grain.HandleMessage(&msg, false)
if err != nil {
fmt.Println("Error handling message:", err)
}
SendPacket(conn, ResponseBody, func(w io.Writer) error {
data, err := json.Marshal(grain)
if err != nil {
return err
}
w.Write(data)
return nil
})
case RemoteGetState: case RemoteGetState:
fmt.Printf("Package: %s %v\n", packet.Id.String(), packet) fmt.Printf("Package: %s %v\n", packet.Id.String(), packet)
grain, err := h.pool.Get(packet.Id)
if err != nil {
fmt.Println("Error getting grain:", err)
}
SendPacket(conn, ResponseBody, func(w io.Writer) error {
data, err := json.Marshal(grain)
if err != nil {
return err
}
w.Write(data)
return nil
})
} }
} }