diff --git a/cmd/inventory/main.go b/cmd/inventory/main.go index ad97b91..2747974 100644 --- a/cmd/inventory/main.go +++ b/cmd/inventory/main.go @@ -2,8 +2,9 @@ package main import ( "context" - "fmt" + "log" + "git.tornberg.me/go-cart-actor/pkg/inventory" "github.com/redis/go-redis/v9" "github.com/redis/go-redis/v9/maintnotifications" ) @@ -18,24 +19,36 @@ func main() { Mode: maintnotifications.ModeDisabled, }, }) - - err := rdb.Set(ctx, "key", "value", 0).Err() + s, err := inventory.NewRedisInventoryService(rdb, ctx) if err != nil { - panic(err) + log.Fatalf("Unable to connect to inventory redis", err) + return } - - val, err := rdb.Get(ctx, "key").Result() + rdb.Pipelined(ctx, func(p redis.Pipeliner) error { + s.UpdateInventory(p, "1", "1", 10) + s.UpdateInventory(p, "2", "2", 20) + s.UpdateInventory(p, "3", "3", 30) + s.UpdateInventory(p, "4", "4", 40) + return nil + }) + err = s.ReserveInventory(inventory.ReserveRequest{ + SKU: "1", + LocationID: "1", + Quantity: 3, + }, inventory.ReserveRequest{ + SKU: "2", + LocationID: "2", + Quantity: 15, + }) if err != nil { - panic(err) + log.Printf("Unable to reserve inventory: %v", err) + return } - fmt.Println("key", val) + v, err := s.GetInventory("1", "1") + if err != nil { + log.Printf("Unable to get inventory: %v", err) + return + } + log.Printf("Inventory after reservation: %v", v) - val2, err := rdb.Get(ctx, "key2").Result() - if err == redis.Nil { - fmt.Println("key2 does not exist") - } else if err != nil { - panic(err) - } else { - fmt.Println("key2", val2) - } } diff --git a/pkg/inventory/redis_service.go b/pkg/inventory/redis_service.go index 0d8e944..24874e9 100644 --- a/pkg/inventory/redis_service.go +++ b/pkg/inventory/redis_service.go @@ -2,8 +2,9 @@ package inventory import ( "context" - "encoding/json" "errors" + "fmt" + "strconv" "github.com/redis/go-redis/v9" ) @@ -56,52 +57,124 @@ func (s *RedisInventoryService) AddWarehouse(warehouse *Warehouse) error { return err } -func (s *RedisInventoryService) GetInventory(sku SKU, locationID LocationID) (*InventoryItem, error) { - // Get the warehouse from Redis - key := "warehouse:" + string(locationID) - result, err := s.client.HGetAll(s.ctx, key).Result() +func (s *RedisInventoryService) GetInventory(sku SKU, locationID LocationID) (int64, error) { + + cmd := s.client.Get(s.ctx, getInventoryKey(sku, locationID)) + if err := cmd.Err(); err != nil { + return 0, err + } + + i, err := cmd.Int64() if err != nil { - return nil, err + return 0, err } - // Parse the inventory items - var inventoryItems []InventoryItem - for _, itemData := range result { - var item InventoryItem - if err := json.Unmarshal([]byte(itemData), &item); err == nil { - inventoryItems = append(inventoryItems, item) - } - } - - // Find the requested SKU - for _, item := range inventoryItems { - if item.SKU == sku { - return &item, nil - } - } - - return nil, errors.New("sku not found in warehouse") + return i, nil } -func (s *RedisInventoryService) ReserveInventory(req ReserveRequest) error { +func getInventoryKey(sku SKU, locationID LocationID) string { + return fmt.Sprintf("inventory:%s:%s", sku, locationID) +} + +func (s *RedisInventoryService) UpdateInventory(rdb redis.Pipeliner, sku SKU, locationID LocationID, quantity int64) error { + key := getInventoryKey(sku, locationID) + cmd := rdb.Set(s.ctx, key, quantity, 0) + return cmd.Err() +} + +var ( + ErrInsufficientInventory = errors.New("insufficient inventory") + ErrInvalidQuantity = errors.New("invalid quantity") + ErrMissingReservation = errors.New("missing reservation") +) + +func (s *RedisInventoryService) ReserveInventory(req ...ReserveRequest) error { + if len(req) == 0 { + return ErrMissingReservation + } + + keys := make([]string, len(req)) + args := make([]string, len(req)) + for i, r := range req { + if r.Quantity <= 0 { + return ErrInvalidQuantity + } + keys[i] = getInventoryKey(r.SKU, r.LocationID) + args[i] = strconv.Itoa(int(r.Quantity)) + } + cmd := reserveScript.Run(s.ctx, s.client, keys, args) + if err := cmd.Err(); err != nil { + return err + } + if val, err := cmd.Int(); err != nil { + return err + } else if val != 1 { + return ErrInsufficientInventory + } return nil - // Get the Lua script from Redis - // key := "lua:reserve_inventory" - // script, err := s.client.Get(s.ctx, key).Result() - // if err != nil { - // return err - // } - - // luaScript := redis.NewScript(script) - - // // Prepare arguments for the Lua script - // args := []interface{}{ - // string(req.LocationID), - // string(req.SKU), - // req.Quantity, - // } - - // // Execute the Lua script - // cmd := s.client.Eval(s.ctx, luaScript, len(args), args...) - //return err } + +var reserveScript = redis.NewScript(` +-- Get the number of keys passed +local num_keys = #KEYS + +-- Ensure the number of keys matches the number of quantities +if num_keys ~= #ARGV then + return {err = "Script requires the same number of keys and quantities."} +end + +local new_values = {} +local payload = {} + +-- --- +-- 1. CHECK PHASE +-- --- +-- Loop through all keys to check their values first +for i = 1, num_keys do + local key = KEYS[i] + local quantity_to_check = tonumber(ARGV[i]) + + -- Fail if the quantity is not a valid number + if not quantity_to_check then + return {err = "Invalid quantity provided for key: " .. key} + end + + -- Get the current value stored at the key + local current_val = tonumber(redis.call('GET', key)) + + -- Check the condition + -- Fail if: + -- 1. The key doesn't exist (current_val is nil) + -- 2. The value is not > the required quantity + if not current_val or current_val <= quantity_to_check then + -- Return 0 to indicate the operation failed and no changes were made + return 0 + end + + -- If the check passes, store the new value + local new_val = current_val - quantity_to_check + table.insert(new_values, new_val) + + -- Add this key and its *new* value to our payload map + payload[key] = new_val +end + +-- --- +-- 2. UPDATE PHASE +-- --- +-- If the script reaches this point, all checks passed. +-- Now, loop again and apply all the updates. +for i = 1, num_keys do + local key = KEYS[i] + local new_val = new_values[i] + + -- Set the key to its new calculated value + redis.call('SET', key, new_val) +end +local message_payload = cjson.encode(payload) + +-- Publish the JSON-encoded message to the specified channel +redis.call('PUBLISH', "inventory_changed", message_payload) +-- Return 1 to indicate the operation was successful +return 1 +`) diff --git a/pkg/inventory/types.go b/pkg/inventory/types.go index 80ea8be..4ae6be5 100644 --- a/pkg/inventory/types.go +++ b/pkg/inventory/types.go @@ -19,7 +19,7 @@ type Warehouse struct { } type InventoryService interface { - GetInventory(sku SKU, locationID LocationID) (uint32, error) + GetInventory(sku SKU, locationID LocationID) (int64, error) ReserveInventory(req ...ReserveRequest) error }