Skip to content

Commit

Permalink
docs
Browse files Browse the repository at this point in the history
  • Loading branch information
aaronc committed Jul 10, 2024
1 parent a50f48c commit 78f1640
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 21 deletions.
41 changes: 35 additions & 6 deletions schema/appdata/async.go
Original file line number Diff line number Diff line change
@@ -1,19 +1,36 @@
package appdata

import "context"
import (
"context"
"sync"
)

// AsyncListenerOptions are options for async listeners and listener mux's.
type AsyncListenerOptions struct {
// Context is the context whose Done() channel listeners use will use to listen for completion to close their
// goroutine. If it is nil, then context.Background() will be used and goroutines may be leaked.
Context context.Context

// BufferSize is the buffer size of the channels to use. It defaults to 0.
BufferSize int

// DoneWaitGroup is an optional wait-group that listener goroutines will notify via Add(1) when they are started
// and Done() after they are cancelled and completed.
DoneWaitGroup *sync.WaitGroup
}

// AsyncListenerMux returns a listener that forwards received events to all the provided listeners asynchronously
// with each listener processing in a separate go routine. All callbacks in the returned listener will return nil
// except for Commit which will return an error or nil once all listeners have processed the commit. The context
// is used to signal that the listeners should stop listening and return. bufferSize is the size of the buffer for the
// channels used to send events to the listeners.
func AsyncListenerMux(listeners []Listener, bufferSize int, ctx context.Context) Listener {
func AsyncListenerMux(opts AsyncListenerOptions, listeners ...Listener) Listener {
asyncListeners := make([]Listener, len(listeners))
commitChans := make([]chan error, len(listeners))
for i, l := range listeners {
commitChan := make(chan error)
commitChans[i] = commitChan
asyncListeners[i] = AsyncListener(l, bufferSize, commitChan, ctx)
asyncListeners[i] = AsyncListener(opts, commitChan, l)
}
mux := ListenerMux(asyncListeners...)
muxCommit := mux.Commit
Expand Down Expand Up @@ -45,11 +62,20 @@ func AsyncListenerMux(listeners []Listener, bufferSize int, ctx context.Context)
// bufferSize is the size of the buffer for the channel that is used to send events to the listener.
// Instead of using AsyncListener directly, it is recommended to use AsyncListenerMux which does coordination directly
// via its Commit callback.
func AsyncListener(listener Listener, bufferSize int, commitChan chan<- error, ctx context.Context) Listener {
packetChan := make(chan Packet, bufferSize)
func AsyncListener(opts AsyncListenerOptions, commitChan chan<- error, listener Listener) Listener {
packetChan := make(chan Packet, opts.BufferSize)
res := Listener{}
ctx := opts.Context
if ctx == nil {
ctx = context.Background()
}
done := ctx.Done()

go func() {
if opts.DoneWaitGroup != nil {
opts.DoneWaitGroup.Add(1)
}

var err error
for {
select {
Expand All @@ -73,8 +99,11 @@ func AsyncListener(listener Listener, bufferSize int, commitChan chan<- error, c
}
}

case <-ctx.Done():
case <-done:
close(packetChan)
if opts.DoneWaitGroup != nil {
opts.DoneWaitGroup.Done()
}
return
}
}
Expand Down
47 changes: 32 additions & 15 deletions schema/appdata/async_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@ package appdata
import (
"context"
"fmt"
"sync"
"testing"
)

func TestAsyncListenerMux(t *testing.T) {
t.Run("empty", func(t *testing.T) {
listener := AsyncListenerMux([]Listener{{}, {}}, 16, context.Background())
listener := AsyncListenerMux(AsyncListenerOptions{}, Listener{}, Listener{})

if listener.InitializeModuleData != nil {
t.Error("expected nil")
Expand All @@ -34,14 +35,17 @@ func TestAsyncListenerMux(t *testing.T) {

t.Run("call cancel", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
wg := &sync.WaitGroup{}
var calls1, calls2 []string
listener1 := callCollector(1, func(name string, _ int, _ Packet) {
calls1 = append(calls1, name)
})
listener2 := callCollector(2, func(name string, _ int, _ Packet) {
calls2 = append(calls2, name)
})
res := AsyncListenerMux([]Listener{listener1, listener2}, 16, ctx)
res := AsyncListenerMux(AsyncListenerOptions{
BufferSize: 16, Context: ctx, DoneWaitGroup: wg,
}, listener1, listener2)

callAllCallbacksOnces(t, res)

Expand All @@ -58,27 +62,42 @@ func TestAsyncListenerMux(t *testing.T) {
checkExpectedCallOrder(t, calls1, expectedCalls)
checkExpectedCallOrder(t, calls2, expectedCalls)

// cancel and expect the test to finish - if all goroutines aren't canceled the test will hang
cancel()
wg.Wait()
})

// expect a panic if we try to write to the now closed channels
defer func() {
if err := recover(); err == nil {
t.Fatalf("expected panic")
}
}()
callAllCallbacksOnces(t, res)
t.Run("error on commit", func(t *testing.T) {
var calls1, calls2 []string
listener1 := callCollector(1, func(name string, _ int, _ Packet) {
calls1 = append(calls1, name)
})
listener1.Commit = func(data CommitData) error {
return fmt.Errorf("error")
}
listener2 := callCollector(2, func(name string, _ int, _ Packet) {
calls2 = append(calls2, name)
})
res := AsyncListenerMux(AsyncListenerOptions{}, listener1, listener2)

err := res.Commit(CommitData{})
if err == nil || err.Error() != "error" {
t.Fatalf("expected error, got %v", err)
}
})
}

func TestAsyncListener(t *testing.T) {
t.Run("call cancel", func(t *testing.T) {
commitChan := make(chan error)
ctx, cancel := context.WithCancel(context.Background())
wg := &sync.WaitGroup{}
var calls []string
listener := callCollector(1, func(name string, _ int, _ Packet) {
calls = append(calls, name)
})
res := AsyncListener(listener, 16, commitChan, ctx)
res := AsyncListener(AsyncListenerOptions{BufferSize: 16, Context: ctx, DoneWaitGroup: wg},
commitChan, listener)

callAllCallbacksOnces(t, res)

Expand All @@ -99,11 +118,9 @@ func TestAsyncListener(t *testing.T) {

calls = nil

// expect wait group to return after cancel is called
cancel()

callAllCallbacksOnces(t, res)

checkExpectedCallOrder(t, calls, nil)
wg.Wait()
})

t.Run("error", func(t *testing.T) {
Expand All @@ -117,7 +134,7 @@ func TestAsyncListener(t *testing.T) {
return fmt.Errorf("error")
}

res := AsyncListener(listener, 16, commitChan, context.Background())
res := AsyncListener(AsyncListenerOptions{BufferSize: 16}, commitChan, listener)

callAllCallbacksOnces(t, res)

Expand Down

0 comments on commit 78f1640

Please sign in to comment.