diff --git a/pkg/actor/mutation_registry.go b/pkg/actor/mutation_registry.go index 7a059a4..c5f69bc 100644 --- a/pkg/actor/mutation_registry.go +++ b/pkg/actor/mutation_registry.go @@ -41,14 +41,16 @@ type MutationRegistry interface { Create(typeName string) (proto.Message, bool) GetTypeName(msg proto.Message) (string, bool) RegisterProcessor(processor ...MutationProcessor) - //GetStorageEvent(msg proto.Message) StorageEvent - //FromStorageEvent(event StorageEvent) (proto.Message, error) + RegisterTrigger(trigger ...TriggerHandler) + SetEventChannel(ch chan<- ApplyResult) } type ProtoMutationRegistry struct { mutationRegistryMu sync.RWMutex mutationRegistry map[reflect.Type]MutationHandler + triggers map[reflect.Type][]TriggerHandler processors []MutationProcessor + eventChannel chan<- ApplyResult } var ( @@ -84,6 +86,26 @@ func WithTotals() MutationOption { } } +type TriggerHandler interface { + Handle(state any, msg proto.Message) []proto.Message + Name() string + Type() reflect.Type +} + +type RegisteredTrigger[V any, I proto.Message] struct { + name string + handler func(state any, msg proto.Message) []proto.Message + msgType reflect.Type +} + +func NewTrigger[V any, I proto.Message](name string, handler func(state any, msg proto.Message) []proto.Message) *RegisteredTrigger[V, I] { + return &RegisteredTrigger[V, I]{ + name: name, + handler: handler, + msgType: reflect.TypeOf((*I)(nil)).Elem(), + } +} + type MutationHandler interface { Handle(state any, msg proto.Message) error Name() string @@ -145,6 +167,7 @@ func NewMutationRegistry() MutationRegistry { return &ProtoMutationRegistry{ mutationRegistry: make(map[reflect.Type]MutationHandler), mutationRegistryMu: sync.RWMutex{}, + triggers: make(map[reflect.Type][]TriggerHandler), processors: make([]MutationProcessor, 0), } } @@ -162,6 +185,24 @@ func (r *ProtoMutationRegistry) RegisterMutations(handlers ...MutationHandler) { } } +func (r *ProtoMutationRegistry) RegisterTrigger(triggers ...TriggerHandler) { + r.mutationRegistryMu.Lock() + defer r.mutationRegistryMu.Unlock() + + for _, trigger := range triggers { + existingTriggers, ok := r.triggers[trigger.Type()] + if !ok { + r.triggers[trigger.Type()] = []TriggerHandler{trigger} + } else { + r.triggers[trigger.Type()] = append(existingTriggers, trigger) + } + } +} + +func (r *ProtoMutationRegistry) SetEventChannel(ch chan<- ApplyResult) { + r.eventChannel = ch +} + func (r *ProtoMutationRegistry) GetTypeName(msg proto.Message) (string, bool) { r.mutationRegistryMu.RLock() defer r.mutationRegistryMu.RUnlock() @@ -244,6 +285,25 @@ func (r *ProtoMutationRegistry) Apply(ctx context.Context, grain any, msg ...pro if err != nil { msgSpan.RecordError(err) } + if r.eventChannel != nil { + go func() { + defer func() { + if r := recover(); r != nil { + // Handle panic from sending to closed channel + log.Printf("event channel closed: %v", r) + } + }() + for _, tr := range r.triggers[rt] { + for _, msg := range tr.Handle(grain, m) { + select { + case r.eventChannel <- msg: + default: + // Channel full or no receiver, skip to avoid blocking + } + } + } + }() + } results = append(results, ApplyResult{Error: err, Type: rt.Name(), Mutation: m}) } msgSpan.End() @@ -266,6 +326,7 @@ func (r *ProtoMutationRegistry) Apply(ctx context.Context, grain any, msg ...pro return results, res.Error } } + return results, nil } diff --git a/pkg/actor/mutation_registry_test.go b/pkg/actor/mutation_registry_test.go index 9d9f7b7..e48a77a 100644 --- a/pkg/actor/mutation_registry_test.go +++ b/pkg/actor/mutation_registry_test.go @@ -5,6 +5,7 @@ import ( "reflect" "slices" "testing" + "time" cart_messages "git.k6n.net/go-cart-actor/proto/cart" ) @@ -104,31 +105,103 @@ func TestRegisteredMutationBasics(t *testing.T) { } } -// func TestConcurrentSafeRegistrationLookup(t *testing.T) { -// // This test is light-weight; it ensures locks don't deadlock under simple concurrent access. -// reg := NewMutationRegistry().(*ProtoMutationRegistry) -// mut := NewMutation[cartState, *messages.Noop]( -// func(state *cartState, msg *messages.Noop) error { state.calls++; return nil }, -// func() *messages.Noop { return &messages.Noop{} }, -// ) -// reg.RegisterMutations(mut) +func TestEventChannel(t *testing.T) { + reg := NewMutationRegistry().(*ProtoMutationRegistry) -// done := make(chan struct{}) -// const workers = 25 -// for i := 0; i < workers; i++ { -// go func() { -// for j := 0; j < 100; j++ { -// _, _ = reg.Create("Noop") -// _, _ = reg.GetTypeName(&messages.Noop{}) -// _ = reg.Apply(&cartState{}, &messages.Noop{}) -// } -// done <- struct{}{} -// }() -// } + addItemMutation := NewMutation( + func(state *cartState, msg *cart_messages.AddItem) error { + state.calls++ + return nil + }, + ) -// for i := 0; i < workers; i++ { -// <-done -// } -// } + reg.RegisterMutations(addItemMutation) + + eventCh := make(chan ApplyResult, 10) + reg.SetEventChannel(eventCh) + + state := &cartState{} + add := &cart_messages.AddItem{ItemId: 42, Quantity: 3, Sku: "ABC"} + results, err := reg.Apply(context.Background(), state, add) + if err != nil { + t.Fatalf("Apply returned error: %v", err) + } + if len(results) != 1 { + t.Fatalf("expected 1 result, got %d", len(results)) + } + + // Receive from channel with timeout + select { + case res := <-eventCh: + if res.Type != "AddItem" { + t.Fatalf("expected type AddItem, got %s", res.Type) + } + if res.Error != nil { + t.Fatalf("expected no error, got %v", res.Error) + } + case <-time.After(time.Second): + t.Fatalf("expected to receive event on channel within timeout") + } +} + +func TestEventChannelClosed(t *testing.T) { + reg := NewMutationRegistry().(*ProtoMutationRegistry) + + addItemMutation := NewMutation( + func(state *cartState, msg *cart_messages.AddItem) error { + state.calls++ + return nil + }, + ) + + reg.RegisterMutations(addItemMutation) + + eventCh := make(chan ApplyResult, 10) + reg.SetEventChannel(eventCh) + + close(eventCh) // Close the channel to simulate external close + + state := &cartState{} + add := &cart_messages.AddItem{ItemId: 42, Quantity: 3, Sku: "ABC"} + // This should not panic due to recover in goroutine + results, err := reg.Apply(context.Background(), state, add) + if err != nil { + t.Fatalf("Apply returned error: %v", err) + } + if len(results) != 1 { + t.Fatalf("expected 1 result, got %d", len(results)) + } + // Test passes if no panic occurs +} + +func TestEventChannelUnbufferedNoListener(t *testing.T) { + reg := NewMutationRegistry().(*ProtoMutationRegistry) + + addItemMutation := NewMutation( + func(state *cartState, msg *cart_messages.AddItem) error { + state.calls++ + return nil + }, + ) + + reg.RegisterMutations(addItemMutation) + + eventCh := make(chan ApplyResult) // unbuffered + reg.SetEventChannel(eventCh) + + // No goroutine reading from eventCh + + state := &cartState{} + add := &cart_messages.AddItem{ItemId: 42, Quantity: 3, Sku: "ABC"} + results, err := reg.Apply(context.Background(), state, add) + if err != nil { + t.Fatalf("Apply returned error: %v", err) + } + if len(results) != 1 { + t.Fatalf("expected 1 result, got %d", len(results)) + } + // Since no listener, the send should go to default and not block + // Test passes if Apply completes without hanging +} // Helpers