From 683371f7791792c7b4e74d18321e8d08039780d4 Mon Sep 17 00:00:00 2001 From: Aaron Craelius Date: Thu, 25 Jul 2024 15:53:26 +0200 Subject: [PATCH] feat(schema/appdata): async listener mux'ing (#20879) Co-authored-by: cool-developer <51834436+cool-develope@users.noreply.github.com> --- schema/appdata/async.go | 162 +++++++++++++++++++++++++++++++++++ schema/appdata/async_test.go | 148 ++++++++++++++++++++++++++++++++ schema/appdata/mux.go | 128 +++++++++++++++++++++++++++ schema/appdata/mux_test.go | 131 ++++++++++++++++++++++++++++ 4 files changed, 569 insertions(+) create mode 100644 schema/appdata/async.go create mode 100644 schema/appdata/async_test.go create mode 100644 schema/appdata/mux.go create mode 100644 schema/appdata/mux_test.go diff --git a/schema/appdata/async.go b/schema/appdata/async.go new file mode 100644 index 000000000000..0d4126ed2a03 --- /dev/null +++ b/schema/appdata/async.go @@ -0,0 +1,162 @@ +package appdata + +import ( + "context" + "sync" +) + +// AsyncListenerOptions are options for async listeners and listener mux's. +type AsyncListenerOptions struct { + // Context is the context whose Done() channel listeners use will use to listen for completion to close their + // goroutine. If it is nil, then context.Background() will be used and goroutines may be leaked. + Context context.Context + + // BufferSize is the buffer size of the channels to use. It defaults to 0. + BufferSize int + + // DoneWaitGroup is an optional wait-group that listener goroutines will notify via Add(1) when they are started + // and Done() after they are cancelled and completed. + DoneWaitGroup *sync.WaitGroup +} + +// AsyncListenerMux returns a listener that forwards received events to all the provided listeners asynchronously +// with each listener processing in a separate go routine. All callbacks in the returned listener will return nil +// except for Commit which will return an error or nil once all listeners have processed the commit. The context +// is used to signal that the listeners should stop listening and return. bufferSize is the size of the buffer for the +// channels used to send events to the listeners. +func AsyncListenerMux(opts AsyncListenerOptions, listeners ...Listener) Listener { + asyncListeners := make([]Listener, len(listeners)) + commitChans := make([]chan error, len(listeners)) + for i, l := range listeners { + commitChan := make(chan error) + commitChans[i] = commitChan + asyncListeners[i] = AsyncListener(opts, commitChan, l) + } + mux := ListenerMux(asyncListeners...) + muxCommit := mux.Commit + mux.Commit = func(data CommitData) error { + if muxCommit != nil { + err := muxCommit(data) + if err != nil { + return err + } + } + + for _, commitChan := range commitChans { + err := <-commitChan + if err != nil { + return err + } + } + return nil + } + + return mux +} + +// AsyncListener returns a listener that forwards received events to the provided listener listening in asynchronously +// in a separate go routine. The listener that is returned will return nil for all methods including Commit and +// an error or nil will only be returned in commitChan once the sender has sent commit and the receiving listener has +// processed it. Thus commitChan can be used as a synchronization and error checking mechanism. The go routine +// that is being used for listening will exit when context.Done() returns and no more events will be received by the listener. +// bufferSize is the size of the buffer for the channel that is used to send events to the listener. +// Instead of using AsyncListener directly, it is recommended to use AsyncListenerMux which does coordination directly +// via its Commit callback. +func AsyncListener(opts AsyncListenerOptions, commitChan chan<- error, listener Listener) Listener { + packetChan := make(chan Packet, opts.BufferSize) + res := Listener{} + ctx := opts.Context + if ctx == nil { + ctx = context.Background() + } + done := ctx.Done() + + go func() { + if opts.DoneWaitGroup != nil { + opts.DoneWaitGroup.Add(1) + } + + var err error + for { + select { + case packet := <-packetChan: + if err != nil { + // if we have an error, don't process any more packets + // and return the error and finish when it's time to commit + if _, ok := packet.(CommitData); ok { + commitChan <- err + return + } + } else { + // process the packet + err = listener.SendPacket(packet) + // if it's a commit + if _, ok := packet.(CommitData); ok { + commitChan <- err + if err != nil { + return + } + } + } + + case <-done: + close(packetChan) + if opts.DoneWaitGroup != nil { + opts.DoneWaitGroup.Done() + } + return + } + } + }() + + if listener.InitializeModuleData != nil { + res.InitializeModuleData = func(data ModuleInitializationData) error { + packetChan <- data + return nil + } + } + + if listener.StartBlock != nil { + res.StartBlock = func(data StartBlockData) error { + packetChan <- data + return nil + } + } + + if listener.OnTx != nil { + res.OnTx = func(data TxData) error { + packetChan <- data + return nil + } + } + + if listener.OnEvent != nil { + res.OnEvent = func(data EventData) error { + packetChan <- data + return nil + } + } + + if listener.OnKVPair != nil { + res.OnKVPair = func(data KVPairData) error { + packetChan <- data + return nil + } + } + + if listener.OnObjectUpdate != nil { + res.OnObjectUpdate = func(data ObjectUpdateData) error { + packetChan <- data + return nil + } + } + + if listener.Commit != nil { + res.Commit = func(data CommitData) error { + packetChan <- data + return nil + } + } + + return res +} diff --git a/schema/appdata/async_test.go b/schema/appdata/async_test.go new file mode 100644 index 000000000000..c1df2d4ca6c0 --- /dev/null +++ b/schema/appdata/async_test.go @@ -0,0 +1,148 @@ +package appdata + +import ( + "context" + "fmt" + "sync" + "testing" +) + +func TestAsyncListenerMux(t *testing.T) { + t.Run("empty", func(t *testing.T) { + listener := AsyncListenerMux(AsyncListenerOptions{}, Listener{}, Listener{}) + + if listener.InitializeModuleData != nil { + t.Error("expected nil") + } + if listener.StartBlock != nil { + t.Error("expected nil") + } + if listener.OnTx != nil { + t.Error("expected nil") + } + if listener.OnEvent != nil { + t.Error("expected nil") + } + if listener.OnKVPair != nil { + t.Error("expected nil") + } + if listener.OnObjectUpdate != nil { + t.Error("expected nil") + } + + // commit is not expected to be nil + }) + + t.Run("call cancel", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + wg := &sync.WaitGroup{} + var calls1, calls2 []string + listener1 := callCollector(1, func(name string, _ int, _ Packet) { + calls1 = append(calls1, name) + }) + listener2 := callCollector(2, func(name string, _ int, _ Packet) { + calls2 = append(calls2, name) + }) + res := AsyncListenerMux(AsyncListenerOptions{ + BufferSize: 16, Context: ctx, DoneWaitGroup: wg, + }, listener1, listener2) + + callAllCallbacksOnces(t, res) + + expectedCalls := []string{ + "InitializeModuleData", + "StartBlock", + "OnTx", + "OnEvent", + "OnKVPair", + "OnObjectUpdate", + "Commit", + } + + checkExpectedCallOrder(t, calls1, expectedCalls) + checkExpectedCallOrder(t, calls2, expectedCalls) + + // cancel and expect the test to finish - if all goroutines aren't canceled the test will hang + cancel() + wg.Wait() + }) + + t.Run("error on commit", func(t *testing.T) { + var calls1, calls2 []string + listener1 := callCollector(1, func(name string, _ int, _ Packet) { + calls1 = append(calls1, name) + }) + listener1.Commit = func(data CommitData) error { + return fmt.Errorf("error") + } + listener2 := callCollector(2, func(name string, _ int, _ Packet) { + calls2 = append(calls2, name) + }) + res := AsyncListenerMux(AsyncListenerOptions{}, listener1, listener2) + + err := res.Commit(CommitData{}) + if err == nil || err.Error() != "error" { + t.Fatalf("expected error, got %v", err) + } + }) +} + +func TestAsyncListener(t *testing.T) { + t.Run("call cancel", func(t *testing.T) { + commitChan := make(chan error) + ctx, cancel := context.WithCancel(context.Background()) + wg := &sync.WaitGroup{} + var calls []string + listener := callCollector(1, func(name string, _ int, _ Packet) { + calls = append(calls, name) + }) + res := AsyncListener(AsyncListenerOptions{BufferSize: 16, Context: ctx, DoneWaitGroup: wg}, + commitChan, listener) + + callAllCallbacksOnces(t, res) + + err := <-commitChan + if err != nil { + t.Fatalf("expected nil, got %v", err) + } + + checkExpectedCallOrder(t, calls, []string{ + "InitializeModuleData", + "StartBlock", + "OnTx", + "OnEvent", + "OnKVPair", + "OnObjectUpdate", + "Commit", + }) + + calls = nil + + // expect wait group to return after cancel is called + cancel() + wg.Wait() + }) + + t.Run("error", func(t *testing.T) { + commitChan := make(chan error) + var calls []string + listener := callCollector(1, func(name string, _ int, _ Packet) { + calls = append(calls, name) + }) + + listener.OnKVPair = func(updates KVPairData) error { + return fmt.Errorf("error") + } + + res := AsyncListener(AsyncListenerOptions{BufferSize: 16}, commitChan, listener) + + callAllCallbacksOnces(t, res) + + err := <-commitChan + if err == nil || err.Error() != "error" { + t.Fatalf("expected error, got %v", err) + } + + checkExpectedCallOrder(t, calls, []string{"InitializeModuleData", "StartBlock", "OnTx", "OnEvent"}) + }) +} diff --git a/schema/appdata/mux.go b/schema/appdata/mux.go new file mode 100644 index 000000000000..8e6b886577d2 --- /dev/null +++ b/schema/appdata/mux.go @@ -0,0 +1,128 @@ +package appdata + +// ListenerMux returns a listener that forwards received events to all the provided listeners in order. +// A callback is only registered if a non-nil callback is present in at least one of the listeners. +func ListenerMux(listeners ...Listener) Listener { + mux := Listener{} + + initModDataCbs := make([]func(ModuleInitializationData) error, 0, len(listeners)) + for _, l := range listeners { + if l.InitializeModuleData != nil { + initModDataCbs = append(initModDataCbs, l.InitializeModuleData) + } + } + if len(initModDataCbs) > 0 { + mux.InitializeModuleData = func(data ModuleInitializationData) error { + for _, cb := range initModDataCbs { + if err := cb(data); err != nil { + return err + } + } + return nil + } + } + + startBlockCbs := make([]func(StartBlockData) error, 0, len(listeners)) + for _, l := range listeners { + if l.StartBlock != nil { + startBlockCbs = append(startBlockCbs, l.StartBlock) + } + } + if len(startBlockCbs) > 0 { + mux.StartBlock = func(data StartBlockData) error { + for _, cb := range startBlockCbs { + if err := cb(data); err != nil { + return err + } + } + return nil + } + } + + onTxCbs := make([]func(TxData) error, 0, len(listeners)) + for _, l := range listeners { + if l.OnTx != nil { + onTxCbs = append(onTxCbs, l.OnTx) + } + } + if len(onTxCbs) > 0 { + mux.OnTx = func(data TxData) error { + for _, cb := range onTxCbs { + if err := cb(data); err != nil { + return err + } + } + return nil + } + } + + onEventCbs := make([]func(EventData) error, 0, len(listeners)) + for _, l := range listeners { + if l.OnEvent != nil { + onEventCbs = append(onEventCbs, l.OnEvent) + } + } + if len(onEventCbs) > 0 { + mux.OnEvent = func(data EventData) error { + for _, cb := range onEventCbs { + if err := cb(data); err != nil { + return err + } + } + return nil + } + } + + onKvPairCbs := make([]func(KVPairData) error, 0, len(listeners)) + for _, l := range listeners { + if l.OnKVPair != nil { + onKvPairCbs = append(onKvPairCbs, l.OnKVPair) + } + } + if len(onKvPairCbs) > 0 { + mux.OnKVPair = func(data KVPairData) error { + for _, cb := range onKvPairCbs { + if err := cb(data); err != nil { + return err + } + } + return nil + } + } + + onObjectUpdateCbs := make([]func(ObjectUpdateData) error, 0, len(listeners)) + for _, l := range listeners { + if l.OnObjectUpdate != nil { + onObjectUpdateCbs = append(onObjectUpdateCbs, l.OnObjectUpdate) + } + } + if len(onObjectUpdateCbs) > 0 { + mux.OnObjectUpdate = func(data ObjectUpdateData) error { + for _, cb := range onObjectUpdateCbs { + if err := cb(data); err != nil { + return err + } + } + return nil + } + } + + commitCbs := make([]func(CommitData) error, 0, len(listeners)) + for _, l := range listeners { + if l.Commit != nil { + commitCbs = append(commitCbs, l.Commit) + } + } + if len(commitCbs) > 0 { + mux.Commit = func(data CommitData) error { + for _, cb := range commitCbs { + if err := cb(data); err != nil { + return err + } + } + return nil + } + } + + return mux +} diff --git a/schema/appdata/mux_test.go b/schema/appdata/mux_test.go new file mode 100644 index 000000000000..b5e3a95dd569 --- /dev/null +++ b/schema/appdata/mux_test.go @@ -0,0 +1,131 @@ +package appdata + +import ( + "fmt" + "testing" +) + +func TestListenerMux(t *testing.T) { + t.Run("empty", func(t *testing.T) { + listener := ListenerMux(Listener{}, Listener{}) + + if listener.InitializeModuleData != nil { + t.Error("expected nil") + } + if listener.StartBlock != nil { + t.Error("expected nil") + } + if listener.OnTx != nil { + t.Error("expected nil") + } + if listener.OnEvent != nil { + t.Error("expected nil") + } + if listener.OnKVPair != nil { + t.Error("expected nil") + } + if listener.OnObjectUpdate != nil { + t.Error("expected nil") + } + if listener.Commit != nil { + t.Error("expected nil") + } + }) + + t.Run("all called once", func(t *testing.T) { + var calls []string + onCall := func(name string, i int, _ Packet) { + calls = append(calls, fmt.Sprintf("%s %d", name, i)) + } + + res := ListenerMux(callCollector(1, onCall), callCollector(2, onCall)) + + callAllCallbacksOnces(t, res) + + checkExpectedCallOrder(t, calls, []string{ + "InitializeModuleData 1", + "InitializeModuleData 2", + "StartBlock 1", + "StartBlock 2", + "OnTx 1", + "OnTx 2", + "OnEvent 1", + "OnEvent 2", + "OnKVPair 1", + "OnKVPair 2", + "OnObjectUpdate 1", + "OnObjectUpdate 2", + "Commit 1", + "Commit 2", + }) + }) +} + +func callAllCallbacksOnces(t *testing.T, listener Listener) { + if err := listener.InitializeModuleData(ModuleInitializationData{}); err != nil { + t.Error(err) + } + if err := listener.StartBlock(StartBlockData{}); err != nil { + t.Error(err) + } + if err := listener.OnTx(TxData{}); err != nil { + t.Error(err) + } + if err := listener.OnEvent(EventData{}); err != nil { + t.Error(err) + } + if err := listener.OnKVPair(KVPairData{}); err != nil { + t.Error(err) + } + if err := listener.OnObjectUpdate(ObjectUpdateData{}); err != nil { + t.Error(err) + } + if err := listener.Commit(CommitData{}); err != nil { + t.Error(err) + } +} + +func callCollector(i int, onCall func(string, int, Packet)) Listener { + return Listener{ + InitializeModuleData: func(ModuleInitializationData) error { + onCall("InitializeModuleData", i, nil) + return nil + }, + StartBlock: func(StartBlockData) error { + onCall("StartBlock", i, nil) + return nil + }, + OnTx: func(TxData) error { + onCall("OnTx", i, nil) + return nil + }, + OnEvent: func(EventData) error { + onCall("OnEvent", i, nil) + return nil + }, + OnKVPair: func(KVPairData) error { + onCall("OnKVPair", i, nil) + return nil + }, + OnObjectUpdate: func(ObjectUpdateData) error { + onCall("OnObjectUpdate", i, nil) + return nil + }, + Commit: func(CommitData) error { + onCall("Commit", i, nil) + return nil + }, + } +} + +func checkExpectedCallOrder(t *testing.T, actual, expected []string) { + if len(actual) != len(expected) { + t.Fatalf("expected %d calls, got %d", len(expected), len(actual)) + } + + for i := range actual { + if actual[i] != expected[i] { + t.Errorf("expected %q, got %q", expected[i], actual[i]) + } + } +}