diff --git a/schema/appdata/async.go b/schema/appdata/async.go index f76ab034909a..0d4126ed2a03 100644 --- a/schema/appdata/async.go +++ b/schema/appdata/async.go @@ -1,19 +1,36 @@ package appdata -import "context" +import ( + "context" + "sync" +) + +// AsyncListenerOptions are options for async listeners and listener mux's. +type AsyncListenerOptions struct { + // Context is the context whose Done() channel listeners use will use to listen for completion to close their + // goroutine. If it is nil, then context.Background() will be used and goroutines may be leaked. + Context context.Context + + // BufferSize is the buffer size of the channels to use. It defaults to 0. + BufferSize int + + // DoneWaitGroup is an optional wait-group that listener goroutines will notify via Add(1) when they are started + // and Done() after they are cancelled and completed. + DoneWaitGroup *sync.WaitGroup +} // AsyncListenerMux returns a listener that forwards received events to all the provided listeners asynchronously // with each listener processing in a separate go routine. All callbacks in the returned listener will return nil // except for Commit which will return an error or nil once all listeners have processed the commit. The context // is used to signal that the listeners should stop listening and return. bufferSize is the size of the buffer for the // channels used to send events to the listeners. -func AsyncListenerMux(listeners []Listener, bufferSize int, ctx context.Context) Listener { +func AsyncListenerMux(opts AsyncListenerOptions, listeners ...Listener) Listener { asyncListeners := make([]Listener, len(listeners)) commitChans := make([]chan error, len(listeners)) for i, l := range listeners { commitChan := make(chan error) commitChans[i] = commitChan - asyncListeners[i] = AsyncListener(l, bufferSize, commitChan, ctx) + asyncListeners[i] = AsyncListener(opts, commitChan, l) } mux := ListenerMux(asyncListeners...) muxCommit := mux.Commit @@ -45,11 +62,20 @@ func AsyncListenerMux(listeners []Listener, bufferSize int, ctx context.Context) // bufferSize is the size of the buffer for the channel that is used to send events to the listener. // Instead of using AsyncListener directly, it is recommended to use AsyncListenerMux which does coordination directly // via its Commit callback. -func AsyncListener(listener Listener, bufferSize int, commitChan chan<- error, ctx context.Context) Listener { - packetChan := make(chan Packet, bufferSize) +func AsyncListener(opts AsyncListenerOptions, commitChan chan<- error, listener Listener) Listener { + packetChan := make(chan Packet, opts.BufferSize) res := Listener{} + ctx := opts.Context + if ctx == nil { + ctx = context.Background() + } + done := ctx.Done() go func() { + if opts.DoneWaitGroup != nil { + opts.DoneWaitGroup.Add(1) + } + var err error for { select { @@ -73,8 +99,11 @@ func AsyncListener(listener Listener, bufferSize int, commitChan chan<- error, c } } - case <-ctx.Done(): + case <-done: close(packetChan) + if opts.DoneWaitGroup != nil { + opts.DoneWaitGroup.Done() + } return } } diff --git a/schema/appdata/async_test.go b/schema/appdata/async_test.go index f4706cab0ad3..c1df2d4ca6c0 100644 --- a/schema/appdata/async_test.go +++ b/schema/appdata/async_test.go @@ -3,12 +3,13 @@ package appdata import ( "context" "fmt" + "sync" "testing" ) func TestAsyncListenerMux(t *testing.T) { t.Run("empty", func(t *testing.T) { - listener := AsyncListenerMux([]Listener{{}, {}}, 16, context.Background()) + listener := AsyncListenerMux(AsyncListenerOptions{}, Listener{}, Listener{}) if listener.InitializeModuleData != nil { t.Error("expected nil") @@ -34,6 +35,7 @@ func TestAsyncListenerMux(t *testing.T) { t.Run("call cancel", func(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) + wg := &sync.WaitGroup{} var calls1, calls2 []string listener1 := callCollector(1, func(name string, _ int, _ Packet) { calls1 = append(calls1, name) @@ -41,7 +43,9 @@ func TestAsyncListenerMux(t *testing.T) { listener2 := callCollector(2, func(name string, _ int, _ Packet) { calls2 = append(calls2, name) }) - res := AsyncListenerMux([]Listener{listener1, listener2}, 16, ctx) + res := AsyncListenerMux(AsyncListenerOptions{ + BufferSize: 16, Context: ctx, DoneWaitGroup: wg, + }, listener1, listener2) callAllCallbacksOnces(t, res) @@ -58,15 +62,28 @@ func TestAsyncListenerMux(t *testing.T) { checkExpectedCallOrder(t, calls1, expectedCalls) checkExpectedCallOrder(t, calls2, expectedCalls) + // cancel and expect the test to finish - if all goroutines aren't canceled the test will hang cancel() + wg.Wait() + }) - // expect a panic if we try to write to the now closed channels - defer func() { - if err := recover(); err == nil { - t.Fatalf("expected panic") - } - }() - callAllCallbacksOnces(t, res) + t.Run("error on commit", func(t *testing.T) { + var calls1, calls2 []string + listener1 := callCollector(1, func(name string, _ int, _ Packet) { + calls1 = append(calls1, name) + }) + listener1.Commit = func(data CommitData) error { + return fmt.Errorf("error") + } + listener2 := callCollector(2, func(name string, _ int, _ Packet) { + calls2 = append(calls2, name) + }) + res := AsyncListenerMux(AsyncListenerOptions{}, listener1, listener2) + + err := res.Commit(CommitData{}) + if err == nil || err.Error() != "error" { + t.Fatalf("expected error, got %v", err) + } }) } @@ -74,11 +91,13 @@ func TestAsyncListener(t *testing.T) { t.Run("call cancel", func(t *testing.T) { commitChan := make(chan error) ctx, cancel := context.WithCancel(context.Background()) + wg := &sync.WaitGroup{} var calls []string listener := callCollector(1, func(name string, _ int, _ Packet) { calls = append(calls, name) }) - res := AsyncListener(listener, 16, commitChan, ctx) + res := AsyncListener(AsyncListenerOptions{BufferSize: 16, Context: ctx, DoneWaitGroup: wg}, + commitChan, listener) callAllCallbacksOnces(t, res) @@ -99,11 +118,9 @@ func TestAsyncListener(t *testing.T) { calls = nil + // expect wait group to return after cancel is called cancel() - - callAllCallbacksOnces(t, res) - - checkExpectedCallOrder(t, calls, nil) + wg.Wait() }) t.Run("error", func(t *testing.T) { @@ -117,7 +134,7 @@ func TestAsyncListener(t *testing.T) { return fmt.Errorf("error") } - res := AsyncListener(listener, 16, commitChan, context.Background()) + res := AsyncListener(AsyncListenerOptions{BufferSize: 16}, commitChan, listener) callAllCallbacksOnces(t, res)