diff --git a/cmd/cart/pool-server.go b/cmd/cart/pool-server.go index 36639b7..04ca98a 100644 --- a/cmd/cart/pool-server.go +++ b/cmd/cart/pool-server.go @@ -6,7 +6,6 @@ import ( "io" "log" "net/http" - "os" "strconv" "sync" "time" @@ -205,9 +204,9 @@ type AddRequest struct { } func (s *PoolServer) GetReservationTime(item *messages.AddItem) time.Duration { - + // TODO: Implement reservation time calculation, nil don't require reservation return time.Minute * 15 - //return nil + } func (s *PoolServer) AddSkuRequestHandler(w http.ResponseWriter, r *http.Request, id cart.CartId) error { @@ -248,60 +247,6 @@ func (s *PoolServer) AddSkuRequestHandler(w http.ResponseWriter, r *http.Request // return json.NewEncoder(w).Encode(order) // } -func getCurrency(country string) string { - if country == "no" { - return "NOK" - } - return "SEK" -} - -func getLocale(country string) string { - if country == "no" { - return "nb-no" - } - return "sv-se" -} - -func getLocationId(item *cart.CartItem) inventory.LocationID { - if item.StoreId == nil || *item.StoreId == "" { - return "se" - } - return inventory.LocationID(*item.StoreId) -} - -func getInventoryRequests(items []*cart.CartItem) []inventory.ReserveRequest { - var requests []inventory.ReserveRequest - for _, item := range items { - if item == nil { - continue - } - requests = append(requests, inventory.ReserveRequest{ - InventoryReference: &inventory.InventoryReference{ - SKU: inventory.SKU(item.Sku), - LocationID: getLocationId(item), - }, - Quantity: uint32(item.Quantity), - }) - } - return requests -} - -func getOriginalHost(r *http.Request) string { - proxyHost := r.Header.Get("X-Forwarded-Host") - if proxyHost != "" { - return proxyHost - } - return r.Host -} - -func getClientIp(r *http.Request) string { - ip := r.Header.Get("X-Forwarded-For") - if ip == "" { - ip = r.RemoteAddr - } - return ip -} - // func (s *PoolServer) HandleCheckout(w http.ResponseWriter, r *http.Request, id CartId) error { // klarnaOrder, err := s.CreateOrUpdateCheckout(r.Host, id) // if err != nil { @@ -315,52 +260,6 @@ func getClientIp(r *http.Request) string { // } // -func CookieCartIdHandler(fn func(cartId cart.CartId, w http.ResponseWriter, r *http.Request) error) func(w http.ResponseWriter, r *http.Request) { - return func(w http.ResponseWriter, r *http.Request) { - - var id cart.CartId - cookie, err := r.Cookie("cartid") - if err != nil || cookie.Value == "" { - id = cart.MustNewCartId() - http.SetCookie(w, &http.Cookie{ - Name: "cartid", - Value: id.String(), - Secure: r.TLS != nil, - HttpOnly: true, - Path: "/", - Expires: time.Now().AddDate(0, 0, 14), - SameSite: http.SameSiteLaxMode, - }) - w.Header().Set("Set-Cart-Id", id.String()) - } else { - parsed, ok := cart.ParseCartId(cookie.Value) - if !ok { - id = cart.MustNewCartId() - http.SetCookie(w, &http.Cookie{ - Name: "cartid", - Value: id.String(), - Secure: r.TLS != nil, - HttpOnly: true, - Path: "/", - Expires: time.Now().AddDate(0, 0, 14), - SameSite: http.SameSiteLaxMode, - }) - w.Header().Set("Set-Cart-Id", id.String()) - } else { - id = parsed - } - } - - err = fn(id, w, r) - if err != nil { - log.Printf("Server error, not remote error: %v\n", err) - w.WriteHeader(http.StatusInternalServerError) - w.Write([]byte(err.Error())) - } - - } -} - // Removed leftover legacy block after CookieCartIdHandler (obsolete code referencing cid/legacy) func (s *PoolServer) RemoveCartCookie(w http.ResponseWriter, r *http.Request, cartId cart.CartId) error { @@ -378,34 +277,6 @@ func (s *PoolServer) RemoveCartCookie(w http.ResponseWriter, r *http.Request, ca return nil } -func CartIdHandler(fn func(cartId cart.CartId, w http.ResponseWriter, r *http.Request) error) func(w http.ResponseWriter, r *http.Request) { - return func(w http.ResponseWriter, r *http.Request) { - var id cart.CartId - raw := r.PathValue("id") - // If no id supplied, generate a new one - if raw == "" { - id := cart.MustNewCartId() - w.Header().Set("Set-Cart-Id", id.String()) - } else { - // Parse base62 cart id - if parsedId, ok := cart.ParseCartId(raw); !ok { - w.WriteHeader(http.StatusBadRequest) - w.Write([]byte("cart id is invalid")) - return - } else { - id = parsedId - } - } - - err := fn(id, w, r) - if err != nil { - log.Printf("Server error, not remote error: %v\n", err) - w.WriteHeader(http.StatusInternalServerError) - w.Write([]byte(err.Error())) - } - } -} - func (s *PoolServer) ProxyHandler(fn func(w http.ResponseWriter, r *http.Request, cartId cart.CartId) error) func(cartId cart.CartId, w http.ResponseWriter, r *http.Request) error { return func(cartId cart.CartId, w http.ResponseWriter, r *http.Request) error { if ownerHost, ok := s.OwnerHost(uint64(cartId)); ok { @@ -433,12 +304,9 @@ func (s *PoolServer) ProxyHandler(fn func(w http.ResponseWriter, r *http.Request var ( tracer = otel.Tracer(name) - hmacKey = os.Getenv("ADYEN_HMAC") meter = otel.Meter(name) logger = otelslog.NewLogger(name) proxyCalls metric.Int64Counter - -// rollCnt metric.Int64Counter ) func init() { diff --git a/cmd/cart/utils.go b/cmd/cart/utils.go new file mode 100644 index 0000000..f8cf223 --- /dev/null +++ b/cmd/cart/utils.go @@ -0,0 +1,138 @@ +package main + +import ( + "log" + "net/http" + "time" + + "git.k6n.net/go-cart-actor/pkg/cart" + "github.com/matst80/go-redis-inventory/pkg/inventory" +) + +func getCurrency(country string) string { + if country == "no" { + return "NOK" + } + return "SEK" +} + +func getLocale(country string) string { + if country == "no" { + return "nb-no" + } + return "sv-se" +} + +func getLocationId(item *cart.CartItem) inventory.LocationID { + if item.StoreId == nil || *item.StoreId == "" { + return "se" + } + return inventory.LocationID(*item.StoreId) +} + +func getInventoryRequests(items []*cart.CartItem) []inventory.ReserveRequest { + var requests []inventory.ReserveRequest + for _, item := range items { + if item == nil { + continue + } + requests = append(requests, inventory.ReserveRequest{ + InventoryReference: &inventory.InventoryReference{ + SKU: inventory.SKU(item.Sku), + LocationID: getLocationId(item), + }, + Quantity: uint32(item.Quantity), + }) + } + return requests +} + +func getOriginalHost(r *http.Request) string { + proxyHost := r.Header.Get("X-Forwarded-Host") + if proxyHost != "" { + return proxyHost + } + return r.Host +} + +func getClientIp(r *http.Request) string { + ip := r.Header.Get("X-Forwarded-For") + if ip == "" { + ip = r.RemoteAddr + } + return ip +} + +func CookieCartIdHandler(fn func(cartId cart.CartId, w http.ResponseWriter, r *http.Request) error) func(w http.ResponseWriter, r *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + + var id cart.CartId + cookie, err := r.Cookie("cartid") + if err != nil || cookie.Value == "" { + id = cart.MustNewCartId() + http.SetCookie(w, &http.Cookie{ + Name: "cartid", + Value: id.String(), + Secure: r.TLS != nil, + HttpOnly: true, + Path: "/", + Expires: time.Now().AddDate(0, 0, 14), + SameSite: http.SameSiteLaxMode, + }) + w.Header().Set("Set-Cart-Id", id.String()) + } else { + parsed, ok := cart.ParseCartId(cookie.Value) + if !ok { + id = cart.MustNewCartId() + http.SetCookie(w, &http.Cookie{ + Name: "cartid", + Value: id.String(), + Secure: r.TLS != nil, + HttpOnly: true, + Path: "/", + Expires: time.Now().AddDate(0, 0, 14), + SameSite: http.SameSiteLaxMode, + }) + w.Header().Set("Set-Cart-Id", id.String()) + } else { + id = parsed + } + } + + err = fn(id, w, r) + if err != nil { + log.Printf("Server error, not remote error: %v\n", err) + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte(err.Error())) + } + + } +} + +func CartIdHandler(fn func(cartId cart.CartId, w http.ResponseWriter, r *http.Request) error) func(w http.ResponseWriter, r *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + var id cart.CartId + raw := r.PathValue("id") + // If no id supplied, generate a new one + if raw == "" { + id := cart.MustNewCartId() + w.Header().Set("Set-Cart-Id", id.String()) + } else { + // Parse base62 cart id + if parsedId, ok := cart.ParseCartId(raw); !ok { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte("cart id is invalid")) + return + } else { + id = parsedId + } + } + + err := fn(id, w, r) + if err != nil { + log.Printf("Server error, not remote error: %v\n", err) + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte(err.Error())) + } + } +} diff --git a/cmd/checkout/adyen-handlers.go b/cmd/checkout/adyen-handlers.go new file mode 100644 index 0000000..3369138 --- /dev/null +++ b/cmd/checkout/adyen-handlers.go @@ -0,0 +1,220 @@ +package main + +import ( + "bytes" + "encoding/json" + "fmt" + "log" + "net/http" + "net/url" + + "git.k6n.net/go-cart-actor/pkg/actor" + "git.k6n.net/go-cart-actor/pkg/cart" + "git.k6n.net/go-cart-actor/pkg/proxy" + messages "git.k6n.net/go-cart-actor/proto/checkout" + adyenCheckout "github.com/adyen/adyen-go-api-library/v21/src/checkout" + "github.com/adyen/adyen-go-api-library/v21/src/common" + "github.com/adyen/adyen-go-api-library/v21/src/hmacvalidator" + "github.com/adyen/adyen-go-api-library/v21/src/webhook" + "github.com/google/uuid" +) + +type SessionRequest struct { + SessionId string `json:"sessionId"` + SessionResult string `json:"sessionResult"` + SessionData string `json:"sessionData,omitempty"` +} + +func (s *CheckoutPoolServer) AdyenSessionHandler(w http.ResponseWriter, r *http.Request, cartId cart.CartId) error { + + grain, err := s.Get(r.Context(), uint64(cartId)) + if err != nil { + return err + } + if r.Method == http.MethodGet { + meta := GetCheckoutMetaFromRequest(r) + sessionData, err := BuildAdyenCheckoutSession(grain, meta) + if err != nil { + return err + } + service := s.adyenClient.Checkout() + req := service.PaymentsApi.SessionsInput().CreateCheckoutSessionRequest(*sessionData) + res, _, err := service.PaymentsApi.Sessions(r.Context(), req) + // apply checkout started + if err != nil { + return err + } + return s.WriteResult(w, res) + } else { + payload := &SessionRequest{} + if err := json.NewDecoder(r.Body).Decode(payload); err != nil { + return err + } + service := s.adyenClient.Checkout() + req := service.PaymentsApi.GetResultOfPaymentSessionInput(payload.SessionId).SessionResult(payload.SessionResult) + res, _, err := service.PaymentsApi.GetResultOfPaymentSession(r.Context(), req) + if err != nil { + return err + } + return s.WriteResult(w, res) + } + +} + +func (s *CheckoutPoolServer) AdyenHookHandler(w http.ResponseWriter, r *http.Request) { + var notificationRequest webhook.Webhook + service := s.adyenClient.Checkout() + if err := json.NewDecoder(r.Body).Decode(¬ificationRequest); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + cartHostMap := make(map[actor.Host][]webhook.NotificationItem) + for _, notificationItem := range *notificationRequest.NotificationItems { + item := notificationItem.NotificationRequestItem + log.Printf("Recieved notification event code: %s, %+v", item.EventCode, item) + + isValid := hmacvalidator.ValidateHmac(item, hmacKey) + if !isValid { + log.Printf("notification hmac not valid %s, %v", item.EventCode, item) + http.Error(w, "Invalid HMAC", http.StatusUnauthorized) + return + } else { + switch item.EventCode { + case "CAPTURE": + log.Printf("Capture status: %v", item.Success) + // dataBytes, err := json.Marshal(item) + // if err != nil { + // log.Printf("error marshaling item: %v", err) + // http.Error(w, "Error marshaling item", http.StatusInternalServerError) + // return + // } + //s.ApplyAnywhere(r.Context(),0, &messages.PaymentEvent{PaymentId: item.PspReference, Success: item.Success, Name: item.EventCode, Data: &pbany.Any{Value: dataBytes}}) + case "AUTHORISATION": + + cartId, ok := cart.ParseCartId(item.MerchantReference) + if !ok { + log.Printf("invalid cart id %s", item.MerchantReference) + http.Error(w, "Invalid cart id", http.StatusBadRequest) + return + } + //s.Apply() + + if host, ok := s.OwnerHost(uint64(cartId)); ok { + cartHostMap[host] = append(cartHostMap[host], notificationItem) + continue + } + + grain, err := s.Get(r.Context(), uint64(cartId)) + if err != nil { + log.Printf("Error getting cart: %v", err) + http.Error(w, "Cart not found", http.StatusBadRequest) + return + } + meta := GetCheckoutMetaFromRequest(r) + pspReference := item.PspReference + uid := uuid.New().String() + ref := uuid.New().String() + req := service.ModificationsApi.CaptureAuthorisedPaymentInput(pspReference).IdempotencyKey(uid).PaymentCaptureRequest(adyenCheckout.PaymentCaptureRequest{ + Amount: adyenCheckout.Amount{ + Currency: meta.Currency, + Value: grain.CartTotalPrice.IncVat, + }, + MerchantAccount: "ElgigantenECOM", + Reference: &ref, + }) + res, _, err := service.ModificationsApi.CaptureAuthorisedPayment(r.Context(), req) + if err != nil { + log.Printf("Error capturing payment: %v", err) + } else { + log.Printf("Payment captured successfully: %+v", res) + s.Apply(r.Context(), uint64(cartId), &messages.OrderCreated{ + OrderId: res.PaymentPspReference, + Status: item.EventCode, + }) + } + default: + log.Printf("Unknown event code: %s", item.EventCode) + } + } + } + var failed bool = false + var lastMock *proxy.MockResponseWriter + for host, items := range cartHostMap { + notificationRequest.NotificationItems = &items + bodyBytes, err := json.Marshal(notificationRequest) + if err != nil { + log.Printf("error marshaling notification: %+v", err) + continue + } + customBody := bytes.NewReader(bodyBytes) + mockW := proxy.NewMockResponseWriter() + handled, err := host.Proxy(0, mockW, r, customBody) + if err != nil { + log.Printf("proxy failed for %s: %+v", host.Name(), err) + failed = true + lastMock = mockW + } else if handled { + log.Printf("notification proxied to %s", host.Name()) + } + } + if failed { + w.WriteHeader(lastMock.StatusCode) + w.Write(lastMock.Body.Bytes()) + } else { + w.WriteHeader(http.StatusAccepted) + } +} + +func (s *CheckoutPoolServer) AdyenReturnHandler(w http.ResponseWriter, r *http.Request) { + log.Println("Redirect received") + + service := s.adyenClient.Checkout() + + req := service.PaymentsApi.GetResultOfPaymentSessionInput(r.URL.Query().Get("sessionId")) + + res, httpRes, err := service.PaymentsApi.GetResultOfPaymentSession(r.Context(), req) + log.Printf("got payment session %+v", res) + + dreq := service.PaymentsApi.PaymentsDetailsInput() + dreq = dreq.PaymentDetailsRequest(adyenCheckout.PaymentDetailsRequest{ + Details: adyenCheckout.PaymentCompletionDetails{ + RedirectResult: common.PtrString(r.URL.Query().Get("redirectResult")), + Payload: common.PtrString(r.URL.Query().Get("payload")), + }, + }) + + dres, httpRes, err := service.PaymentsApi.PaymentsDetails(r.Context(), dreq) + + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + log.Printf("Payment details response: %+v", dres) + + if !common.IsNil(dres.PspReference) && *dres.PspReference != "" { + var redirectURL string + // Conditionally handle different result codes for the shopper + switch *dres.ResultCode { + case "Authorised": + redirectURL = "/result/success" + case "Pending", "Received": + redirectURL = "/result/pending" + case "Refused": + redirectURL = "/result/failed" + default: + reason := "" + if dres.RefusalReason != nil { + reason = *dres.RefusalReason + } else { + reason = *dres.ResultCode + } + log.Printf("Payment failed: %s", reason) + redirectURL = fmt.Sprintf("/result/error?reason=%s", url.QueryEscape(reason)) + } + http.Redirect(w, r, redirectURL, http.StatusFound) + return + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(httpRes.StatusCode) + json.NewEncoder(w).Encode(httpRes.Status) +} diff --git a/cmd/checkout/checkout_server.go b/cmd/checkout/checkout_server.go deleted file mode 100644 index 0a7e258..0000000 --- a/cmd/checkout/checkout_server.go +++ /dev/null @@ -1,226 +0,0 @@ -package main - -import ( - "context" - "encoding/json" - "fmt" - "log" - "net/http" - "time" - - "git.k6n.net/go-cart-actor/pkg/cart" - "git.k6n.net/go-cart-actor/pkg/checkout" - messages "git.k6n.net/go-cart-actor/proto/checkout" - - "github.com/matst80/go-redis-inventory/pkg/inventory" - amqp "github.com/rabbitmq/amqp091-go" -) - -var tpl = ` - - - - - s10r testing - checkout - - - - %s - - -` - -func getLocationId(item *cart.CartItem) inventory.LocationID { - if item.StoreId == nil || *item.StoreId == "" { - return "se" - } - return inventory.LocationID(*item.StoreId) -} - -func getInventoryRequests(items []*cart.CartItem) []inventory.ReserveRequest { - var requests []inventory.ReserveRequest - for _, item := range items { - if item == nil { - continue - } - requests = append(requests, inventory.ReserveRequest{ - InventoryReference: &inventory.InventoryReference{ - SKU: inventory.SKU(item.Sku), - LocationID: getLocationId(item), - }, - Quantity: uint32(item.Quantity), - }) - } - return requests -} - -func (a *App) getGrainFromOrder(ctx context.Context, order *CheckoutOrder) (*checkout.CheckoutGrain, error) { - cartId, ok := cart.ParseCartId(order.MerchantReference1) - if !ok { - return nil, fmt.Errorf("invalid cart id in order reference: %s", order.MerchantReference1) - } - grain, err := a.pool.Get(ctx, uint64(cartId)) - if err != nil { - return nil, fmt.Errorf("failed to get cart grain: %w", err) - } - return grain, nil -} - -func (a *App) HandleCheckoutRequests(amqpUrl string, mux *http.ServeMux, inventoryService inventory.InventoryService) { - conn, err := amqp.Dial(amqpUrl) - if err != nil { - log.Fatalf("failed to connect to RabbitMQ: %v", err) - } - - orderHandler := NewAmqpOrderHandler(conn) - orderHandler.DefineQueue() - - mux.HandleFunc("/push", func(w http.ResponseWriter, r *http.Request) { - log.Printf("Klarna order confirmation push, method: %s", r.Method) - if r.Method != http.MethodPost { - w.WriteHeader(http.StatusMethodNotAllowed) - return - } - orderId := r.URL.Query().Get("order_id") - log.Printf("Order confirmation push: %s", orderId) - - order, err := a.klarnaClient.GetOrder(r.Context(), orderId) - - if err != nil { - log.Printf("Error creating request: %v\n", err) - w.WriteHeader(http.StatusInternalServerError) - return - } - - grain, err := a.getGrainFromOrder(r.Context(), order) - if err != nil { - logger.ErrorContext(r.Context(), "Unable to get grain from klarna order", "error", err.Error()) - w.WriteHeader(http.StatusInternalServerError) - return - } - - if inventoryService != nil { - inventoryRequests := getInventoryRequests(grain.CartState.Items) - err = inventoryService.ReserveInventory(r.Context(), inventoryRequests...) - - if err != nil { - logger.WarnContext(r.Context(), "placeorder inventory reservation failed") - w.WriteHeader(http.StatusNotAcceptable) - return - } - a.pool.Apply(r.Context(), uint64(grain.Id), &messages.InventoryReserved{ - Id: grain.Id.String(), - Status: "success", - }) - } - - // err = confirmOrder(r.Context(), order, orderHandler) - // if err != nil { - // log.Printf("Error confirming order: %v\n", err) - // w.WriteHeader(http.StatusInternalServerError) - // return - // } - - // err = triggerOrderCompleted(r.Context(), a.server, order) - // if err != nil { - // log.Printf("Error processing cart message: %v\n", err) - // w.WriteHeader(http.StatusInternalServerError) - // return - // } - err = a.klarnaClient.AcknowledgeOrder(r.Context(), orderId) - if err != nil { - log.Printf("Error acknowledging order: %v\n", err) - } - - w.WriteHeader(http.StatusOK) - }) - - mux.HandleFunc("GET /checkout", a.server.CheckoutHandler(func(order *CheckoutOrder, w http.ResponseWriter) error { - w.Header().Set("Content-Type", "text/html; charset=utf-8") - w.Header().Set("Permissions-Policy", "payment=(self \"https://js.stripe.com\" \"https://m.stripe.network\" \"https://js.playground.kustom.co\")") - w.WriteHeader(http.StatusOK) - _, err := fmt.Fprintf(w, tpl, order.HTMLSnippet) - return err - })) - - mux.HandleFunc("GET /confirmation/{order_id}", func(w http.ResponseWriter, r *http.Request) { - - orderId := r.PathValue("order_id") - order, err := a.klarnaClient.GetOrder(r.Context(), orderId) - - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - w.Write([]byte(err.Error())) - return - } - - // Apply ConfirmationViewed mutation - cartId, ok := cart.ParseCartId(order.MerchantReference1) - if ok { - a.pool.Apply(r.Context(), uint64(cartId), &messages.ConfirmationViewed{}) - } - - w.Header().Set("Content-Type", "text/html; charset=utf-8") - if order.Status == "checkout_complete" { - http.SetCookie(w, &http.Cookie{ - Name: "cartid", - Value: "", - Path: "/", - Secure: true, - HttpOnly: true, - Expires: time.Unix(0, 0), - SameSite: http.SameSiteLaxMode, - }) - } - - w.WriteHeader(http.StatusOK) - fmt.Fprintf(w, tpl, order.HTMLSnippet) - }) - mux.HandleFunc("/notification", func(w http.ResponseWriter, r *http.Request) { - log.Printf("Klarna order notification, method: %s", r.Method) - logger.InfoContext(r.Context(), "Klarna order notification received", "method", r.Method) - if r.Method != "POST" { - w.WriteHeader(http.StatusMethodNotAllowed) - return - } - order := &CheckoutOrder{} - err := json.NewDecoder(r.Body).Decode(order) - if err != nil { - w.WriteHeader(http.StatusBadRequest) - } - log.Printf("Klarna order notification: %s", order.ID) - logger.InfoContext(r.Context(), "Klarna order notification received", "order_id", order.ID) - - w.WriteHeader(http.StatusOK) - }) - mux.HandleFunc("POST /validate", func(w http.ResponseWriter, r *http.Request) { - log.Printf("Klarna order validation, method: %s", r.Method) - if r.Method != "POST" { - w.WriteHeader(http.StatusMethodNotAllowed) - return - } - order := &CheckoutOrder{} - err := json.NewDecoder(r.Body).Decode(order) - if err != nil { - w.WriteHeader(http.StatusBadRequest) - } - logger.InfoContext(r.Context(), "Klarna order validation received", "order_id", order.ID, "cart_id", order.MerchantReference1) - grain, err := a.getGrainFromOrder(r.Context(), order) - if err != nil { - logger.ErrorContext(r.Context(), "Unable to get grain from klarna order", "error", err.Error()) - w.WriteHeader(http.StatusInternalServerError) - return - } - if inventoryService != nil { - inventoryRequests := getInventoryRequests(grain.CartState.Items) - _, err = inventoryService.ReservationCheck(r.Context(), inventoryRequests...) - if err != nil { - logger.WarnContext(r.Context(), "placeorder inventory check failed") - w.WriteHeader(http.StatusNotAcceptable) - return - } - } - - w.WriteHeader(http.StatusOK) - }) -} diff --git a/cmd/checkout/klarna-handlers.go b/cmd/checkout/klarna-handlers.go new file mode 100644 index 0000000..7c38916 --- /dev/null +++ b/cmd/checkout/klarna-handlers.go @@ -0,0 +1,168 @@ +package main + +import ( + "context" + "encoding/json" + "fmt" + "log" + "net/http" + + "git.k6n.net/go-cart-actor/pkg/cart" + "git.k6n.net/go-cart-actor/pkg/checkout" + messages "git.k6n.net/go-cart-actor/proto/checkout" + "github.com/matst80/go-redis-inventory/pkg/inventory" +) + +func (s *CheckoutPoolServer) KlarnaValidationHandler(w http.ResponseWriter, r *http.Request) { + log.Printf("Klarna order validation, method: %s", r.Method) + if r.Method != "POST" { + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + order := &CheckoutOrder{} + err := json.NewDecoder(r.Body).Decode(order) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + } + logger.InfoContext(r.Context(), "Klarna order validation received", "order_id", order.ID, "cart_id", order.MerchantReference1) + grain, err := s.getGrainFromKlarnaOrder(r.Context(), order) + if err != nil { + logger.ErrorContext(r.Context(), "Unable to get grain from klarna order", "error", err.Error()) + w.WriteHeader(http.StatusInternalServerError) + return + } + s.reserveInventory(r.Context(), grain) + + w.WriteHeader(http.StatusOK) + +} + +func (s *CheckoutPoolServer) KlarnaNotificationHandler(w http.ResponseWriter, r *http.Request) { + + log.Printf("Klarna order notification, method: %s", r.Method) + logger.InfoContext(r.Context(), "Klarna order notification received", "method", r.Method) + if r.Method != "POST" { + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + order := &CheckoutOrder{} + err := json.NewDecoder(r.Body).Decode(order) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + } + log.Printf("Klarna order notification: %s", order.ID) + logger.InfoContext(r.Context(), "Klarna order notification received", "order_id", order.ID) + + w.WriteHeader(http.StatusOK) + +} + +func (s *CheckoutPoolServer) KlarnaPushHandler(w http.ResponseWriter, r *http.Request) { + log.Printf("Klarna order confirmation push, method: %s", r.Method) + if r.Method != http.MethodPost { + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + orderId := r.URL.Query().Get("order_id") + log.Printf("Order confirmation push: %s", orderId) + + order, err := s.klarnaClient.GetOrder(r.Context(), orderId) + + if err != nil { + log.Printf("Error creating request: %v\n", err) + w.WriteHeader(http.StatusInternalServerError) + return + } + + grain, err := s.getGrainFromKlarnaOrder(r.Context(), order) + if err != nil { + logger.ErrorContext(r.Context(), "Unable to get grain from klarna order", "error", err.Error()) + w.WriteHeader(http.StatusInternalServerError) + return + } + + if s.inventoryService != nil { + inventoryRequests := getInventoryRequests(grain.CartState.Items) + err = s.inventoryService.ReserveInventory(r.Context(), inventoryRequests...) + + if err != nil { + logger.WarnContext(r.Context(), "placeorder inventory reservation failed") + w.WriteHeader(http.StatusNotAcceptable) + return + } + s.Apply(r.Context(), uint64(grain.Id), &messages.InventoryReserved{ + Id: grain.Id.String(), + Status: "success", + }) + } + + // err = confirmOrder(r.Context(), order, orderHandler) + // if err != nil { + // log.Printf("Error confirming order: %v\n", err) + // w.WriteHeader(http.StatusInternalServerError) + // return + // } + + // err = triggerOrderCompleted(r.Context(), a.server, order) + // if err != nil { + // log.Printf("Error processing cart message: %v\n", err) + // w.WriteHeader(http.StatusInternalServerError) + // return + // } + err = s.klarnaClient.AcknowledgeOrder(r.Context(), orderId) + if err != nil { + log.Printf("Error acknowledging order: %v\n", err) + } + + w.WriteHeader(http.StatusOK) +} + +var tpl = ` + + + + + s10r testing - checkout + + + + %s + + +` + +func getLocationId(item *cart.CartItem) inventory.LocationID { + if item.StoreId == nil || *item.StoreId == "" { + return "se" + } + return inventory.LocationID(*item.StoreId) +} + +func getInventoryRequests(items []*cart.CartItem) []inventory.ReserveRequest { + var requests []inventory.ReserveRequest + for _, item := range items { + if item == nil { + continue + } + requests = append(requests, inventory.ReserveRequest{ + InventoryReference: &inventory.InventoryReference{ + SKU: inventory.SKU(item.Sku), + LocationID: getLocationId(item), + }, + Quantity: uint32(item.Quantity), + }) + } + return requests +} + +func (a *CheckoutPoolServer) getGrainFromKlarnaOrder(ctx context.Context, order *CheckoutOrder) (*checkout.CheckoutGrain, error) { + cartId, ok := cart.ParseCartId(order.MerchantReference1) + if !ok { + return nil, fmt.Errorf("invalid cart id in order reference: %s", order.MerchantReference1) + } + grain, err := a.Get(ctx, uint64(cartId)) + if err != nil { + return nil, fmt.Errorf("failed to get cart grain: %w", err) + } + return grain, nil +} diff --git a/cmd/checkout/main.go b/cmd/checkout/main.go index 08fdf50..5522ada 100644 --- a/cmd/checkout/main.go +++ b/cmd/checkout/main.go @@ -107,21 +107,13 @@ func main() { cartClient := NewCartClient(cartInternalUrl) syncedServer := NewCheckoutPoolServer(pool, fmt.Sprintf("%s, %s", name, podIp), klarnaClient, cartClient, adyenClient) - - app := &App{ - pool: pool, - server: syncedServer, - klarnaClient: klarnaClient, - cartClient: cartClient, - } + syncedServer.inventoryService = inventoryService mux := http.NewServeMux() debugMux := http.NewServeMux() if amqpUrl == "" { - log.Printf("no connection to amqp defined") - } else { - app.HandleCheckoutRequests(amqpUrl, mux, inventoryService) + log.Fatalf("no connection to amqp defined") } grpcSrv, err := actor.NewControlServer[*checkout.CheckoutGrain](controlPlaneConfig, pool) @@ -143,7 +135,7 @@ func main() { syncedServer.Serve(mux) mux.HandleFunc("/healthz", func(w http.ResponseWriter, r *http.Request) { - grainCount, capacity := app.pool.LocalUsage() + grainCount, capacity := pool.LocalUsage() if grainCount >= capacity { w.WriteHeader(http.StatusInternalServerError) w.Write([]byte("grain pool at capacity")) diff --git a/cmd/checkout/pool-server.go b/cmd/checkout/pool-server.go index f93d159..eb824d4 100644 --- a/cmd/checkout/pool-server.go +++ b/cmd/checkout/pool-server.go @@ -7,22 +7,17 @@ import ( "fmt" "log" "net/http" - "net/url" "os" "time" "git.k6n.net/go-cart-actor/pkg/actor" "git.k6n.net/go-cart-actor/pkg/cart" "git.k6n.net/go-cart-actor/pkg/checkout" - "git.k6n.net/go-cart-actor/pkg/proxy" messages "git.k6n.net/go-cart-actor/proto/checkout" adyen "github.com/adyen/adyen-go-api-library/v21/src/adyen" - adyenCheckout "github.com/adyen/adyen-go-api-library/v21/src/checkout" - "github.com/adyen/adyen-go-api-library/v21/src/common" - "github.com/adyen/adyen-go-api-library/v21/src/hmacvalidator" - "github.com/adyen/adyen-go-api-library/v21/src/webhook" - "github.com/google/uuid" + "github.com/matst80/go-redis-inventory/pkg/inventory" + amqp "github.com/rabbitmq/amqp091-go" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" @@ -50,10 +45,11 @@ var ( type CheckoutPoolServer struct { actor.GrainPool[*checkout.CheckoutGrain] - pod_name string - klarnaClient *KlarnaClient - adyenClient *adyen.APIClient - cartClient *CartClient + pod_name string + klarnaClient *KlarnaClient + adyenClient *adyen.APIClient + cartClient *CartClient + inventoryService *inventory.RedisInventoryService } func NewCheckoutPoolServer(pool actor.GrainPool[*checkout.CheckoutGrain], pod_name string, klarnaClient *KlarnaClient, cartClient *CartClient, adyenClient *adyen.APIClient) *CheckoutPoolServer { @@ -156,56 +152,6 @@ func (s *CheckoutPoolServer) CheckoutHandler(fn func(order *CheckoutOrder, w htt })) } -func CheckoutIdHandler(fn func(checkoutId checkout.CheckoutId, w http.ResponseWriter, r *http.Request) error) func(w http.ResponseWriter, r *http.Request) { - return func(w http.ResponseWriter, r *http.Request) { - var id checkout.CheckoutId - raw := r.PathValue("id") - if raw == "" { - id = checkout.CheckoutId(cart.MustNewCartId()) - w.Header().Set("Set-Checkout-Id", id.String()) - } else { - if parsedId, ok := cart.ParseCartId(raw); !ok { - w.WriteHeader(http.StatusBadRequest) - w.Write([]byte("checkout id is invalid")) - return - } else { - id = checkout.CheckoutId(parsedId) - } - } - - err := fn(id, w, r) - if err != nil { - log.Printf("Server error, not remote error: %v\n", err) - w.WriteHeader(http.StatusInternalServerError) - w.Write([]byte(err.Error())) - } - } -} - -func (s *CheckoutPoolServer) ProxyHandler(fn func(w http.ResponseWriter, r *http.Request, checkoutId checkout.CheckoutId) error) func(checkoutId checkout.CheckoutId, w http.ResponseWriter, r *http.Request) error { - return func(checkoutId checkout.CheckoutId, w http.ResponseWriter, r *http.Request) error { - if ownerHost, ok := s.OwnerHost(uint64(checkoutId)); ok { - ctx, span := tracer.Start(r.Context(), "proxy") - defer span.End() - span.SetAttributes(attribute.String("checkoutid", checkoutId.String())) - hostAttr := attribute.String("other host", ownerHost.Name()) - span.SetAttributes(hostAttr) - logger.InfoContext(ctx, "checkout proxyed", "result", ownerHost.Name()) - proxyCalls.Add(ctx, 1, metric.WithAttributes(hostAttr)) - handled, err := ownerHost.Proxy(uint64(checkoutId), w, r, nil) - - grainLookups.Inc() - if err == nil && handled { - return nil - } - } - _, span := tracer.Start(r.Context(), "own") - span.SetAttributes(attribute.String("checkoutid", checkoutId.String())) - defer span.End() - return fn(w, r, checkoutId) - } -} - var ( tracer = otel.Tracer(name) hmacKey = os.Getenv("ADYEN_HMAC") @@ -224,206 +170,6 @@ func init() { } } -type SessionRequest struct { - SessionId string `json:"sessionId"` - SessionResult string `json:"sessionResult"` - SessionData string `json:"sessionData,omitempty"` -} - -func (s *CheckoutPoolServer) AdyenSessionHandler(w http.ResponseWriter, r *http.Request, cartId cart.CartId) error { - - grain, err := s.Get(r.Context(), uint64(cartId)) - if err != nil { - return err - } - if r.Method == http.MethodGet { - meta := GetCheckoutMetaFromRequest(r) - sessionData, err := BuildAdyenCheckoutSession(grain, meta) - if err != nil { - return err - } - service := s.adyenClient.Checkout() - req := service.PaymentsApi.SessionsInput().CreateCheckoutSessionRequest(*sessionData) - res, _, err := service.PaymentsApi.Sessions(r.Context(), req) - // apply checkout started - if err != nil { - return err - } - return s.WriteResult(w, res) - } else { - payload := &SessionRequest{} - if err := json.NewDecoder(r.Body).Decode(payload); err != nil { - return err - } - service := s.adyenClient.Checkout() - req := service.PaymentsApi.GetResultOfPaymentSessionInput(payload.SessionId).SessionResult(payload.SessionResult) - res, _, err := service.PaymentsApi.GetResultOfPaymentSession(r.Context(), req) - if err != nil { - return err - } - return s.WriteResult(w, res) - } - -} - -func (s *CheckoutPoolServer) AdyenHookHandler(w http.ResponseWriter, r *http.Request) { - var notificationRequest webhook.Webhook - service := s.adyenClient.Checkout() - if err := json.NewDecoder(r.Body).Decode(¬ificationRequest); err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - cartHostMap := make(map[actor.Host][]webhook.NotificationItem) - for _, notificationItem := range *notificationRequest.NotificationItems { - item := notificationItem.NotificationRequestItem - log.Printf("Recieved notification event code: %s, %+v", item.EventCode, item) - - isValid := hmacvalidator.ValidateHmac(item, hmacKey) - if !isValid { - log.Printf("notification hmac not valid %s, %v", item.EventCode, item) - http.Error(w, "Invalid HMAC", http.StatusUnauthorized) - return - } else { - switch item.EventCode { - case "CAPTURE": - log.Printf("Capture status: %v", item.Success) - // dataBytes, err := json.Marshal(item) - // if err != nil { - // log.Printf("error marshaling item: %v", err) - // http.Error(w, "Error marshaling item", http.StatusInternalServerError) - // return - // } - //s.ApplyAnywhere(r.Context(),0, &messages.PaymentEvent{PaymentId: item.PspReference, Success: item.Success, Name: item.EventCode, Data: &pbany.Any{Value: dataBytes}}) - case "AUTHORISATION": - - cartId, ok := cart.ParseCartId(item.MerchantReference) - if !ok { - log.Printf("invalid cart id %s", item.MerchantReference) - http.Error(w, "Invalid cart id", http.StatusBadRequest) - return - } - //s.Apply() - - if host, ok := s.OwnerHost(uint64(cartId)); ok { - cartHostMap[host] = append(cartHostMap[host], notificationItem) - continue - } - - grain, err := s.Get(r.Context(), uint64(cartId)) - if err != nil { - log.Printf("Error getting cart: %v", err) - http.Error(w, "Cart not found", http.StatusBadRequest) - return - } - meta := GetCheckoutMetaFromRequest(r) - pspReference := item.PspReference - uid := uuid.New().String() - ref := uuid.New().String() - req := service.ModificationsApi.CaptureAuthorisedPaymentInput(pspReference).IdempotencyKey(uid).PaymentCaptureRequest(adyenCheckout.PaymentCaptureRequest{ - Amount: adyenCheckout.Amount{ - Currency: meta.Currency, - Value: grain.CartTotalPrice.IncVat, - }, - MerchantAccount: "ElgigantenECOM", - Reference: &ref, - }) - res, _, err := service.ModificationsApi.CaptureAuthorisedPayment(r.Context(), req) - if err != nil { - log.Printf("Error capturing payment: %v", err) - } else { - log.Printf("Payment captured successfully: %+v", res) - s.Apply(r.Context(), uint64(cartId), &messages.OrderCreated{ - OrderId: res.PaymentPspReference, - Status: item.EventCode, - }) - } - default: - log.Printf("Unknown event code: %s", item.EventCode) - } - } - } - var failed bool = false - var lastMock *proxy.MockResponseWriter - for host, items := range cartHostMap { - notificationRequest.NotificationItems = &items - bodyBytes, err := json.Marshal(notificationRequest) - if err != nil { - log.Printf("error marshaling notification: %+v", err) - continue - } - customBody := bytes.NewReader(bodyBytes) - mockW := proxy.NewMockResponseWriter() - handled, err := host.Proxy(0, mockW, r, customBody) - if err != nil { - log.Printf("proxy failed for %s: %+v", host.Name(), err) - failed = true - lastMock = mockW - } else if handled { - log.Printf("notification proxied to %s", host.Name()) - } - } - if failed { - w.WriteHeader(lastMock.StatusCode) - w.Write(lastMock.Body.Bytes()) - } else { - w.WriteHeader(http.StatusAccepted) - } -} - -func (s *CheckoutPoolServer) AdyenReturnHandler(w http.ResponseWriter, r *http.Request) { - log.Println("Redirect received") - - service := s.adyenClient.Checkout() - - req := service.PaymentsApi.GetResultOfPaymentSessionInput(r.URL.Query().Get("sessionId")) - - res, httpRes, err := service.PaymentsApi.GetResultOfPaymentSession(r.Context(), req) - log.Printf("got payment session %+v", res) - - dreq := service.PaymentsApi.PaymentsDetailsInput() - dreq = dreq.PaymentDetailsRequest(adyenCheckout.PaymentDetailsRequest{ - Details: adyenCheckout.PaymentCompletionDetails{ - RedirectResult: common.PtrString(r.URL.Query().Get("redirectResult")), - Payload: common.PtrString(r.URL.Query().Get("payload")), - }, - }) - - dres, httpRes, err := service.PaymentsApi.PaymentsDetails(r.Context(), dreq) - - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - log.Printf("Payment details response: %+v", dres) - - if !common.IsNil(dres.PspReference) && *dres.PspReference != "" { - var redirectURL string - // Conditionally handle different result codes for the shopper - switch *dres.ResultCode { - case "Authorised": - redirectURL = "/result/success" - case "Pending", "Received": - redirectURL = "/result/pending" - case "Refused": - redirectURL = "/result/failed" - default: - reason := "" - if dres.RefusalReason != nil { - reason = *dres.RefusalReason - } else { - reason = *dres.ResultCode - } - log.Printf("Payment failed: %s", reason) - redirectURL = fmt.Sprintf("/result/error?reason=%s", url.QueryEscape(reason)) - } - http.Redirect(w, r, redirectURL, http.StatusFound) - return - } - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(httpRes.StatusCode) - json.NewEncoder(w).Encode(httpRes.Status) -} - func (s *CheckoutPoolServer) Serve(mux *http.ServeMux) { handleFunc := func(pattern string, handlerFunc func(http.ResponseWriter, *http.Request)) { attr := attribute.String("http.route", pattern) @@ -438,9 +184,63 @@ func (s *CheckoutPoolServer) Serve(mux *http.ServeMux) { handlerFunc(w, r) })) } + handleFunc("/payment/adyen/session", s.AdyenSessionHandler) + handleFunc("/payment/adyen/push", s.AdyenHookHandler) + handleFunc("/payment/adyen/return", s.AdyenReturnHandler) + //handleFunc("/payment/adyen/cancel", s.AdyenCancelHandler) - handleFunc("/adyen_hook", s.AdyenHookHandler) - handleFunc("/adyen-return", s.AdyenReturnHandler) + handleFunc("/payment/klarna/validate", s.KlarnaValidationHandler) + handleFunc("/payment/klarna/notification", s.KlarnaNotificationHandler) + + conn, err := amqp.Dial(amqpUrl) + if err != nil { + log.Fatalf("failed to connect to RabbitMQ: %v", err) + } + + orderHandler := NewAmqpOrderHandler(conn) + orderHandler.DefineQueue() + + mux.HandleFunc("GET /checkout", s.CheckoutHandler(func(order *CheckoutOrder, w http.ResponseWriter) error { + w.Header().Set("Content-Type", "text/html; charset=utf-8") + w.Header().Set("Permissions-Policy", "payment=(self \"https://js.stripe.com\" \"https://m.stripe.network\" \"https://js.playground.kustom.co\")") + w.WriteHeader(http.StatusOK) + _, err := fmt.Fprintf(w, tpl, order.HTMLSnippet) + return err + })) + + mux.HandleFunc("GET /confirmation/{order_id}", func(w http.ResponseWriter, r *http.Request) { + + orderId := r.PathValue("order_id") + order, err := s.klarnaClient.GetOrder(r.Context(), orderId) + + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte(err.Error())) + return + } + + // Apply ConfirmationViewed mutation + cartId, ok := cart.ParseCartId(order.MerchantReference1) + if ok { + a.Apply(r.Context(), uint64(cartId), &messages.ConfirmationViewed{}) + } + + w.Header().Set("Content-Type", "text/html; charset=utf-8") + if order.Status == "checkout_complete" { + http.SetCookie(w, &http.Cookie{ + Name: "cartid", + Value: "", + Path: "/", + Secure: true, + HttpOnly: true, + Expires: time.Unix(0, 0), + SameSite: http.SameSiteLaxMode, + }) + } + + w.WriteHeader(http.StatusOK) + fmt.Fprintf(w, tpl, order.HTMLSnippet) + }) handleFunc("GET /checkout", s.CheckoutHandler(func(order *CheckoutOrder, w http.ResponseWriter) error { w.Header().Set("Content-Type", "text/html; charset=utf-8") diff --git a/cmd/checkout/utils.go b/cmd/checkout/utils.go index 7e95b1b..c02428b 100644 --- a/cmd/checkout/utils.go +++ b/cmd/checkout/utils.go @@ -1,8 +1,15 @@ package main import ( + "context" + "log" "net/http" "strings" + + "git.k6n.net/go-cart-actor/pkg/cart" + "git.k6n.net/go-cart-actor/pkg/checkout" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/metric" ) func getOriginalHost(r *http.Request) string { @@ -44,3 +51,65 @@ func getCountryFromHost(host string) string { } return "" } + +func (a *CheckoutPoolServer) reserveInventory(ctx context.Context, grain *checkout.CheckoutGrain) error { + if a.inventoryService != nil { + inventoryRequests := getInventoryRequests(grain.CartState.Items) + _, err := a.inventoryService.ReservationCheck(ctx, inventoryRequests...) + if err != nil { + logger.WarnContext(ctx, "placeorder inventory check failed") + return err + } + } + return nil +} + +func CheckoutIdHandler(fn func(checkoutId checkout.CheckoutId, w http.ResponseWriter, r *http.Request) error) func(w http.ResponseWriter, r *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + var id checkout.CheckoutId + raw := r.PathValue("id") + if raw == "" { + id = checkout.CheckoutId(cart.MustNewCartId()) + w.Header().Set("Set-Checkout-Id", id.String()) + } else { + if parsedId, ok := cart.ParseCartId(raw); !ok { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte("checkout id is invalid")) + return + } else { + id = checkout.CheckoutId(parsedId) + } + } + + err := fn(id, w, r) + if err != nil { + log.Printf("Server error, not remote error: %v\n", err) + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte(err.Error())) + } + } +} + +func (s *CheckoutPoolServer) ProxyHandler(fn func(w http.ResponseWriter, r *http.Request, checkoutId checkout.CheckoutId) error) func(checkoutId checkout.CheckoutId, w http.ResponseWriter, r *http.Request) error { + return func(checkoutId checkout.CheckoutId, w http.ResponseWriter, r *http.Request) error { + if ownerHost, ok := s.OwnerHost(uint64(checkoutId)); ok { + ctx, span := tracer.Start(r.Context(), "proxy") + defer span.End() + span.SetAttributes(attribute.String("checkoutid", checkoutId.String())) + hostAttr := attribute.String("other host", ownerHost.Name()) + span.SetAttributes(hostAttr) + logger.InfoContext(ctx, "checkout proxyed", "result", ownerHost.Name()) + proxyCalls.Add(ctx, 1, metric.WithAttributes(hostAttr)) + handled, err := ownerHost.Proxy(uint64(checkoutId), w, r, nil) + + grainLookups.Inc() + if err == nil && handled { + return nil + } + } + _, span := tracer.Start(r.Context(), "own") + span.SetAttributes(attribute.String("checkoutid", checkoutId.String())) + defer span.End() + return fn(w, r, checkoutId) + } +} diff --git a/pkg/actor/mutation_registry.go b/pkg/actor/mutation_registry.go index 3b92cd2..6134f8a 100644 --- a/pkg/actor/mutation_registry.go +++ b/pkg/actor/mutation_registry.go @@ -95,14 +95,18 @@ type MutationHandler interface { type RegisteredMutation[V any, T proto.Message] struct { name string handler func(*V, T) error - create func() T + create func() proto.Message msgType reflect.Type } -func NewMutation[V any, T proto.Message](handler func(*V, T) error, create func() T) *RegisteredMutation[V, T] { +func NewMutation[V any, T proto.Message](handler func(*V, T) error) *RegisteredMutation[V, T] { // Derive the name and message type from a concrete instance produced by create(). // This avoids relying on reflect.TypeFor (which can yield unexpected results in some toolchains) // and ensures we always peel off the pointer layer for proto messages. + create := func() proto.Message { + m := new(T) + return *m + } instance := create() rt := reflect.TypeOf(instance) if rt.Kind() == reflect.Ptr { diff --git a/pkg/actor/mutation_registry_test.go b/pkg/actor/mutation_registry_test.go index 419b29c..9d9f7b7 100644 --- a/pkg/actor/mutation_registry_test.go +++ b/pkg/actor/mutation_registry_test.go @@ -6,33 +6,32 @@ import ( "slices" "testing" - "git.k6n.net/go-cart-actor/pkg/messages" + cart_messages "git.k6n.net/go-cart-actor/proto/cart" ) type cartState struct { calls int - lastAdded *messages.AddItem + lastAdded *cart_messages.AddItem } func TestRegisteredMutationBasics(t *testing.T) { reg := NewMutationRegistry().(*ProtoMutationRegistry) addItemMutation := NewMutation( - func(state *cartState, msg *messages.AddItem) error { + func(state *cartState, msg *cart_messages.AddItem) error { state.calls++ // copy to avoid external mutation side-effects (not strictly necessary for the test) cp := msg state.lastAdded = cp return nil }, - func() *messages.AddItem { return &messages.AddItem{} }, ) // Sanity check on mutation metadata if addItemMutation.Name() != "AddItem" { t.Fatalf("expected mutation Name() == AddItem, got %s", addItemMutation.Name()) } - if got, want := addItemMutation.Type(), reflect.TypeOf(messages.AddItem{}); got != want { + if got, want := addItemMutation.Type(), reflect.TypeOf(cart_messages.AddItem{}); got != want { t.Fatalf("expected Type() == %v, got %v", want, got) } @@ -46,18 +45,18 @@ func TestRegisteredMutationBasics(t *testing.T) { // RegisteredMutationTypes: membership (order not guaranteed) types := reg.RegisteredMutationTypes() - if !slices.Contains(types, reflect.TypeOf(messages.AddItem{})) { + if !slices.Contains(types, reflect.TypeOf(cart_messages.AddItem{})) { t.Fatalf("RegisteredMutationTypes missing AddItem type, got %v", types) } // GetTypeName should resolve for a pointer instance - name, ok := reg.GetTypeName(&messages.AddItem{}) + name, ok := reg.GetTypeName(&cart_messages.AddItem{}) if !ok || name != "AddItem" { t.Fatalf("GetTypeName returned (%q,%v), expected (AddItem,true)", name, ok) } // GetTypeName should fail for unregistered type - if name, ok := reg.GetTypeName(&messages.RemoveItem{}); ok || name != "" { + if name, ok := reg.GetTypeName(&cart_messages.RemoveItem{}); ok || name != "" { t.Fatalf("expected GetTypeName to fail for unregistered message, got (%q,%v)", name, ok) } @@ -66,7 +65,7 @@ func TestRegisteredMutationBasics(t *testing.T) { if !ok { t.Fatalf("Create failed for registered mutation") } - if _, isAddItem := msg.(*messages.AddItem); !isAddItem { + if _, isAddItem := msg.(*cart_messages.AddItem); !isAddItem { t.Fatalf("Create returned wrong concrete type: %T", msg) } @@ -77,7 +76,7 @@ func TestRegisteredMutationBasics(t *testing.T) { // Apply happy path state := &cartState{} - add := &messages.AddItem{ItemId: 42, Quantity: 3, Sku: "ABC"} + add := &cart_messages.AddItem{ItemId: 42, Quantity: 3, Sku: "ABC"} if _, err := reg.Apply(context.Background(), state, add); err != nil { t.Fatalf("Apply returned error: %v", err) } @@ -99,7 +98,7 @@ func TestRegisteredMutationBasics(t *testing.T) { } // Apply unregistered message - _, err := reg.Apply(context.Background(), state, &messages.RemoveItem{}) + _, err := reg.Apply(context.Background(), state, &cart_messages.RemoveItem{}) if err != ErrMutationNotRegistered { t.Fatalf("expected ErrMutationNotRegistered, got %v", err) } diff --git a/pkg/cart/cart-mutation-helper.go b/pkg/cart/cart-mutation-helper.go index ecdc397..341ae90 100644 --- a/pkg/cart/cart-mutation-helper.go +++ b/pkg/cart/cart-mutation-helper.go @@ -5,7 +5,6 @@ import ( "time" "git.k6n.net/go-cart-actor/pkg/actor" - messages "git.k6n.net/go-cart-actor/proto/cart" "github.com/matst80/go-redis-inventory/pkg/inventory" ) @@ -57,43 +56,28 @@ func (c *CartMutationContext) ReleaseItem(ctx context.Context, cartId CartId, sk return c.reservationService.ReleaseForCart(ctx, inventory.SKU(sku), l, inventory.CartID(cartId.String())) } +func Create[T any]() func() *T { + return func() *T { + return new(T) + } +} + func NewCartMultationRegistry(context *CartMutationContext) actor.MutationRegistry { reg := actor.NewMutationRegistry() reg.RegisterMutations( - actor.NewMutation(context.AddItem, func() *messages.AddItem { - return &messages.AddItem{} - }), - actor.NewMutation(context.ChangeQuantity, func() *messages.ChangeQuantity { - return &messages.ChangeQuantity{} - }), - actor.NewMutation(context.RemoveItem, func() *messages.RemoveItem { - return &messages.RemoveItem{} - }), - actor.NewMutation(ClearCart, func() *messages.ClearCartRequest { - return &messages.ClearCartRequest{} - }), - actor.NewMutation(AddVoucher, func() *messages.AddVoucher { - return &messages.AddVoucher{} - }), - actor.NewMutation(RemoveVoucher, func() *messages.RemoveVoucher { - return &messages.RemoveVoucher{} - }), - actor.NewMutation(UpsertSubscriptionDetails, func() *messages.UpsertSubscriptionDetails { - return &messages.UpsertSubscriptionDetails{} - }), - actor.NewMutation(SetUserId, func() *messages.SetUserId { - return &messages.SetUserId{} - }), - actor.NewMutation(LineItemMarking, func() *messages.LineItemMarking { - return &messages.LineItemMarking{} - }), - actor.NewMutation(RemoveLineItemMarking, func() *messages.RemoveLineItemMarking { - return &messages.RemoveLineItemMarking{} - }), - actor.NewMutation(SubscriptionAdded, func() *messages.SubscriptionAdded { - return &messages.SubscriptionAdded{} - }), + actor.NewMutation(context.AddItem), + actor.NewMutation(context.ChangeQuantity), + actor.NewMutation(context.RemoveItem), + actor.NewMutation(ClearCart), + actor.NewMutation(AddVoucher), + actor.NewMutation(RemoveVoucher), + actor.NewMutation(UpsertSubscriptionDetails), + actor.NewMutation(SetUserId), + actor.NewMutation(LineItemMarking), + actor.NewMutation(RemoveLineItemMarking), + actor.NewMutation(SubscriptionAdded), + // actor.NewMutation(SubscriptionRemoved), ) return reg diff --git a/pkg/checkout/mutation-context.go b/pkg/checkout/mutation-context.go index f4e71a4..e0e45b4 100644 --- a/pkg/checkout/mutation-context.go +++ b/pkg/checkout/mutation-context.go @@ -2,7 +2,6 @@ package checkout import ( "git.k6n.net/go-cart-actor/pkg/actor" - messages "git.k6n.net/go-cart-actor/proto/checkout" ) type CheckoutMutationContext struct { @@ -16,20 +15,17 @@ func NewCheckoutMutationContext() *CheckoutMutationContext { func NewCheckoutMutationRegistry(ctx *CheckoutMutationContext) actor.MutationRegistry { reg := actor.NewMutationRegistry() reg.RegisterMutations( - actor.NewMutation(HandleInitializeCheckout, func() *messages.InitializeCheckout { return &messages.InitializeCheckout{} }), - actor.NewMutation(HandlePaymentStarted, func() *messages.PaymentStarted { return &messages.PaymentStarted{} }), - actor.NewMutation(HandlePaymentCompleted, func() *messages.PaymentCompleted { return &messages.PaymentCompleted{} }), - actor.NewMutation(HandlePaymentDeclined, func() *messages.PaymentDeclined { return &messages.PaymentDeclined{} }), - actor.NewMutation(HandlePaymentEvent, func() *messages.PaymentEvent { return &messages.PaymentEvent{} }), - actor.NewMutation(HandleConfirmationViewed, func() *messages.ConfirmationViewed { return &messages.ConfirmationViewed{} }), - //actor.NewMutation(HandleCreateCheckoutOrder, func() *messages.CreateCheckoutOrder { return &messages.CreateCheckoutOrder{} }), - actor.NewMutation(HandleOrderCreated, func() *messages.OrderCreated { return &messages.OrderCreated{} }), - actor.NewMutation(HandleInventoryReserved, func() *messages.InventoryReserved { return &messages.InventoryReserved{} }), - actor.NewMutation(HandleSetDelivery, func() *messages.SetDelivery { return &messages.SetDelivery{} }), - actor.NewMutation(HandleSetPickupPoint, func() *messages.SetPickupPoint { return &messages.SetPickupPoint{} }), - actor.NewMutation(HandleRemoveDelivery, func() *messages.RemoveDelivery { return &messages.RemoveDelivery{} }), - // actor.NewMutation(HandleAddGiftcard, func() *messages.AddGiftcard { return &messages.AddGiftcard{} }), - // actor.NewMutation(HandleRemoveGiftcard, func() *messages.RemoveGiftcard { return &messages.RemoveGiftcard{} }), + actor.NewMutation(HandleInitializeCheckout), + actor.NewMutation(HandlePaymentStarted), + actor.NewMutation(HandlePaymentCompleted), + actor.NewMutation(HandlePaymentDeclined), + actor.NewMutation(HandlePaymentEvent), + actor.NewMutation(HandleConfirmationViewed), + actor.NewMutation(HandleOrderCreated), + actor.NewMutation(HandleInventoryReserved), + actor.NewMutation(HandleSetDelivery), + actor.NewMutation(HandleSetPickupPoint), + actor.NewMutation(HandleRemoveDelivery), ) return reg }