Skip to content

Commit

Permalink
refactor(broadcast): pass get wait ch to Wait callback
Browse files Browse the repository at this point in the history
Signed-off-by: Christian Stewart <christian@aperture.us>
  • Loading branch information
paralin committed Aug 2, 2024
1 parent 24517dd commit 31a22d9
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 5 deletions.
9 changes: 5 additions & 4 deletions broadcast/broadcast.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,10 @@ func (c *Broadcast) HoldLockMaybeAsync(cb func(broadcast func(), getWaitCh func(
// Wait waits for the cb to return true or an error before returning.
// When the broadcast channel is broadcasted, re-calls cb again to re-check the value.
// cb is called while the mutex is locked.
// Returns false, context.Canceled if ctx is canceled.
// Return nil if and only if cb returned true, nil.
func (c *Broadcast) Wait(ctx context.Context, cb func(broadcast func()) (bool, error)) error {
// Returns the wait channel and any error.
// Returns context.Canceled if ctx is canceled.
// Return waitCh, nil if and only if cb returned true, nil.
func (c *Broadcast) Wait(ctx context.Context, cb func(broadcast func(), getWaitCh func() <-chan struct{}) (bool, error)) error {
if cb == nil || ctx == nil {
return errors.New("cb and ctx must be set")
}
Expand All @@ -65,7 +66,7 @@ func (c *Broadcast) Wait(ctx context.Context, cb func(broadcast func()) (bool, e
var done bool
var err error
c.HoldLock(func(broadcast func(), getWaitCh func() <-chan struct{}) {
done, err = cb(broadcast)
done, err = cb(broadcast, getWaitCh)
if !done && err == nil {
waitCh = getWaitCh()
}
Expand Down
6 changes: 5 additions & 1 deletion broadcast/broadcast_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,14 @@ func ExampleBroadcast_Wait() {

ctx := context.Background()
var gotValue int
b.Wait(ctx, func(broadcast func()) (bool, error) {
err := b.Wait(ctx, func(broadcast func(), getWaitCh func() <-chan struct{}) (bool, error) {
gotValue = currValue
return gotValue == 9, nil
})
if err != nil {
fmt.Printf("failed to wait for value: %v", err.Error())
return
}

fmt.Printf("waited for value to increment: %v\n", gotValue)
// Output: waited for value to increment: 9
Expand Down

0 comments on commit 31a22d9

Please sign in to comment.