Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Easy FSM Add-on: Makes making Finite State Machines with go-statemachine easier #4

Merged
merged 20 commits into from
Feb 26, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ jobs:
default: golang
golangci-lint-version:
type: string
default: 1.17.1
default: 1.23.6
executor: << parameters.executor >>
steps:
- install-deps
Expand Down
130 changes: 130 additions & 0 deletions fsm/eventbuilder.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
package fsm

import "golang.org/x/xerrors"

type transitionToBuilder struct {
name EventName
action ActionFunc
transitionsSoFar map[StateKey]StateKey
nextFrom []StateKey
}

// To means the transition ends in the given state
func (t transitionToBuilder) To(to StateKey) EventBuilder {
transitions := t.transitionsSoFar
for _, from := range t.nextFrom {
transitions[from] = to
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general, what do you think about emitting a warning or error even if we are clobbering previously-set callbacks or transitions?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I hate to say it thought I'd rather not mess up the fluidity of the interface I might store the error and fail when you try to construct the event machine which is delayed but preserves the DSL.

}
return eventBuilder{t.name, t.action, transitions}
}

// ToNoChange means a transition ends in the same state it started in (just retriggers state cb)
func (t transitionToBuilder) ToNoChange() EventBuilder {
transitions := t.transitionsSoFar
for _, from := range t.nextFrom {
transitions[from] = nil
}
return eventBuilder{t.name, t.action, transitions}
}

type eventBuilder struct {
name EventName
action ActionFunc
transitionsSoFar map[StateKey]StateKey
}

// From begins describing a transition from a specific state
func (t eventBuilder) From(s StateKey) TransitionToBuilder {
_, ok := t.transitionsSoFar[s]
if ok {
return errBuilder{t.name, xerrors.Errorf("duplicate transition source `%v` for event `%v`", s, t.name)}
}
return transitionToBuilder{
t.name,
t.action,
t.transitionsSoFar,
[]StateKey{s},
}
}

// FromAny begins describing a transition from any state
func (t eventBuilder) FromAny() TransitionToBuilder {
_, ok := t.transitionsSoFar[nil]
if ok {
return errBuilder{t.name, xerrors.Errorf("duplicate all-sources destination for event `%v`", t.name)}
}
return transitionToBuilder{
t.name,
t.action,
t.transitionsSoFar,
[]StateKey{nil},
}
}

// FromMany begins describing a transition from many states
func (t eventBuilder) FromMany(sources ...StateKey) TransitionToBuilder {
for _, source := range sources {
_, ok := t.transitionsSoFar[source]
if ok {
return errBuilder{t.name, xerrors.Errorf("duplicate transition source `%v` for event `%v`", source, t.name)}
}
}
return transitionToBuilder{
t.name,
t.action,
t.transitionsSoFar,
sources,
}
}

// Action describes actions taken on the state for this event
func (t eventBuilder) Action(action ActionFunc) EventBuilder {
if t.action != nil {
return errBuilder{t.name, xerrors.Errorf("duplicate action for event `%v`", t.name)}
}
return eventBuilder{
t.name,
action,
t.transitionsSoFar,
}
}

type errBuilder struct {
name EventName
err error
}

// From passes on the error
func (e errBuilder) From(s StateKey) TransitionToBuilder {
return e
}

// FromAny passes on the error
func (e errBuilder) FromAny() TransitionToBuilder {
return e
}

// FromMany passes on the error
func (e errBuilder) FromMany(sources ...StateKey) TransitionToBuilder {
return e
}

// Action passes on the error
func (e errBuilder) Action(action ActionFunc) EventBuilder {
return e
}

// To passes on the error
func (e errBuilder) To(_ StateKey) EventBuilder {
return e
}

// ToNoChange passes on the error
func (e errBuilder) ToNoChange() EventBuilder {
return e
}

// Event starts building a new event
func Event(name EventName) EventBuilder {
return eventBuilder{name, nil, map[StateKey]StateKey{}}
}
198 changes: 198 additions & 0 deletions fsm/eventprocessor.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
package fsm

import (
"context"
"reflect"

"github.com/filecoin-project/go-statemachine"
"golang.org/x/xerrors"
)

// EventProcessor creates and applies events for go-statemachine based on the given event list
type EventProcessor interface {
// Event generates an event that can be dispatched to go-statemachine from the given event name and context args
Generate(ctx context.Context, event EventName, returnChannel chan error, args ...interface{}) (interface{}, error)
// Apply applies the given event from go-statemachine to the given state, based on transition rules
Apply(evt statemachine.Event, user interface{}) (EventName, error)
}

type eventProcessor struct {
stateType reflect.Type
stateKeyField StateKeyField
callbacks map[EventName]callback
transitions map[eKey]StateKey
}

// eKey is a struct key used for storing the transition map.
type eKey struct {
// event is the name of the event that the keys refers to.
event EventName

// src is the source from where the event can transition.
src interface{}
}

// callback stores a transition function and its argument types
type callback struct {
argumentTypes []reflect.Type
action ActionFunc
}

// fsmEvent is the internal event type
type fsmEvent struct {
name EventName
args []interface{}
ctx context.Context
returnChannel chan error
}

// NewEventProcessor returns a new event machine for the given state and event list
func NewEventProcessor(state StateType, stateKeyField StateKeyField, events []EventBuilder) (EventProcessor, error) {
stateType := reflect.TypeOf(state)
stateFieldType, ok := stateType.FieldByName(string(stateKeyField))
if !ok {
return nil, xerrors.Errorf("state type has no field `%s`", stateKeyField)
}
if !stateFieldType.Type.Comparable() {
return nil, xerrors.Errorf("state field `%s` is not comparable", stateKeyField)
}

em := eventProcessor{
stateType: stateType,
stateKeyField: stateKeyField,
callbacks: make(map[EventName]callback),
transitions: make(map[eKey]StateKey),
}

// Build transition map and store sets of all events and states.
for _, evtIface := range events {
evt, ok := evtIface.(eventBuilder)
if !ok {
errEvt := evtIface.(errBuilder)
return nil, errEvt.err
}

name := evt.name

_, exists := em.callbacks[name]
if exists {
return nil, xerrors.Errorf("Duplicate event name `%+v`", name)
}

argumentTypes, err := inspectActionFunc(name, evt.action, stateType)
if err != nil {
return nil, err
}
em.callbacks[name] = callback{
argumentTypes: argumentTypes,
action: evt.action,
}
for src, dst := range evt.transitionsSoFar {
if dst != nil && !reflect.TypeOf(dst).AssignableTo(stateFieldType.Type) {
return nil, xerrors.Errorf("event `%+v` destination type is not assignable to: %s", name, stateFieldType.Type.Name())
}
if src != nil && !reflect.TypeOf(src).AssignableTo(stateFieldType.Type) {
return nil, xerrors.Errorf("event `%+v` source type is not assignable to: %s", name, stateFieldType.Type.Name())
}
em.transitions[eKey{name, src}] = dst
}
}
return em, nil
}

// Event generates an event that can be dispatched to go-statemachine from the given event name and context args
func (em eventProcessor) Generate(ctx context.Context, event EventName, returnChannel chan error, args ...interface{}) (interface{}, error) {
cb, ok := em.callbacks[event]
if !ok {
return fsmEvent{}, xerrors.Errorf("Unknown event `%+v`", event)
}
if len(args) != len(cb.argumentTypes) {
return fsmEvent{}, xerrors.Errorf("Wrong number of arguments for event `%+v`", event)
}
for i, arg := range args {
if !reflect.TypeOf(arg).AssignableTo(cb.argumentTypes[i]) {
return fsmEvent{}, xerrors.Errorf("Incorrect argument type at index `%d` for event `%+v`", i, event)
}
}
return fsmEvent{event, args, ctx, returnChannel}, nil
}

func (em eventProcessor) Apply(evt statemachine.Event, user interface{}) (EventName, error) {
userValue := reflect.ValueOf(user)
currentState := userValue.Elem().FieldByName(string(em.stateKeyField)).Interface()
e, ok := evt.User.(fsmEvent)
if !ok {
return nil, xerrors.New("Not an fsm event")
}
destination, ok := em.transitions[eKey{e.name, currentState}]
// check for fallback transition for any source state
if !ok {
destination, ok = em.transitions[eKey{e.name, nil}]
}
if !ok {
return nil, completeEvent(e, xerrors.Errorf("Invalid transition in queue, state `%+v`, event `%+v`", currentState, e.name))
}
cb := em.callbacks[e.name]
err := applyAction(userValue, e, cb)
if err != nil {
return nil, completeEvent(e, err)
}
if destination != nil {
userValue.Elem().FieldByName(string(em.stateKeyField)).Set(reflect.ValueOf(destination))
}

return e.name, completeEvent(e, nil)
}

// Apply applies the given event from go-statemachine to the given state, based on transition rules
func applyAction(userValue reflect.Value, e fsmEvent, cb callback) error {
if cb.action == nil {
return nil
}
values := make([]reflect.Value, 0, len(e.args)+1)
values = append(values, userValue)
for _, arg := range e.args {
values = append(values, reflect.ValueOf(arg))
}
res := reflect.ValueOf(cb.action).Call(values)

if res[0].Interface() != nil {
return xerrors.Errorf("Error applying event transition `%+v`: %w", e.name, res[0].Interface().(error))
}
return nil
}

func completeEvent(event fsmEvent, err error) error {
if event.returnChannel != nil {
select {
case <-event.ctx.Done():
case event.returnChannel <- err:
}
}
return err
}

func inspectActionFunc(name EventName, action ActionFunc, stateType reflect.Type) ([]reflect.Type, error) {
if action == nil {
return nil, nil
}

atType := reflect.TypeOf(action)
if atType.Kind() != reflect.Func {
return nil, xerrors.Errorf("event `%+v` has a callback that is not a function", name)
}
if atType.NumIn() < 1 {
return nil, xerrors.Errorf("event `%+v` has a callback that does not take the state", name)
}
if !reflect.PtrTo(stateType).AssignableTo(atType.In(0)) {
return nil, xerrors.Errorf("event `%+v` has a callback that does not take the state", name)
}
if atType.NumOut() != 1 || atType.Out(0).AssignableTo(reflect.TypeOf(new(error))) {
return nil, xerrors.Errorf("event `%+v` callback should return exactly one param that is an error", name)
}
argumentTypes := make([]reflect.Type, atType.NumIn()-1)
for i := range argumentTypes {
argumentTypes[i] = atType.In(i + 1)
}
return argumentTypes, nil
}
Loading