213 lines
5.5 KiB
Go
213 lines
5.5 KiB
Go
package actor
|
|
|
|
import (
|
|
"fmt"
|
|
"log"
|
|
"reflect"
|
|
"sync"
|
|
|
|
"github.com/gogo/protobuf/proto"
|
|
)
|
|
|
|
type ApplyResult struct {
|
|
Type string `json:"type"`
|
|
Mutation proto.Message `json:"mutation"`
|
|
Error error `json:"error,omitempty"`
|
|
}
|
|
|
|
type MutationRegistry interface {
|
|
Apply(grain any, msg ...proto.Message) ([]ApplyResult, error)
|
|
RegisterMutations(handlers ...MutationHandler)
|
|
Create(typeName string) (proto.Message, bool)
|
|
GetTypeName(msg proto.Message) (string, bool)
|
|
//GetStorageEvent(msg proto.Message) StorageEvent
|
|
//FromStorageEvent(event StorageEvent) (proto.Message, error)
|
|
}
|
|
|
|
type ProtoMutationRegistry struct {
|
|
mutationRegistryMu sync.RWMutex
|
|
mutationRegistry map[reflect.Type]MutationHandler
|
|
}
|
|
|
|
var (
|
|
ErrMutationNotRegistered = fmt.Errorf("mutation not registered")
|
|
)
|
|
|
|
// MutationOption configures additional behavior for a registered mutation.
|
|
type MutationOption func(*mutationOptions)
|
|
|
|
// mutationOptions holds flags adjustable per registration.
|
|
type mutationOptions struct {
|
|
updateTotals bool
|
|
}
|
|
|
|
// WithTotals ensures CartGrain.UpdateTotals() is called after a successful handler.
|
|
func WithTotals() MutationOption {
|
|
return func(o *mutationOptions) {
|
|
o.updateTotals = true
|
|
}
|
|
}
|
|
|
|
type MutationHandler interface {
|
|
Handle(state any, msg proto.Message) error
|
|
Name() string
|
|
Type() reflect.Type
|
|
Create() proto.Message
|
|
}
|
|
|
|
// RegisteredMutation stores metadata + the execution closure.
|
|
type RegisteredMutation[V any, T proto.Message] struct {
|
|
name string
|
|
handler func(*V, T) error
|
|
create func() T
|
|
msgType reflect.Type
|
|
}
|
|
|
|
func NewMutation[V any, T proto.Message](handler func(*V, T) error, create func() T) *RegisteredMutation[V, T] {
|
|
// Derive the name and message type from a concrete instance produced by create().
|
|
// This avoids relying on reflect.TypeFor (which can yield unexpected results in some toolchains)
|
|
// and ensures we always peel off the pointer layer for proto messages.
|
|
instance := create()
|
|
rt := reflect.TypeOf(instance)
|
|
if rt.Kind() == reflect.Ptr {
|
|
rt = rt.Elem()
|
|
}
|
|
return &RegisteredMutation[V, T]{
|
|
name: rt.Name(),
|
|
handler: handler,
|
|
create: create,
|
|
msgType: rt,
|
|
}
|
|
}
|
|
|
|
func (m *RegisteredMutation[V, T]) Handle(state any, msg proto.Message) error {
|
|
return m.handler(state.(*V), msg.(T))
|
|
}
|
|
|
|
func (m *RegisteredMutation[V, T]) Name() string {
|
|
return m.name
|
|
}
|
|
|
|
func (m *RegisteredMutation[V, T]) Create() proto.Message {
|
|
return m.create()
|
|
}
|
|
|
|
func (m *RegisteredMutation[V, T]) Type() reflect.Type {
|
|
return m.msgType
|
|
}
|
|
|
|
func NewMutationRegistry() MutationRegistry {
|
|
return &ProtoMutationRegistry{
|
|
mutationRegistry: make(map[reflect.Type]MutationHandler),
|
|
mutationRegistryMu: sync.RWMutex{},
|
|
}
|
|
}
|
|
|
|
func (r *ProtoMutationRegistry) RegisterMutations(handlers ...MutationHandler) {
|
|
r.mutationRegistryMu.Lock()
|
|
defer r.mutationRegistryMu.Unlock()
|
|
|
|
for _, handler := range handlers {
|
|
r.mutationRegistry[handler.Type()] = handler
|
|
}
|
|
}
|
|
|
|
func (r *ProtoMutationRegistry) GetTypeName(msg proto.Message) (string, bool) {
|
|
r.mutationRegistryMu.RLock()
|
|
defer r.mutationRegistryMu.RUnlock()
|
|
|
|
rt := indirectType(reflect.TypeOf(msg))
|
|
if handler, ok := r.mutationRegistry[rt]; ok {
|
|
return handler.Name(), true
|
|
}
|
|
return "", false
|
|
}
|
|
|
|
func (r *ProtoMutationRegistry) getHandler(typeName string) MutationHandler {
|
|
r.mutationRegistryMu.Lock()
|
|
defer r.mutationRegistryMu.Unlock()
|
|
|
|
for _, handler := range r.mutationRegistry {
|
|
if handler.Name() == typeName {
|
|
return handler
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (r *ProtoMutationRegistry) Create(typeName string) (proto.Message, bool) {
|
|
|
|
handler := r.getHandler(typeName)
|
|
if handler == nil {
|
|
log.Printf("missing handler for %s", typeName)
|
|
return nil, false
|
|
}
|
|
|
|
return handler.Create(), true
|
|
}
|
|
|
|
// ApplyRegistered attempts to apply a registered mutation.
|
|
// Returns updated grain if successful.
|
|
//
|
|
// If the mutation is not registered, returns (nil, ErrMutationNotRegistered).
|
|
func (r *ProtoMutationRegistry) Apply(grain any, msg ...proto.Message) ([]ApplyResult, error) {
|
|
results := make([]ApplyResult, 0, len(msg))
|
|
|
|
if grain == nil {
|
|
return results, fmt.Errorf("nil grain")
|
|
}
|
|
if msg == nil {
|
|
return results, fmt.Errorf("nil mutation message")
|
|
}
|
|
|
|
for _, m := range msg {
|
|
rt := indirectType(reflect.TypeOf(m))
|
|
r.mutationRegistryMu.RLock()
|
|
entry, ok := r.mutationRegistry[rt]
|
|
r.mutationRegistryMu.RUnlock()
|
|
if !ok {
|
|
results = append(results, ApplyResult{Error: ErrMutationNotRegistered, Type: rt.Name(), Mutation: m})
|
|
continue
|
|
}
|
|
err := entry.Handle(grain, m)
|
|
results = append(results, ApplyResult{Error: err, Type: rt.Name(), Mutation: m})
|
|
}
|
|
|
|
// if entry.updateTotals {
|
|
// grain.UpdateTotals()
|
|
// }
|
|
|
|
return results, nil
|
|
}
|
|
|
|
// RegisteredMutations returns metadata for all registered mutations (snapshot).
|
|
func (r *ProtoMutationRegistry) RegisteredMutations() []string {
|
|
r.mutationRegistryMu.RLock()
|
|
defer r.mutationRegistryMu.RUnlock()
|
|
out := make([]string, 0, len(r.mutationRegistry))
|
|
for _, entry := range r.mutationRegistry {
|
|
out = append(out, entry.Name())
|
|
}
|
|
return out
|
|
}
|
|
|
|
// RegisteredMutationTypes returns the reflect.Type list of all registered messages.
|
|
// Useful for coverage tests ensuring expected set matches actual set.
|
|
func (r *ProtoMutationRegistry) RegisteredMutationTypes() []reflect.Type {
|
|
r.mutationRegistryMu.RLock()
|
|
defer r.mutationRegistryMu.RUnlock()
|
|
out := make([]reflect.Type, 0, len(r.mutationRegistry))
|
|
for _, entry := range r.mutationRegistry {
|
|
out = append(out, entry.Type())
|
|
}
|
|
return out
|
|
}
|
|
|
|
func indirectType(t reflect.Type) reflect.Type {
|
|
for t.Kind() == reflect.Ptr {
|
|
t = t.Elem()
|
|
}
|
|
return t
|
|
}
|