refactor once again
This commit is contained in:
315
ownership_middleware.go
Normal file
315
ownership_middleware.go
Normal file
@@ -0,0 +1,315 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// OwnershipProxyMiddleware provides HTTP-layer routing to the primary owner
|
||||
// of a cart before the request hits local handlers.
|
||||
//
|
||||
// Motivation:
|
||||
//
|
||||
// In the current system SyncedPool can proxy cart mutations to remote owners
|
||||
// via remote grains (gRPC). For a simpler deployment you can instead forward
|
||||
// the incoming HTTP request directly to the owning host and let only the
|
||||
// owner execute the standard handlers (which apply mutations locally).
|
||||
//
|
||||
// Behavior:
|
||||
// 1. Attempts to extract a cart id from (in priority order):
|
||||
// - Cookie "cartid"
|
||||
// - Path segment after "/byid/{id}" (e.g. /cart/byid/abc123/add/sku)
|
||||
// 2. Resolves the primary owner host using the consistent hashing ring
|
||||
// maintained by SyncedPool.
|
||||
// 3. If the owner is the local host (or no id found), the request proceeds.
|
||||
// 4. If the owner is a different host, the middleware performs an in-cluster
|
||||
// HTTP proxy (single-hop) to http://<owner>:<port><original-path>?<query>
|
||||
// and streams the response back to the client.
|
||||
// 5. Adds headers:
|
||||
// X-Cart-Owner: <resolved-owner>
|
||||
// X-Cart-Owner-Routed: "true" (only when proxied)
|
||||
// X-Cart-Id: <cart-id> (when available)
|
||||
// On local handling (not proxied) X-Cart-Owner-Routed is "false".
|
||||
//
|
||||
// Configuration:
|
||||
//
|
||||
// CART_SERVICE_PORT (env) - target port for proxying (default: 8080)
|
||||
// CART_PROXY_TIMEOUT_MS (env) - timeout for outbound proxy calls (default: 800)
|
||||
//
|
||||
// Integration:
|
||||
//
|
||||
// Wrap just the cart mux:
|
||||
//
|
||||
// cartMux := syncedServer.Serve() // existing cart handlers
|
||||
// wrapped := OwnershipProxyMiddleware(syncedPool)(cartMux)
|
||||
// mux.Handle("/cart/", http.StripPrefix("/cart", wrapped))
|
||||
//
|
||||
// Fallbacks:
|
||||
//
|
||||
// If extraction or proxying fails, a 502 is returned (except missing cart id
|
||||
// which simply skips routing). Timeouts produce 504.
|
||||
//
|
||||
// NOTE:
|
||||
// - This does NOT (yet) support sticky upgrade / websockets.
|
||||
// - Only primary ownership is considered (replicas ignored).
|
||||
// - This keeps control plane & ring logic unmodified.
|
||||
//
|
||||
// You can gradually phase out remote grain logic by placing this middleware
|
||||
// in front while leaving the rest of the code untouched.
|
||||
func OwnershipProxyMiddleware(pool *SyncedPool) func(http.Handler) http.Handler {
|
||||
localHost := pool.Hostname()
|
||||
targetPort := envOr("CART_SERVICE_PORT", "8080")
|
||||
timeout := envDurationOr("CART_PROXY_TIMEOUT_MS", 800*time.Millisecond)
|
||||
|
||||
client := &http.Client{
|
||||
Timeout: timeout,
|
||||
Transport: &http.Transport{
|
||||
MaxIdleConnsPerHost: 32,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
// Dialer with small timeouts to fail fast inside cluster
|
||||
DialContext: (&net.Dialer{
|
||||
Timeout: 300 * time.Millisecond,
|
||||
KeepAlive: 30 * time.Second,
|
||||
}).DialContext,
|
||||
},
|
||||
}
|
||||
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// CORS preflight / safe methods that don't need routing without id.
|
||||
if r.Method == http.MethodOptions {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
cartId, ok := extractCartIdFromRequest(r)
|
||||
if !ok || cartId.String() == "" {
|
||||
// No cart id available -> cannot determine ownership; proceed locally.
|
||||
w.Header().Set("X-Cart-Owner-Routed", "false")
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
owner := pool.OwnerHost(cartId)
|
||||
w.Header().Set("X-Cart-Id", cartId.String())
|
||||
w.Header().Set("X-Cart-Owner", owner)
|
||||
|
||||
// Route locally if we're the owner or owner resolution empty.
|
||||
if owner == "" || owner == localHost {
|
||||
w.Header().Set("X-Cart-Owner-Routed", "false")
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
// Proxy to remote owner
|
||||
proxyURL := buildProxyURL(r, owner, targetPort)
|
||||
bodyBuf, err := readBodyDuplicate(r)
|
||||
if err != nil {
|
||||
http.Error(w, "failed to read request body", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(r.Context(), r.Method, proxyURL, bodyBuf)
|
||||
if err != nil {
|
||||
http.Error(w, "failed to create proxy request", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
copyHeaders(req.Header, r.Header)
|
||||
// Ensure we don't forward hop-by-hop headers
|
||||
cleanHopHeaders(req.Header)
|
||||
req.Header.Set("X-Forwarded-For", appendForwardedFor(r))
|
||||
req.Header.Set("X-Forwarded-Host", r.Host)
|
||||
req.Header.Set("X-Forwarded-Proto", schemeFromRequest(r))
|
||||
req.Header.Set("X-Cart-Forwarded", "true")
|
||||
|
||||
start := time.Now()
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
if os.IsTimeout(err) || strings.Contains(err.Error(), "timeout") {
|
||||
http.Error(w, "gateway timeout contacting owner", http.StatusGatewayTimeout)
|
||||
return
|
||||
}
|
||||
http.Error(w, "upstream owner error", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Copy status + headers
|
||||
copyHeaders(w.Header(), resp.Header)
|
||||
w.Header().Set("X-Cart-Owner-Routed", "true")
|
||||
w.Header().Set("X-Cart-Owner-Latency-Ms", durationMs(time.Since(start)))
|
||||
w.WriteHeader(resp.StatusCode)
|
||||
io.Copy(w, resp.Body)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// (Removed duplicate OwnerHost method; single implementation now lives in synced-pool.go)
|
||||
|
||||
// extractCartIdFromRequest tries cookie first, then path form /byid/{id}/...
|
||||
func extractCartIdFromRequest(r *http.Request) (CartId, bool) {
|
||||
// Cookie
|
||||
if c, err := r.Cookie("cartid"); err == nil && c.Value != "" {
|
||||
if cid, _, _, err2 := CanonicalizeOrLegacy(c.Value); err2 == nil {
|
||||
return CartIDToLegacy(cid), true
|
||||
}
|
||||
}
|
||||
// Path-based: locate "byid" segment
|
||||
parts := splitPath(r.URL.Path)
|
||||
for i := 0; i < len(parts); i++ {
|
||||
if parts[i] == "byid" && i+1 < len(parts) {
|
||||
raw := parts[i+1]
|
||||
if raw != "" {
|
||||
if cid, _, _, err := CanonicalizeOrLegacy(raw); err == nil {
|
||||
return CartIDToLegacy(cid), true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
var zero CartId
|
||||
return zero, false
|
||||
}
|
||||
|
||||
// Helpers
|
||||
|
||||
func envOr(key, def string) string {
|
||||
if v := os.Getenv(key); v != "" {
|
||||
return v
|
||||
}
|
||||
return def
|
||||
}
|
||||
|
||||
func envDurationOr(key string, def time.Duration) time.Duration {
|
||||
if v := os.Getenv(key); v != "" {
|
||||
if d, err := time.ParseDuration(v); err == nil {
|
||||
return d
|
||||
}
|
||||
}
|
||||
return def
|
||||
}
|
||||
|
||||
func buildProxyURL(r *http.Request, host, port string) string {
|
||||
sb := &strings.Builder{}
|
||||
sb.WriteString("http://")
|
||||
sb.WriteString(host)
|
||||
if port != "" {
|
||||
sb.WriteString(":")
|
||||
sb.WriteString(port)
|
||||
}
|
||||
// Preserve original path & query (already includes /cart prefix stripped? depends on where middleware placed)
|
||||
sb.WriteString(r.URL.Path)
|
||||
if rq := r.URL.RawQuery; rq != "" {
|
||||
sb.WriteString("?")
|
||||
sb.WriteString(rq)
|
||||
}
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func readBodyDuplicate(r *http.Request) (io.ReadCloser, error) {
|
||||
if r.Body == nil {
|
||||
return http.NoBody, nil
|
||||
}
|
||||
defer r.Body.Close()
|
||||
buf, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Restore original for downstream if local (we only call when proxying, but safe)
|
||||
r.Body = io.NopCloser(bytes.NewReader(buf))
|
||||
return io.NopCloser(bytes.NewReader(buf)), nil
|
||||
}
|
||||
|
||||
func copyHeaders(dst, src http.Header) {
|
||||
for k, vv := range src {
|
||||
// Skip hop-by-hop; they'll be cleaned anyway
|
||||
for _, v := range vv {
|
||||
dst.Add(k, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var hopHeaders = map[string]struct{}{
|
||||
"Connection": {},
|
||||
"Proxy-Connection": {},
|
||||
"Keep-Alive": {},
|
||||
"Proxy-Authenticate": {},
|
||||
"Proxy-Authorization": {},
|
||||
"Te": {},
|
||||
"Trailer": {},
|
||||
"Transfer-Encoding": {},
|
||||
"Upgrade": {},
|
||||
}
|
||||
|
||||
func cleanHopHeaders(h http.Header) {
|
||||
for k := range hopHeaders {
|
||||
h.Del(k)
|
||||
}
|
||||
}
|
||||
|
||||
func appendForwardedFor(r *http.Request) string {
|
||||
host, _, _ := net.SplitHostPort(r.RemoteAddr)
|
||||
if host == "" {
|
||||
host = r.RemoteAddr
|
||||
}
|
||||
prior := r.Header.Get("X-Forwarded-For")
|
||||
if prior == "" {
|
||||
return host
|
||||
}
|
||||
return prior + ", " + host
|
||||
}
|
||||
|
||||
func schemeFromRequest(r *http.Request) string {
|
||||
if r.TLS != nil {
|
||||
return "https"
|
||||
}
|
||||
if proto := r.Header.Get("X-Forwarded-Proto"); proto != "" {
|
||||
return proto
|
||||
}
|
||||
return "http"
|
||||
}
|
||||
|
||||
func splitPath(p string) []string {
|
||||
if p == "" || p == "/" {
|
||||
return nil
|
||||
}
|
||||
trimmed := strings.TrimPrefix(p, "/")
|
||||
if trimmed == "" {
|
||||
return nil
|
||||
}
|
||||
parts := strings.Split(trimmed, "/")
|
||||
return parts
|
||||
}
|
||||
|
||||
func durationMs(d time.Duration) string {
|
||||
return strconvFormatInt(int64(d / time.Millisecond))
|
||||
}
|
||||
|
||||
// strconvFormatInt is a tiny helper to avoid importing strconv for one use.
|
||||
func strconvFormatInt(i int64) string {
|
||||
// Fast int64 -> string (base 10) without strconv for small dependency surface.
|
||||
if i == 0 {
|
||||
return "0"
|
||||
}
|
||||
neg := i < 0
|
||||
if neg {
|
||||
i = -i
|
||||
}
|
||||
var buf [20]byte
|
||||
pos := len(buf)
|
||||
for i > 0 {
|
||||
pos--
|
||||
buf[pos] = byte('0' + (i % 10))
|
||||
i /= 10
|
||||
}
|
||||
if neg {
|
||||
pos--
|
||||
buf[pos] = '-'
|
||||
}
|
||||
return string(buf[pos:])
|
||||
}
|
||||
Reference in New Issue
Block a user