Skip to content

Commit

Permalink
Merge pull request #13 from filecoin-project/fix/in-order-notificatio…
Browse files Browse the repository at this point in the history
…ns-cleanup

Fix/in order notifications cleanup
  • Loading branch information
ingar authored Jun 11, 2020
2 parents aaeab1d + fb55e0c commit c290494
Show file tree
Hide file tree
Showing 4 changed files with 165 additions and 22 deletions.
45 changes: 38 additions & 7 deletions fsm/fsm.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,9 @@ func NewFSMHandler(parameters Parameters) (statemachine.StateHandler, error) {
d.finalityStates[finalityState] = struct{}{}
}

d.initNotifier()
if d.notifier != nil {
d.notifications = make(chan notification)
}

return d, nil
}
Expand Down Expand Up @@ -117,15 +119,44 @@ func (d fsmHandler) reachedFinalityState(user interface{}) bool {
return final
}

// initNotifier will start up a goroutine which processes the notification queue
// Init will start up a goroutine which processes the notification queue
// in order
func (d *fsmHandler) initNotifier() {
func (d fsmHandler) Init(closing <-chan struct{}) {
if d.notifier != nil {
d.notifications = make(chan notification, NotificationQueueSize)

queue := make([]notification, 0, NotificationQueueSize)
toProcess := make(chan notification)
go func() {
for {
select {
case n := <-toProcess:
d.notifier(n.eventName, n.state)
case <-closing:
return
}
}
}()
go func() {
for n := range d.notifications {
d.notifier(n.eventName, n.state)
outgoing := func() chan<- notification {
if len(queue) == 0 {
return nil
}
return toProcess
}
nextNofication := func() notification {
if len(queue) == 0 {
return notification{}
}
return queue[0]
}
for {
select {
case n := <-d.notifications:
queue = append(queue, n)
case outgoing() <- nextNofication():
queue = queue[1:]
case <-closing:
return
}
}
}()
}
Expand Down
33 changes: 19 additions & 14 deletions fsm/fsm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@ import (
"context"
"fmt"
"sync"
"sync/atomic"
"testing"
"time"

"github.com/ipfs/go-datastore"
dss "github.com/ipfs/go-datastore/sync"
logging "github.com/ipfs/go-log"
"github.com/stretchr/testify/require"
"gotest.tools/assert"
Expand Down Expand Up @@ -57,7 +59,7 @@ var stateEntryFuncs = fsm.StateEntryFuncs{
}

func TestTypeCheckingOnSetup(t *testing.T) {
ds := datastore.NewMapDatastore()
ds := dss.MutexWrap(datastore.NewMapDatastore())
te := &testEnvironment{t: t, done: make(chan struct{}), proceed: make(chan struct{})}
t.Run("Bad state field", func(t *testing.T) {
smm, err := fsm.New(ds, fsm.Parameters{
Expand Down Expand Up @@ -350,7 +352,7 @@ func newFsm(ds datastore.Datastore, te *testEnvironment) (fsm.Group, error) {
}

func TestArgumentChecks(t *testing.T) {
ds := datastore.NewMapDatastore()
ds := dss.MutexWrap(datastore.NewMapDatastore())
te := &testEnvironment{t: t, done: make(chan struct{}), proceed: make(chan struct{})}
smm, err := newFsm(ds, te)
close(te.proceed)
Expand All @@ -372,7 +374,7 @@ func TestArgumentChecks(t *testing.T) {

func TestBasic(t *testing.T) {
for i := 0; i < 1000; i++ { // run a few times to expose any races
ds := datastore.NewMapDatastore()
ds := dss.MutexWrap(datastore.NewMapDatastore())

te := &testEnvironment{t: t, done: make(chan struct{}), proceed: make(chan struct{})}
close(te.proceed)
Expand All @@ -389,7 +391,7 @@ func TestBasic(t *testing.T) {

func TestPersist(t *testing.T) {
for i := 0; i < 1000; i++ { // run a few times to expose any races
ds := datastore.NewMapDatastore()
ds := dss.MutexWrap(datastore.NewMapDatastore())

te := &testEnvironment{t: t, done: make(chan struct{}), proceed: make(chan struct{})}
smm, err := newFsm(ds, te)
Expand All @@ -416,7 +418,7 @@ func TestPersist(t *testing.T) {

func TestSyncEventHandling(t *testing.T) {
ctx := context.Background()
ds := datastore.NewMapDatastore()
ds := dss.MutexWrap(datastore.NewMapDatastore())

te := &testEnvironment{t: t, done: make(chan struct{}), proceed: make(chan struct{})}
smm, err := newFsm(ds, te)
Expand All @@ -442,13 +444,13 @@ func TestSyncEventHandling(t *testing.T) {
}

func TestNotification(t *testing.T) {
notifications := 0
notifications := new(uint64)

var notifier fsm.Notifier = func(eventName fsm.EventName, state fsm.StateType) {
notifications++
atomic.AddUint64(notifications, 1)
}

ds := datastore.NewMapDatastore()
ds := dss.MutexWrap(datastore.NewMapDatastore())

te := &testEnvironment{t: t, done: make(chan struct{}), proceed: make(chan struct{}), universalCalls: 0}
close(te.proceed)
Expand All @@ -467,7 +469,8 @@ func TestNotification(t *testing.T) {
require.NoError(t, err)
<-te.done

require.Equal(t, notifications, 2)
total := atomic.LoadUint64(notifications)
require.Equal(t, total, uint64(2))
}

func TestSerialNotification(t *testing.T) {
Expand All @@ -488,14 +491,16 @@ func TestSerialNotification(t *testing.T) {
var notifications []string

wg := sync.WaitGroup{}
handleNotifications := make(chan struct{})
wg.Add(len(events))

var notifier fsm.Notifier = func(eventName fsm.EventName, state fsm.StateType) {
<-handleNotifications
notifications = append(notifications, eventName.(string))
wg.Done()
}

ds := datastore.NewMapDatastore()
ds := dss.MutexWrap(datastore.NewMapDatastore())
params := fsm.Parameters{
Environment: te,
StateType: statemachine.TestState{},
Expand All @@ -512,15 +517,15 @@ func TestSerialNotification(t *testing.T) {
err = smm.Send(uint64(2), eventName)
require.NoError(t, err)
}

close(handleNotifications)
wg.Wait()

// Expect that notifications happened in the order that the events happened
require.Equal(t, eventNames, notifications)
}

func TestNoChangeHandler(t *testing.T) {
ds := datastore.NewMapDatastore()
ds := dss.MutexWrap(datastore.NewMapDatastore())

te := &testEnvironment{t: t, done: make(chan struct{}), proceed: make(chan struct{}), universalCalls: 0}
close(te.proceed)
Expand All @@ -539,7 +544,7 @@ func TestNoChangeHandler(t *testing.T) {
}

func TestAllStateEvent(t *testing.T) {
ds := datastore.NewMapDatastore()
ds := dss.MutexWrap(datastore.NewMapDatastore())

te := &testEnvironment{t: t, done: make(chan struct{}), proceed: make(chan struct{}), universalCalls: 0}
close(te.proceed)
Expand All @@ -563,7 +568,7 @@ func TestFinalityStates(t *testing.T) {
ctx, cancel := context.WithTimeout(ctx, 1*time.Second)
defer cancel()

ds := datastore.NewMapDatastore()
ds := dss.MutexWrap(datastore.NewMapDatastore())

te := &testEnvironment{t: t, done: make(chan struct{}), proceed: make(chan struct{})}
smm, err := newFsm(ds, te)
Expand Down
19 changes: 18 additions & 1 deletion group.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,20 @@ type StateHandler interface {
Plan(events []Event, user interface{}) (interface{}, uint64, error)
}

type StateHandlerWithInit interface {
StateHandler
Init(<-chan struct{})
}

// StateGroup manages a group of state machines sharing the same logic
type StateGroup struct {
sts *statestore.StateStore
hnd StateHandler
stateType reflect.Type

closing chan struct{}
initNotifier sync.Once

lk sync.Mutex
sms map[datastore.Key]*StateMachine
}
Expand All @@ -31,8 +39,15 @@ func New(ds datastore.Datastore, hnd StateHandler, stateType interface{}) *State
sts: statestore.New(ds),
hnd: hnd,
stateType: reflect.TypeOf(stateType),
closing: make(chan struct{}),
sms: map[datastore.Key]*StateMachine{},
}
}

sms: map[datastore.Key]*StateMachine{},
func (s *StateGroup) init() {
initter, ok := s.hnd.(StateHandlerWithInit)
if ok {
initter.Init(s.closing)
}
}

Expand Down Expand Up @@ -85,6 +100,7 @@ func (s *StateGroup) Send(id interface{}, evt interface{}) (err error) {
}

func (s *StateGroup) loadOrCreate(name interface{}, userState interface{}) (*StateMachine, error) {
s.initNotifier.Do(s.init)
exists, err := s.sts.Has(name)
if err != nil {
return nil, xerrors.Errorf("failed to check if state for %v exists: %w", name, err)
Expand Down Expand Up @@ -130,6 +146,7 @@ func (s *StateGroup) Stop(ctx context.Context) error {
}
}

close(s.closing)
return nil
}

Expand Down
90 changes: 90 additions & 0 deletions machine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,96 @@ func (t *testHandlerNoStateCB) step1(ctx Context, st TestState) error {
return nil
}

type testHandlerWithGoRoutine struct {
t *testing.T
event chan struct{}
proceed chan struct{}
done chan struct{}
notifDone chan struct{}
count uint64
}

func (t *testHandlerWithGoRoutine) Plan(events []Event, state interface{}) (interface{}, uint64, error) {
return t.plan(events, state.(*TestState))
}

func (t *testHandlerWithGoRoutine) Init(onClose <-chan struct{}) {
go func() {
for {
select {
case <-t.event:
t.count++
case <-onClose:
close(t.notifDone)
return
}
}
}()
}

func (t *testHandlerWithGoRoutine) plan(events []Event, state *TestState) (func(Context, TestState) error, uint64, error) {
for _, event := range events {
e := event.User.(*TestEvent)
switch e.A {
case "restart":
case "start":
state.A = 1
case "b":
state.A = 2
state.B = e.Val
}
}

t.event <- struct{}{}
switch state.A {
case 1:
return t.step0, uint64(len(events)), nil
case 2:
return t.step1, uint64(len(events)), nil
default:
t.t.Fatal(state.A)
}
panic("how?")
}

func (t *testHandlerWithGoRoutine) step0(ctx Context, st TestState) error {
ctx.Send(&TestEvent{A: "b", Val: 55}) // nolint: errcheck
<-t.proceed
return nil
}

func (t *testHandlerWithGoRoutine) step1(ctx Context, st TestState) error {
assert.Equal(t.t, uint64(2), st.A)

close(t.done)
return nil
}

func TestInit(t *testing.T) {
ctx := context.Background()
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
for i := 0; i < 1000; i++ { // run a few times to expose any races
ds := datastore.NewMapDatastore()

th := &testHandlerWithGoRoutine{t: t, event: make(chan struct{}), notifDone: make(chan struct{}), done: make(chan struct{}), proceed: make(chan struct{})}
close(th.proceed)
smm := New(ds, th, TestState{})

if err := smm.Send(uint64(2), &TestEvent{A: "start"}); err != nil {
t.Fatalf("%+v", err)
}

<-th.done
err := smm.Stop(ctx)
assert.NilError(t, err)
<-th.notifDone
assert.Equal(t, uint64(2), th.count)
}

}

var _ StateHandler = &testHandler{}
var _ StateHandler = &testHandlerPartial{}
var _ StateHandler = &testHandlerNoStateCB{}
var _ StateHandler = &testHandlerWithGoRoutine{}

0 comments on commit c290494

Please sign in to comment.