Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/in order notifications cleanup #13

Merged
merged 3 commits into from
Jun 11, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 38 additions & 7 deletions fsm/fsm.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,9 @@ func NewFSMHandler(parameters Parameters) (statemachine.StateHandler, error) {
d.finalityStates[finalityState] = struct{}{}
}

d.initNotifier()
if d.notifier != nil {
d.notifications = make(chan notification)
}

return d, nil
}
Expand Down Expand Up @@ -117,15 +119,44 @@ func (d fsmHandler) reachedFinalityState(user interface{}) bool {
return final
}

// initNotifier will start up a goroutine which processes the notification queue
// Init will start up a goroutine which processes the notification queue
// in order
func (d *fsmHandler) initNotifier() {
func (d fsmHandler) Init(closing <-chan struct{}) {
if d.notifier != nil {
d.notifications = make(chan notification, NotificationQueueSize)

queue := make([]notification, 0, NotificationQueueSize)
toProcess := make(chan notification)
go func() {
for {
select {
case n := <-toProcess:
d.notifier(n.eventName, n.state)
case <-closing:
return
}
}
}()
go func() {
for n := range d.notifications {
d.notifier(n.eventName, n.state)
outgoing := func() chan<- notification {
if len(queue) == 0 {
return nil
}
return toProcess
}
nextNofication := func() notification {
if len(queue) == 0 {
return notification{}
}
return queue[0]
}
for {
select {
case n := <-d.notifications:
queue = append(queue, n)
case outgoing() <- nextNofication():
queue = queue[1:]
case <-closing:
return
}
}
}()
}
Expand Down
33 changes: 19 additions & 14 deletions fsm/fsm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@ import (
"context"
"fmt"
"sync"
"sync/atomic"
"testing"
"time"

"github.com/ipfs/go-datastore"
dss "github.com/ipfs/go-datastore/sync"
logging "github.com/ipfs/go-log"
"github.com/stretchr/testify/require"
"gotest.tools/assert"
Expand Down Expand Up @@ -57,7 +59,7 @@ var stateEntryFuncs = fsm.StateEntryFuncs{
}

func TestTypeCheckingOnSetup(t *testing.T) {
ds := datastore.NewMapDatastore()
ds := dss.MutexWrap(datastore.NewMapDatastore())
te := &testEnvironment{t: t, done: make(chan struct{}), proceed: make(chan struct{})}
t.Run("Bad state field", func(t *testing.T) {
smm, err := fsm.New(ds, fsm.Parameters{
Expand Down Expand Up @@ -350,7 +352,7 @@ func newFsm(ds datastore.Datastore, te *testEnvironment) (fsm.Group, error) {
}

func TestArgumentChecks(t *testing.T) {
ds := datastore.NewMapDatastore()
ds := dss.MutexWrap(datastore.NewMapDatastore())
te := &testEnvironment{t: t, done: make(chan struct{}), proceed: make(chan struct{})}
smm, err := newFsm(ds, te)
close(te.proceed)
Expand All @@ -372,7 +374,7 @@ func TestArgumentChecks(t *testing.T) {

func TestBasic(t *testing.T) {
for i := 0; i < 1000; i++ { // run a few times to expose any races
ds := datastore.NewMapDatastore()
ds := dss.MutexWrap(datastore.NewMapDatastore())

te := &testEnvironment{t: t, done: make(chan struct{}), proceed: make(chan struct{})}
close(te.proceed)
Expand All @@ -389,7 +391,7 @@ func TestBasic(t *testing.T) {

func TestPersist(t *testing.T) {
for i := 0; i < 1000; i++ { // run a few times to expose any races
ds := datastore.NewMapDatastore()
ds := dss.MutexWrap(datastore.NewMapDatastore())

te := &testEnvironment{t: t, done: make(chan struct{}), proceed: make(chan struct{})}
smm, err := newFsm(ds, te)
Expand All @@ -416,7 +418,7 @@ func TestPersist(t *testing.T) {

func TestSyncEventHandling(t *testing.T) {
ctx := context.Background()
ds := datastore.NewMapDatastore()
ds := dss.MutexWrap(datastore.NewMapDatastore())

te := &testEnvironment{t: t, done: make(chan struct{}), proceed: make(chan struct{})}
smm, err := newFsm(ds, te)
Expand All @@ -442,13 +444,13 @@ func TestSyncEventHandling(t *testing.T) {
}

func TestNotification(t *testing.T) {
notifications := 0
notifications := new(uint64)

var notifier fsm.Notifier = func(eventName fsm.EventName, state fsm.StateType) {
notifications++
atomic.AddUint64(notifications, 1)
}

ds := datastore.NewMapDatastore()
ds := dss.MutexWrap(datastore.NewMapDatastore())

te := &testEnvironment{t: t, done: make(chan struct{}), proceed: make(chan struct{}), universalCalls: 0}
close(te.proceed)
Expand All @@ -467,7 +469,8 @@ func TestNotification(t *testing.T) {
require.NoError(t, err)
<-te.done

require.Equal(t, notifications, 2)
total := atomic.LoadUint64(notifications)
require.Equal(t, total, uint64(2))
}

func TestSerialNotification(t *testing.T) {
Expand All @@ -488,14 +491,16 @@ func TestSerialNotification(t *testing.T) {
var notifications []string

wg := sync.WaitGroup{}
handleNotifications := make(chan struct{})
wg.Add(len(events))

var notifier fsm.Notifier = func(eventName fsm.EventName, state fsm.StateType) {
<-handleNotifications
notifications = append(notifications, eventName.(string))
wg.Done()
}

ds := datastore.NewMapDatastore()
ds := dss.MutexWrap(datastore.NewMapDatastore())
params := fsm.Parameters{
Environment: te,
StateType: statemachine.TestState{},
Expand All @@ -512,15 +517,15 @@ func TestSerialNotification(t *testing.T) {
err = smm.Send(uint64(2), eventName)
require.NoError(t, err)
}

close(handleNotifications)
wg.Wait()

// Expect that notifications happened in the order that the events happened
require.Equal(t, eventNames, notifications)
}

func TestNoChangeHandler(t *testing.T) {
ds := datastore.NewMapDatastore()
ds := dss.MutexWrap(datastore.NewMapDatastore())

te := &testEnvironment{t: t, done: make(chan struct{}), proceed: make(chan struct{}), universalCalls: 0}
close(te.proceed)
Expand All @@ -539,7 +544,7 @@ func TestNoChangeHandler(t *testing.T) {
}

func TestAllStateEvent(t *testing.T) {
ds := datastore.NewMapDatastore()
ds := dss.MutexWrap(datastore.NewMapDatastore())

te := &testEnvironment{t: t, done: make(chan struct{}), proceed: make(chan struct{}), universalCalls: 0}
close(te.proceed)
Expand All @@ -563,7 +568,7 @@ func TestFinalityStates(t *testing.T) {
ctx, cancel := context.WithTimeout(ctx, 1*time.Second)
defer cancel()

ds := datastore.NewMapDatastore()
ds := dss.MutexWrap(datastore.NewMapDatastore())

te := &testEnvironment{t: t, done: make(chan struct{}), proceed: make(chan struct{})}
smm, err := newFsm(ds, te)
Expand Down
19 changes: 18 additions & 1 deletion group.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,20 @@ type StateHandler interface {
Plan(events []Event, user interface{}) (interface{}, uint64, error)
}

type StateHandlerWithInit interface {
StateHandler
Init(<-chan struct{})
}

// StateGroup manages a group of state machines sharing the same logic
type StateGroup struct {
sts *statestore.StateStore
hnd StateHandler
stateType reflect.Type

closing chan struct{}
initNotifier sync.Once

lk sync.Mutex
sms map[datastore.Key]*StateMachine
}
Expand All @@ -31,8 +39,15 @@ func New(ds datastore.Datastore, hnd StateHandler, stateType interface{}) *State
sts: statestore.New(ds),
hnd: hnd,
stateType: reflect.TypeOf(stateType),
closing: make(chan struct{}),
sms: map[datastore.Key]*StateMachine{},
}
}

sms: map[datastore.Key]*StateMachine{},
func (s *StateGroup) init() {
initter, ok := s.hnd.(StateHandlerWithInit)
if ok {
initter.Init(s.closing)
}
}

Expand Down Expand Up @@ -85,6 +100,7 @@ func (s *StateGroup) Send(id interface{}, evt interface{}) (err error) {
}

func (s *StateGroup) loadOrCreate(name interface{}, userState interface{}) (*StateMachine, error) {
s.initNotifier.Do(s.init)
exists, err := s.sts.Has(name)
if err != nil {
return nil, xerrors.Errorf("failed to check if state for %v exists: %w", name, err)
Expand Down Expand Up @@ -130,6 +146,7 @@ func (s *StateGroup) Stop(ctx context.Context) error {
}
}

close(s.closing)
return nil
}

Expand Down
90 changes: 90 additions & 0 deletions machine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,96 @@ func (t *testHandlerNoStateCB) step1(ctx Context, st TestState) error {
return nil
}

type testHandlerWithGoRoutine struct {
t *testing.T
event chan struct{}
proceed chan struct{}
done chan struct{}
notifDone chan struct{}
count uint64
}

func (t *testHandlerWithGoRoutine) Plan(events []Event, state interface{}) (interface{}, uint64, error) {
return t.plan(events, state.(*TestState))
}

func (t *testHandlerWithGoRoutine) Init(onClose <-chan struct{}) {
go func() {
for {
select {
case <-t.event:
t.count++
case <-onClose:
close(t.notifDone)
return
}
}
}()
}

func (t *testHandlerWithGoRoutine) plan(events []Event, state *TestState) (func(Context, TestState) error, uint64, error) {
for _, event := range events {
e := event.User.(*TestEvent)
switch e.A {
case "restart":
case "start":
state.A = 1
case "b":
state.A = 2
state.B = e.Val
}
}

t.event <- struct{}{}
switch state.A {
case 1:
return t.step0, uint64(len(events)), nil
case 2:
return t.step1, uint64(len(events)), nil
default:
t.t.Fatal(state.A)
}
panic("how?")
}

func (t *testHandlerWithGoRoutine) step0(ctx Context, st TestState) error {
ctx.Send(&TestEvent{A: "b", Val: 55}) // nolint: errcheck
<-t.proceed
return nil
}

func (t *testHandlerWithGoRoutine) step1(ctx Context, st TestState) error {
assert.Equal(t.t, uint64(2), st.A)

close(t.done)
return nil
}

func TestInit(t *testing.T) {
ctx := context.Background()
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
for i := 0; i < 1000; i++ { // run a few times to expose any races
ds := datastore.NewMapDatastore()

th := &testHandlerWithGoRoutine{t: t, event: make(chan struct{}), notifDone: make(chan struct{}), done: make(chan struct{}), proceed: make(chan struct{})}
close(th.proceed)
smm := New(ds, th, TestState{})

if err := smm.Send(uint64(2), &TestEvent{A: "start"}); err != nil {
t.Fatalf("%+v", err)
}

<-th.done
err := smm.Stop(ctx)
assert.NilError(t, err)
<-th.notifDone
assert.Equal(t, uint64(2), th.count)
}

}

var _ StateHandler = &testHandler{}
var _ StateHandler = &testHandlerPartial{}
var _ StateHandler = &testHandlerNoStateCB{}
var _ StateHandler = &testHandlerWithGoRoutine{}