Files
go-cart-actor/pkg/actor/mutation_registry.go
matst80 5e36af2524 wip
2025-12-04 22:09:26 +01:00

362 lines
9.5 KiB
Go

package actor
import (
"context"
"fmt"
"log"
"reflect"
"sync"
"go.opentelemetry.io/otel/attribute"
"google.golang.org/protobuf/proto"
)
type ApplyResult struct {
Type string `json:"type"`
Mutation proto.Message `json:"mutation"`
Error error `json:"error,omitempty"`
}
type MutationProcessor interface {
Process(ctx context.Context, grain any) error
}
type BasicMutationProcessor[V any] struct {
processor func(ctx context.Context, grain V) error
}
func NewMutationProcessor[V any](process func(ctx context.Context, grain V) error) MutationProcessor {
return &BasicMutationProcessor[V]{
processor: process,
}
}
func (p *BasicMutationProcessor[V]) Process(ctx context.Context, grain any) error {
return p.processor(ctx, grain.(V))
}
type MutationRegistry interface {
Apply(ctx context.Context, grain any, msg ...proto.Message) ([]ApplyResult, error)
RegisterMutations(handlers ...MutationHandler)
Create(typeName string) (proto.Message, bool)
GetTypeName(msg proto.Message) (string, bool)
RegisterProcessor(processor ...MutationProcessor)
RegisterTrigger(trigger ...TriggerHandler)
SetEventChannel(ch chan<- ApplyResult)
}
type ProtoMutationRegistry struct {
mutationRegistryMu sync.RWMutex
mutationRegistry map[reflect.Type]MutationHandler
triggers map[reflect.Type][]TriggerHandler
processors []MutationProcessor
eventChannel chan<- ApplyResult
}
var (
ErrMutationNotRegistered = &MutationError{
Message: "mutation not registered",
Code: 255,
StatusCode: 500,
}
)
type MutationError struct {
Message string `json:"message"`
Code uint32 `json:"code"`
StatusCode uint32 `json:"status_code"`
}
func (m MutationError) Error() string {
return m.Message
}
// MutationOption configures additional behavior for a registered mutation.
type MutationOption func(*mutationOptions)
// mutationOptions holds flags adjustable per registration.
type mutationOptions struct {
updateTotals bool
}
// WithTotals ensures CartGrain.UpdateTotals() is called after a successful handler.
func WithTotals() MutationOption {
return func(o *mutationOptions) {
o.updateTotals = true
}
}
type TriggerHandler interface {
Handle(state any, msg proto.Message) []proto.Message
Name() string
Type() reflect.Type
}
type RegisteredTrigger[V any, I proto.Message] struct {
name string
handler func(state any, msg proto.Message) []proto.Message
msgType reflect.Type
}
func NewTrigger[V any, I proto.Message](name string, handler func(state any, msg proto.Message) []proto.Message) *RegisteredTrigger[V, I] {
return &RegisteredTrigger[V, I]{
name: name,
handler: handler,
msgType: reflect.TypeOf((*I)(nil)).Elem(),
}
}
type MutationHandler interface {
Handle(state any, msg proto.Message) error
Name() string
Type() reflect.Type
Create() proto.Message
}
// RegisteredMutation stores metadata + the execution closure.
type RegisteredMutation[V any, T proto.Message] struct {
name string
handler func(*V, T) error
create func() proto.Message
msgType reflect.Type
}
func NewMutation[V any, T proto.Message](handler func(*V, T) error) *RegisteredMutation[V, T] {
// Derive the name and message type from a concrete instance produced by create().
// This avoids relying on reflect.TypeFor (which can yield unexpected results in some toolchains)
// and ensures we always peel off the pointer layer for proto messages.
create := func() proto.Message {
var t T
rt := reflect.TypeOf(t)
if rt != nil && rt.Kind() == reflect.Pointer {
return reflect.New(rt.Elem()).Interface().(proto.Message)
}
log.Fatalf("expected to create proto message got %+v", rt)
return nil
}
instance := create()
rt := reflect.TypeOf(instance)
if rt.Kind() == reflect.Pointer {
rt = rt.Elem()
}
return &RegisteredMutation[V, T]{
name: rt.Name(),
handler: handler,
create: create,
msgType: rt,
}
}
func (m *RegisteredMutation[V, T]) Handle(state any, msg proto.Message) error {
return m.handler(state.(*V), msg.(T))
}
func (m *RegisteredMutation[V, T]) Name() string {
return m.name
}
func (m *RegisteredMutation[V, T]) Create() proto.Message {
return m.create()
}
func (m *RegisteredMutation[V, T]) Type() reflect.Type {
return m.msgType
}
func NewMutationRegistry() MutationRegistry {
return &ProtoMutationRegistry{
mutationRegistry: make(map[reflect.Type]MutationHandler),
mutationRegistryMu: sync.RWMutex{},
triggers: make(map[reflect.Type][]TriggerHandler),
processors: make([]MutationProcessor, 0),
}
}
func (r *ProtoMutationRegistry) RegisterProcessor(processors ...MutationProcessor) {
r.processors = append(r.processors, processors...)
}
func (r *ProtoMutationRegistry) RegisterMutations(handlers ...MutationHandler) {
r.mutationRegistryMu.Lock()
defer r.mutationRegistryMu.Unlock()
for _, handler := range handlers {
r.mutationRegistry[handler.Type()] = handler
}
}
func (r *ProtoMutationRegistry) RegisterTrigger(triggers ...TriggerHandler) {
r.mutationRegistryMu.Lock()
defer r.mutationRegistryMu.Unlock()
for _, trigger := range triggers {
existingTriggers, ok := r.triggers[trigger.Type()]
if !ok {
r.triggers[trigger.Type()] = []TriggerHandler{trigger}
} else {
r.triggers[trigger.Type()] = append(existingTriggers, trigger)
}
}
}
func (r *ProtoMutationRegistry) SetEventChannel(ch chan<- ApplyResult) {
r.eventChannel = ch
}
func (r *ProtoMutationRegistry) GetTypeName(msg proto.Message) (string, bool) {
r.mutationRegistryMu.RLock()
defer r.mutationRegistryMu.RUnlock()
rt := indirectType(reflect.TypeOf(msg))
if handler, ok := r.mutationRegistry[rt]; ok {
return handler.Name(), true
}
return "", false
}
func (r *ProtoMutationRegistry) getHandler(typeName string) MutationHandler {
r.mutationRegistryMu.Lock()
defer r.mutationRegistryMu.Unlock()
for _, handler := range r.mutationRegistry {
if handler.Name() == typeName {
return handler
}
}
return nil
}
func (r *ProtoMutationRegistry) Create(typeName string) (proto.Message, bool) {
handler := r.getHandler(typeName)
if handler == nil {
log.Printf("missing handler for %s", typeName)
return nil, false
}
return handler.Create(), true
}
// ApplyRegistered attempts to apply a registered mutation.
// Returns updated grain if successful.
//
// If the mutation is not registered, returns (nil, ErrMutationNotRegistered).
func (r *ProtoMutationRegistry) Apply(ctx context.Context, grain any, msg ...proto.Message) ([]ApplyResult, error) {
parentCtx, span := tracer.Start(ctx, "apply mutations")
defer span.End()
span.SetAttributes(
attribute.String("component", "registry"),
attribute.Int("mutations", len(msg)),
)
results := make([]ApplyResult, 0, len(msg))
if grain == nil {
return results, fmt.Errorf("nil grain")
}
// Nil slice of mutations still treated as an error (call contract violation).
if msg == nil {
return results, fmt.Errorf("nil mutation message")
}
for _, m := range msg {
// Error if any mutation element is nil.
if m == nil {
return results, fmt.Errorf("nil mutation message")
}
// Typed nil: interface holds concrete proto message type whose pointer value is nil.
rv := reflect.ValueOf(m)
if rv.Kind() == reflect.Pointer && rv.IsNil() {
continue
}
rt := indirectType(reflect.TypeOf(m))
_, msgSpan := tracer.Start(parentCtx, rt.Name())
r.mutationRegistryMu.RLock()
entry, ok := r.mutationRegistry[rt]
r.mutationRegistryMu.RUnlock()
if !ok {
results = append(results, ApplyResult{Error: ErrMutationNotRegistered, Type: rt.Name(), Mutation: m})
continue
} else {
err := entry.Handle(grain, m)
if err != nil {
msgSpan.RecordError(err)
}
if r.eventChannel != nil {
go func() {
defer func() {
if r := recover(); r != nil {
// Handle panic from sending to closed channel
log.Printf("event channel closed: %v", r)
}
}()
for _, tr := range r.triggers[rt] {
for _, msg := range tr.Handle(grain, m) {
select {
case r.eventChannel <- msg:
default:
// Channel full or no receiver, skip to avoid blocking
}
}
}
}()
}
results = append(results, ApplyResult{Error: err, Type: rt.Name(), Mutation: m})
}
msgSpan.End()
}
if len(results) > 0 {
processCtx, processSpan := tracer.Start(ctx, "after mutation processors")
defer processSpan.End()
for _, processor := range r.processors {
err := processor.Process(processCtx, grain)
if err != nil {
return results, err
}
}
}
// Return error for unregistered mutations
for _, res := range results {
if res.Error == ErrMutationNotRegistered {
return results, res.Error
}
}
return results, nil
}
// RegisteredMutations returns metadata for all registered mutations (snapshot).
func (r *ProtoMutationRegistry) RegisteredMutations() []string {
r.mutationRegistryMu.RLock()
defer r.mutationRegistryMu.RUnlock()
out := make([]string, 0, len(r.mutationRegistry))
for _, entry := range r.mutationRegistry {
out = append(out, entry.Name())
}
return out
}
// RegisteredMutationTypes returns the reflect.Type list of all registered messages.
// Useful for coverage tests ensuring expected set matches actual set.
func (r *ProtoMutationRegistry) RegisteredMutationTypes() []reflect.Type {
r.mutationRegistryMu.RLock()
defer r.mutationRegistryMu.RUnlock()
out := make([]reflect.Type, 0, len(r.mutationRegistry))
for _, entry := range r.mutationRegistry {
out = append(out, entry.Type())
}
return out
}
func indirectType(t reflect.Type) reflect.Type {
for t.Kind() == reflect.Ptr {
t = t.Elem()
}
return t
}