Skip to content

Commit

Permalink
feat!: use context parameter in Future.Get
Browse files Browse the repository at this point in the history
  • Loading branch information
reugn committed Aug 21, 2024
1 parent 6d948a5 commit c8a6eca
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 23 deletions.
27 changes: 12 additions & 15 deletions future.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
package async

import (
"fmt"
"context"
"sync"
"time"
)

// Future represents a value which may or may not currently be available,
Expand All @@ -23,9 +22,9 @@ type Future[T any] interface {
// or an error.
Join() (T, error)

// Get blocks for at most the given time duration for this Future to
// complete and returns either a result or an error.
Get(time.Duration) (T, error)
// Get blocks until the Future is completed or context is canceled and
// returns either a result or an error.
Get(context.Context) (T, error)

// Recover handles any error that this Future might contain using a
// resolver function.
Expand Down Expand Up @@ -68,16 +67,14 @@ func (fut *futureImpl[T]) accept() {
}

// acceptTimeout blocks once, until the Future result is available or until
// the timeout expires.
func (fut *futureImpl[T]) acceptTimeout(timeout time.Duration) {
// the context is canceled.
func (fut *futureImpl[T]) acceptContext(ctx context.Context) {
fut.acceptOnce.Do(func() {
timer := time.NewTimer(timeout)
defer timer.Stop()
select {
case result := <-fut.done:
fut.setResult(result)
case <-timer.C:
fut.setResult(fmt.Errorf("Future timeout after %s", timeout))
case <-ctx.Done():
fut.setResult(ctx.Err())
}
})
}
Expand Down Expand Up @@ -137,10 +134,10 @@ func (fut *futureImpl[T]) Join() (T, error) {
return fut.value, fut.err
}

// Get blocks for at most the given time duration for this Future to
// complete and returns either a result or an error.
func (fut *futureImpl[T]) Get(timeout time.Duration) (T, error) {
fut.acceptTimeout(timeout)
// Get blocks until the Future is completed or context is canceled and
// returns either a result or an error.
func (fut *futureImpl[T]) Get(ctx context.Context) (T, error) {
fut.acceptContext(ctx)
return fut.value, fut.err
}

Expand Down
15 changes: 10 additions & 5 deletions future_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package async

import (
"context"
"errors"
"fmt"
"runtime"
Expand Down Expand Up @@ -79,7 +80,7 @@ func TestFuture_Transform(t *testing.T) {
return util.Ptr(5), nil
})

res, _ := future.Get(time.Second * 5)
res, _ := future.Get(context.Background())
assert.Equal(t, 3, *res)

res, _ = future.Join()
Expand Down Expand Up @@ -136,11 +137,15 @@ func TestFuture_Timeout(t *testing.T) {
}()
future := p.Future()

_, err := future.Get(10 * time.Millisecond)
assert.ErrorContains(t, err, "timeout")
ctx, cancel := context.WithTimeout(context.Background(),
10*time.Millisecond)
defer cancel()

_, err := future.Get(ctx)
assert.ErrorIs(t, err, context.DeadlineExceeded)

_, err = future.Join()
assert.ErrorContains(t, err, "timeout")
assert.ErrorIs(t, err, context.DeadlineExceeded)
}

func TestFuture_GoroutineLeak(t *testing.T) {
Expand All @@ -161,7 +166,7 @@ func TestFuture_GoroutineLeak(t *testing.T) {
go func() {
defer wg.Done()
fut := promise.Future()
_, _ = fut.Get(10 * time.Millisecond)
_, _ = fut.Get(context.Background())
time.Sleep(100 * time.Millisecond)
_, _ = fut.Join()
}()
Expand Down
5 changes: 2 additions & 3 deletions future_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,10 @@ func FutureFirstCompletedOf[T any](futures ...Future[T]) Future[T] {
func FutureTimer[T any](d time.Duration) Future[T] {
next := newFuture[T]()
go func() {
timer := time.NewTimer(d)
<-timer.C
<-time.After(d)
var zero T
next.(*futureImpl[T]).
complete(zero, fmt.Errorf("FutureTimer %s timeout", d))
complete(zero, fmt.Errorf("future timeout after %s", d))
}()
return next
}
9 changes: 9 additions & 0 deletions internal/assert/assertions.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package assert

import (
"errors"
"fmt"
"reflect"
"strings"
Expand Down Expand Up @@ -78,6 +79,14 @@ func ErrorContains(t *testing.T, err error, str string) {
}
}

// ErrorIs checks whether any error in err's tree matches target.
func ErrorIs(t *testing.T, err error, target error) {
if !errors.Is(err, target) {
t.Helper()
t.Fatalf("Error type mismatch: %v != %v", err, target)
}
}

// Panics checks whether the given function panics.
func Panics(t *testing.T, f func()) {
defer func() {
Expand Down

0 comments on commit c8a6eca

Please sign in to comment.