diff --git a/fsm/fsm.go b/fsm/fsm.go index 5fb8bdb..b721003 100644 --- a/fsm/fsm.go +++ b/fsm/fsm.go @@ -73,7 +73,9 @@ func NewFSMHandler(parameters Parameters) (statemachine.StateHandler, error) { d.finalityStates[finalityState] = struct{}{} } - d.initNotifier() + if d.notifier != nil { + d.notifications = make(chan notification) + } return d, nil } @@ -117,15 +119,44 @@ func (d fsmHandler) reachedFinalityState(user interface{}) bool { return final } -// initNotifier will start up a goroutine which processes the notification queue +// Init will start up a goroutine which processes the notification queue // in order -func (d *fsmHandler) initNotifier() { +func (d fsmHandler) Init(closing <-chan struct{}) { if d.notifier != nil { - d.notifications = make(chan notification, NotificationQueueSize) - + queue := make([]notification, 0, NotificationQueueSize) + toProcess := make(chan notification) + go func() { + for { + select { + case n := <-toProcess: + d.notifier(n.eventName, n.state) + case <-closing: + return + } + } + }() go func() { - for n := range d.notifications { - d.notifier(n.eventName, n.state) + outgoing := func() chan<- notification { + if len(queue) == 0 { + return nil + } + return toProcess + } + nextNofication := func() notification { + if len(queue) == 0 { + return notification{} + } + return queue[0] + } + for { + select { + case n := <-d.notifications: + queue = append(queue, n) + case outgoing() <- nextNofication(): + queue = queue[1:] + case <-closing: + return + } } }() } diff --git a/fsm/fsm_test.go b/fsm/fsm_test.go index ad63e32..fae0704 100644 --- a/fsm/fsm_test.go +++ b/fsm/fsm_test.go @@ -4,10 +4,12 @@ import ( "context" "fmt" "sync" + "sync/atomic" "testing" "time" "github.com/ipfs/go-datastore" + dss "github.com/ipfs/go-datastore/sync" logging "github.com/ipfs/go-log" "github.com/stretchr/testify/require" "gotest.tools/assert" @@ -57,7 +59,7 @@ var stateEntryFuncs = fsm.StateEntryFuncs{ } func TestTypeCheckingOnSetup(t *testing.T) { - ds := datastore.NewMapDatastore() + ds := dss.MutexWrap(datastore.NewMapDatastore()) te := &testEnvironment{t: t, done: make(chan struct{}), proceed: make(chan struct{})} t.Run("Bad state field", func(t *testing.T) { smm, err := fsm.New(ds, fsm.Parameters{ @@ -350,7 +352,7 @@ func newFsm(ds datastore.Datastore, te *testEnvironment) (fsm.Group, error) { } func TestArgumentChecks(t *testing.T) { - ds := datastore.NewMapDatastore() + ds := dss.MutexWrap(datastore.NewMapDatastore()) te := &testEnvironment{t: t, done: make(chan struct{}), proceed: make(chan struct{})} smm, err := newFsm(ds, te) close(te.proceed) @@ -372,7 +374,7 @@ func TestArgumentChecks(t *testing.T) { func TestBasic(t *testing.T) { for i := 0; i < 1000; i++ { // run a few times to expose any races - ds := datastore.NewMapDatastore() + ds := dss.MutexWrap(datastore.NewMapDatastore()) te := &testEnvironment{t: t, done: make(chan struct{}), proceed: make(chan struct{})} close(te.proceed) @@ -389,7 +391,7 @@ func TestBasic(t *testing.T) { func TestPersist(t *testing.T) { for i := 0; i < 1000; i++ { // run a few times to expose any races - ds := datastore.NewMapDatastore() + ds := dss.MutexWrap(datastore.NewMapDatastore()) te := &testEnvironment{t: t, done: make(chan struct{}), proceed: make(chan struct{})} smm, err := newFsm(ds, te) @@ -416,7 +418,7 @@ func TestPersist(t *testing.T) { func TestSyncEventHandling(t *testing.T) { ctx := context.Background() - ds := datastore.NewMapDatastore() + ds := dss.MutexWrap(datastore.NewMapDatastore()) te := &testEnvironment{t: t, done: make(chan struct{}), proceed: make(chan struct{})} smm, err := newFsm(ds, te) @@ -442,13 +444,13 @@ func TestSyncEventHandling(t *testing.T) { } func TestNotification(t *testing.T) { - notifications := 0 + notifications := new(uint64) var notifier fsm.Notifier = func(eventName fsm.EventName, state fsm.StateType) { - notifications++ + atomic.AddUint64(notifications, 1) } - ds := datastore.NewMapDatastore() + ds := dss.MutexWrap(datastore.NewMapDatastore()) te := &testEnvironment{t: t, done: make(chan struct{}), proceed: make(chan struct{}), universalCalls: 0} close(te.proceed) @@ -467,7 +469,8 @@ func TestNotification(t *testing.T) { require.NoError(t, err) <-te.done - require.Equal(t, notifications, 2) + total := atomic.LoadUint64(notifications) + require.Equal(t, total, uint64(2)) } func TestSerialNotification(t *testing.T) { @@ -488,14 +491,16 @@ func TestSerialNotification(t *testing.T) { var notifications []string wg := sync.WaitGroup{} + handleNotifications := make(chan struct{}) wg.Add(len(events)) var notifier fsm.Notifier = func(eventName fsm.EventName, state fsm.StateType) { + <-handleNotifications notifications = append(notifications, eventName.(string)) wg.Done() } - ds := datastore.NewMapDatastore() + ds := dss.MutexWrap(datastore.NewMapDatastore()) params := fsm.Parameters{ Environment: te, StateType: statemachine.TestState{}, @@ -512,7 +517,7 @@ func TestSerialNotification(t *testing.T) { err = smm.Send(uint64(2), eventName) require.NoError(t, err) } - + close(handleNotifications) wg.Wait() // Expect that notifications happened in the order that the events happened @@ -520,7 +525,7 @@ func TestSerialNotification(t *testing.T) { } func TestNoChangeHandler(t *testing.T) { - ds := datastore.NewMapDatastore() + ds := dss.MutexWrap(datastore.NewMapDatastore()) te := &testEnvironment{t: t, done: make(chan struct{}), proceed: make(chan struct{}), universalCalls: 0} close(te.proceed) @@ -539,7 +544,7 @@ func TestNoChangeHandler(t *testing.T) { } func TestAllStateEvent(t *testing.T) { - ds := datastore.NewMapDatastore() + ds := dss.MutexWrap(datastore.NewMapDatastore()) te := &testEnvironment{t: t, done: make(chan struct{}), proceed: make(chan struct{}), universalCalls: 0} close(te.proceed) @@ -563,7 +568,7 @@ func TestFinalityStates(t *testing.T) { ctx, cancel := context.WithTimeout(ctx, 1*time.Second) defer cancel() - ds := datastore.NewMapDatastore() + ds := dss.MutexWrap(datastore.NewMapDatastore()) te := &testEnvironment{t: t, done: make(chan struct{}), proceed: make(chan struct{})} smm, err := newFsm(ds, te) diff --git a/group.go b/group.go index 04a5a1c..80a4a0e 100644 --- a/group.go +++ b/group.go @@ -15,12 +15,20 @@ type StateHandler interface { Plan(events []Event, user interface{}) (interface{}, uint64, error) } +type StateHandlerWithInit interface { + StateHandler + Init(<-chan struct{}) +} + // StateGroup manages a group of state machines sharing the same logic type StateGroup struct { sts *statestore.StateStore hnd StateHandler stateType reflect.Type + closing chan struct{} + initNotifier sync.Once + lk sync.Mutex sms map[datastore.Key]*StateMachine } @@ -31,8 +39,15 @@ func New(ds datastore.Datastore, hnd StateHandler, stateType interface{}) *State sts: statestore.New(ds), hnd: hnd, stateType: reflect.TypeOf(stateType), + closing: make(chan struct{}), + sms: map[datastore.Key]*StateMachine{}, + } +} - sms: map[datastore.Key]*StateMachine{}, +func (s *StateGroup) init() { + initter, ok := s.hnd.(StateHandlerWithInit) + if ok { + initter.Init(s.closing) } } @@ -85,6 +100,7 @@ func (s *StateGroup) Send(id interface{}, evt interface{}) (err error) { } func (s *StateGroup) loadOrCreate(name interface{}, userState interface{}) (*StateMachine, error) { + s.initNotifier.Do(s.init) exists, err := s.sts.Has(name) if err != nil { return nil, xerrors.Errorf("failed to check if state for %v exists: %w", name, err) @@ -130,6 +146,7 @@ func (s *StateGroup) Stop(ctx context.Context) error { } } + close(s.closing) return nil } diff --git a/machine_test.go b/machine_test.go index 689154c..ebb8dc1 100644 --- a/machine_test.go +++ b/machine_test.go @@ -352,6 +352,96 @@ func (t *testHandlerNoStateCB) step1(ctx Context, st TestState) error { return nil } +type testHandlerWithGoRoutine struct { + t *testing.T + event chan struct{} + proceed chan struct{} + done chan struct{} + notifDone chan struct{} + count uint64 +} + +func (t *testHandlerWithGoRoutine) Plan(events []Event, state interface{}) (interface{}, uint64, error) { + return t.plan(events, state.(*TestState)) +} + +func (t *testHandlerWithGoRoutine) Init(onClose <-chan struct{}) { + go func() { + for { + select { + case <-t.event: + t.count++ + case <-onClose: + close(t.notifDone) + return + } + } + }() +} + +func (t *testHandlerWithGoRoutine) plan(events []Event, state *TestState) (func(Context, TestState) error, uint64, error) { + for _, event := range events { + e := event.User.(*TestEvent) + switch e.A { + case "restart": + case "start": + state.A = 1 + case "b": + state.A = 2 + state.B = e.Val + } + } + + t.event <- struct{}{} + switch state.A { + case 1: + return t.step0, uint64(len(events)), nil + case 2: + return t.step1, uint64(len(events)), nil + default: + t.t.Fatal(state.A) + } + panic("how?") +} + +func (t *testHandlerWithGoRoutine) step0(ctx Context, st TestState) error { + ctx.Send(&TestEvent{A: "b", Val: 55}) // nolint: errcheck + <-t.proceed + return nil +} + +func (t *testHandlerWithGoRoutine) step1(ctx Context, st TestState) error { + assert.Equal(t.t, uint64(2), st.A) + + close(t.done) + return nil +} + +func TestInit(t *testing.T) { + ctx := context.Background() + ctx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + for i := 0; i < 1000; i++ { // run a few times to expose any races + ds := datastore.NewMapDatastore() + + th := &testHandlerWithGoRoutine{t: t, event: make(chan struct{}), notifDone: make(chan struct{}), done: make(chan struct{}), proceed: make(chan struct{})} + close(th.proceed) + smm := New(ds, th, TestState{}) + + if err := smm.Send(uint64(2), &TestEvent{A: "start"}); err != nil { + t.Fatalf("%+v", err) + } + + <-th.done + err := smm.Stop(ctx) + assert.NilError(t, err) + <-th.notifDone + assert.Equal(t, uint64(2), th.count) + } + +} + var _ StateHandler = &testHandler{} var _ StateHandler = &testHandlerPartial{} var _ StateHandler = &testHandlerNoStateCB{} +var _ StateHandler = &testHandlerWithGoRoutine{}