252 lines
6.1 KiB
Go
252 lines
6.1 KiB
Go
package inventory
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"strconv"
|
|
|
|
"github.com/redis/go-redis/v9"
|
|
)
|
|
|
|
type RedisInventoryService struct {
|
|
client *redis.Client
|
|
ctx context.Context
|
|
luaScripts map[string]*redis.Script
|
|
}
|
|
|
|
func NewRedisInventoryService(client *redis.Client, ctx context.Context) (*RedisInventoryService, error) {
|
|
rdb := client
|
|
|
|
// Ping Redis to check connection
|
|
_, err := rdb.Ping(ctx).Result()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &RedisInventoryService{
|
|
client: rdb,
|
|
ctx: ctx,
|
|
luaScripts: make(map[string]*redis.Script),
|
|
}, nil
|
|
}
|
|
|
|
func (s *RedisInventoryService) LoadLuaScript(key string) error {
|
|
// Get the script from Redis
|
|
script, err := s.client.Get(s.ctx, key).Result()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Load the script into the luaScripts cache
|
|
s.luaScripts[key] = redis.NewScript(script)
|
|
return nil
|
|
}
|
|
|
|
func (s *RedisInventoryService) AddWarehouse(warehouse *Warehouse) error {
|
|
// Convert warehouse to Redis-friendly format
|
|
data := map[string]interface{}{
|
|
"id": string(warehouse.ID),
|
|
"name": warehouse.Name,
|
|
"inventory": warehouse.Inventory,
|
|
}
|
|
|
|
// Store in Redis with a key pattern like "warehouse:<ID>"
|
|
key := "warehouse:" + string(warehouse.ID)
|
|
_, err := s.client.HMSet(s.ctx, key, data).Result()
|
|
return err
|
|
}
|
|
|
|
func (s *RedisInventoryService) GetInventory(sku SKU, locationID LocationID) (int64, error) {
|
|
|
|
cmd := s.client.Get(s.ctx, getInventoryKey(sku, locationID))
|
|
if err := cmd.Err(); err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
i, err := cmd.Int64()
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
return i, nil
|
|
}
|
|
|
|
func getInventoryKey(sku SKU, locationID LocationID) string {
|
|
return fmt.Sprintf("inventory:%s:%s", sku, locationID)
|
|
}
|
|
|
|
func (s *RedisInventoryService) UpdateInventory(rdb redis.Pipeliner, sku SKU, locationID LocationID, quantity int64) error {
|
|
key := getInventoryKey(sku, locationID)
|
|
cmd := rdb.Set(s.ctx, key, quantity, 0)
|
|
return cmd.Err()
|
|
}
|
|
|
|
var (
|
|
ErrInsufficientInventory = errors.New("insufficient inventory")
|
|
ErrInvalidQuantity = errors.New("invalid quantity")
|
|
ErrMissingReservation = errors.New("missing reservation")
|
|
)
|
|
|
|
func makeKeysAndArgs(req ...ReserveRequest) ([]string, []string) {
|
|
keys := make([]string, len(req))
|
|
args := make([]string, len(req))
|
|
for i, r := range req {
|
|
if r.Quantity <= 0 {
|
|
return nil, nil
|
|
}
|
|
keys[i] = getInventoryKey(r.SKU, r.LocationID)
|
|
args[i] = strconv.Itoa(int(r.Quantity))
|
|
}
|
|
return keys, args
|
|
}
|
|
|
|
func (s *RedisInventoryService) ReservationCheck(req ...ReserveRequest) error {
|
|
if len(req) == 0 {
|
|
return ErrMissingReservation
|
|
}
|
|
|
|
keys, args := makeKeysAndArgs(req...)
|
|
if keys == nil || args == nil {
|
|
return ErrInvalidQuantity
|
|
}
|
|
|
|
cmd := reserveScript.Run(s.ctx, s.client, keys, args)
|
|
if err := cmd.Err(); err != nil {
|
|
return err
|
|
}
|
|
if val, err := cmd.Int(); err != nil {
|
|
return err
|
|
} else if val != 1 {
|
|
return ErrInsufficientInventory
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *RedisInventoryService) ReserveInventory(req ...ReserveRequest) error {
|
|
if len(req) == 0 {
|
|
return ErrMissingReservation
|
|
}
|
|
|
|
keys, args := makeKeysAndArgs(req...)
|
|
if keys == nil || args == nil {
|
|
return ErrInvalidQuantity
|
|
}
|
|
cmd := reserveScript.Run(s.ctx, s.client, keys, args)
|
|
if err := cmd.Err(); err != nil {
|
|
return err
|
|
}
|
|
if val, err := cmd.Int(); err != nil {
|
|
return err
|
|
} else if val != 1 {
|
|
return ErrInsufficientInventory
|
|
}
|
|
return nil
|
|
}
|
|
|
|
var reservationCheck = redis.NewScript(`
|
|
-- Get the number of keys passed
|
|
local num_keys = #KEYS
|
|
|
|
-- Ensure the number of keys matches the number of quantities
|
|
if num_keys ~= #ARGV then
|
|
return {err = "Script requires the same number of keys and quantities."}
|
|
end
|
|
|
|
local new_values = {}
|
|
local payload = {}
|
|
|
|
-- ---
|
|
-- 1. CHECK PHASE
|
|
-- ---
|
|
-- Loop through all keys to check their values first
|
|
for i = 1, num_keys do
|
|
local key = KEYS[i]
|
|
local quantity_to_check = tonumber(ARGV[i])
|
|
|
|
-- Fail if the quantity is not a valid number
|
|
if not quantity_to_check then
|
|
return {err = "Invalid quantity provided for key: " .. key}
|
|
end
|
|
|
|
-- Get the current value stored at the key
|
|
local current_val = tonumber(redis.call('GET', key))
|
|
|
|
-- Check the condition
|
|
-- Fail if:
|
|
-- 1. The key doesn't exist (current_val is nil)
|
|
-- 2. The value is not > the required quantity
|
|
if not current_val or current_val <= quantity_to_check then
|
|
-- Return 0 to indicate the operation failed and no changes were made
|
|
return 0
|
|
end
|
|
end
|
|
|
|
return 1
|
|
`)
|
|
|
|
var reserveScript = redis.NewScript(`
|
|
-- Get the number of keys passed
|
|
local num_keys = #KEYS
|
|
|
|
-- Ensure the number of keys matches the number of quantities
|
|
if num_keys ~= #ARGV then
|
|
return {err = "Script requires the same number of keys and quantities."}
|
|
end
|
|
|
|
local new_values = {}
|
|
local payload = {}
|
|
|
|
-- ---
|
|
-- 1. CHECK PHASE
|
|
-- ---
|
|
-- Loop through all keys to check their values first
|
|
for i = 1, num_keys do
|
|
local key = KEYS[i]
|
|
local quantity_to_check = tonumber(ARGV[i])
|
|
|
|
-- Fail if the quantity is not a valid number
|
|
if not quantity_to_check then
|
|
return {err = "Invalid quantity provided for key: " .. key}
|
|
end
|
|
|
|
-- Get the current value stored at the key
|
|
local current_val = tonumber(redis.call('GET', key))
|
|
|
|
-- Check the condition
|
|
-- Fail if:
|
|
-- 1. The key doesn't exist (current_val is nil)
|
|
-- 2. The value is not > the required quantity
|
|
if not current_val or current_val <= quantity_to_check then
|
|
-- Return 0 to indicate the operation failed and no changes were made
|
|
return 0
|
|
end
|
|
|
|
-- If the check passes, store the new value
|
|
local new_val = current_val - quantity_to_check
|
|
table.insert(new_values, new_val)
|
|
|
|
-- Add this key and its *new* value to our payload map
|
|
payload[key] = new_val
|
|
end
|
|
|
|
-- ---
|
|
-- 2. UPDATE PHASE
|
|
-- ---
|
|
-- If the script reaches this point, all checks passed.
|
|
-- Now, loop again and apply all the updates.
|
|
for i = 1, num_keys do
|
|
local key = KEYS[i]
|
|
local new_val = new_values[i]
|
|
|
|
-- Set the key to its new calculated value
|
|
redis.call('SET', key, new_val)
|
|
end
|
|
local message_payload = cjson.encode(payload)
|
|
|
|
-- Publish the JSON-encoded message to the specified channel
|
|
redis.call('PUBLISH', "inventory_changed", message_payload)
|
|
-- Return 1 to indicate the operation was successful
|
|
return 1
|
|
`)
|