Skip to content

Commit

Permalink
Clockwork contexts respect parent cancellation
Browse files Browse the repository at this point in the history
  • Loading branch information
DPJacques committed Nov 28, 2024
1 parent 42e854e commit a013058
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 84 deletions.
60 changes: 34 additions & 26 deletions context.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package clockwork

import (
"context"
"errors"
"fmt"
"sync"
"time"
)
Expand Down Expand Up @@ -34,6 +34,22 @@ func FromContext(ctx context.Context) Clock {
return NewRealClock()
}

// ErrFakeClockDeadlineExceeded is the error returned by [context.Context] when
// the deadline passes on a context which uses a [FakeClock].
//
// It wraps a [context.DeadlineExceeded] error, i.e.:
//
// // The following is true for any Context whose deadline has been exceeded,
// // including contexts made with clockwork.WithDeadline or clockwork.WithTimeout.
//
// errors.Is(ctx.Err(), context.DeadlineExceeded)
//
// // The following can only be true for contexts made
// // with clockwork.WithDeadline or clockwork.WithTimeout.
//
// errors.Is(ctx.Err(), clockwork.ErrFakeClockDeadlineExceeded)
var ErrFakeClockDeadlineExceeded error = fmt.Errorf("clockwork.FakeClock: %w", context.DeadlineExceeded)

// WithDeadline returns a context with a deadline based on a [FakeClock].
//
// The returned context ignores parent cancelation if the parent was cancelled
Expand All @@ -43,16 +59,22 @@ func FromContext(ctx context.Context) Clock {
// If the parent is cancelled with a [context.DeadlineExceeded] error, the only
// way to then cancel the returned context is by calling the returned
// context.CancelFunc.
func WithDeadline(parent context.Context, clock *FakeClock, t time.Time) (context.Context, context.CancelFunc) {
return newFakeClockContext(parent, t, clock.newTimerAtTime(t, nil).Chan())
func WithDeadline(parent context.Context, clock Clock, t time.Time) (context.Context, context.CancelFunc) {
if fc, ok := clock.(*FakeClock); ok {
return newFakeClockContext(parent, t, fc.newTimerAtTime(t, nil).Chan())
}
return context.WithDeadline(parent, t)
}

// WithTimeout returns a context with a timeout based on a [FakeClock].
//
// The returned context follows the same behaviors as [WithDeadline].
func WithTimeout(parent context.Context, clock *FakeClock, d time.Duration) (context.Context, context.CancelFunc) {
t, deadline := clock.newTimer(d, nil)
return newFakeClockContext(parent, deadline, t.Chan())
func WithTimeout(parent context.Context, clock Clock, d time.Duration) (context.Context, context.CancelFunc) {
if fc, ok := clock.(*FakeClock); ok {
t, deadline := fc.newTimer(d, nil)
return newFakeClockContext(parent, deadline, t.Chan())
}
return context.WithTimeout(parent, d)
}

// fakeClockContext implements context.Context, using a fake clock for its
Expand Down Expand Up @@ -90,7 +112,7 @@ func newFakeClockContext(parent context.Context, deadline time.Time, timer <-cha
}
ready := make(chan struct{}, 1)
go ctx.runCancel(ready)
<-ready // Cancellation goroutine is running.
<-ready // Wait until the cancellation goroutine is running.
return ctx, ctx.cancel
}

Expand Down Expand Up @@ -127,35 +149,21 @@ func (c *fakeClockContext) runCancel(ready chan struct{}) {
// branches of our select statement below.
defer close(ready)

var ctxErr error
for ctxErr == nil {
for c.err == nil {
select {
case <-c.timerDone:
ctxErr = context.DeadlineExceeded

c.err = ErrFakeClockDeadlineExceeded
case <-c.cancelCalled:
ctxErr = context.Canceled

c.err = context.Canceled
case <-parentDone:
parentDone = nil // This case statement can only fire once.
c.err = c.parent.Err()

if err := c.parent.Err(); !errors.Is(err, context.DeadlineExceeded) {
// The parent context was canceled with some error other than deadline
// exceeded, so we respect it.
ctxErr = err
}
case ready <- struct{}{}:
// Signals the cancellation goroutine has begun, in an attempt to minimize
// race conditions related to goroutine startup time.
ready = nil // This case statement can only fire once.
}
}

c.setError(ctxErr)
return
}

func (c *fakeClockContext) setError(err error) {
c.err = err
close(c.ctxDone)
return
}
134 changes: 76 additions & 58 deletions context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,6 @@ func TestContextOps(t *testing.T) {

ctx = AddToContext(ctx, NewFakeClock())
assertIsType(t, NewFakeClock(), FromContext(ctx))

ctx = AddToContext(ctx, NewRealClock())
assertIsType(t, NewRealClock(), FromContext(ctx))
}

func assertIsType(t *testing.T, expectedType, object any) {
Expand Down Expand Up @@ -80,7 +77,7 @@ func TestWithDeadlineDone(t *testing.T) {

select {
case <-child.Done():
t.Fatalf("WithDeadline context finished early.")
t.Fatal("WithDeadline context finished early.")
default:
}

Expand All @@ -92,39 +89,11 @@ func TestWithDeadlineDone(t *testing.T) {
t.Errorf("WithDeadline context returned %v, want %v", got, tc.want)
}
case <-base.Done():
t.Errorf("WithDeadline context was never canceled.")
t.Error("WithDeadline context was never canceled.")
}
})
}
}

func TestWithDeadlineParentDeadlineDoesNotCancelChild(t *testing.T) {
t.Parallel()
base, cancelBase := context.WithTimeout(context.Background(), timeout)
defer cancelBase()

// Parent context hits deadline effectively immediately.
parent, cancelParent := context.WithTimeout(base, time.Nanosecond)
defer cancelParent()

clock := NewFakeClockAt(time.Unix(10, 0))
child, cancelChild := WithDeadline(parent, clock, time.Unix(20, 0))
defer cancelChild()

// TODO(https://github.com/jonboulle/clockwork/issues/67): The time.After()
// below makes the case for having a way to validate that no timers have
// fired, rather than waiting an arbitrary amount of time and hoping you've
// waiting long enough to cover any race conditions.
//
// An abandoned attempt to do this can be found in
// https://github.com/jonboulle/clockwork/pull/69.
select {
case <-time.After(50 * time.Millisecond): // Sleeping in tests, yuck.
case <-child.Done():
t.Errorf("WithDeadline context respected parenet deadline, returning %v, want parent deadline to be ignored.", child.Err())
}
}

func TestWithTimeoutDone(t *testing.T) {
t.Parallel()
cases := []struct {
Expand Down Expand Up @@ -156,7 +125,7 @@ func TestWithTimeoutDone(t *testing.T) {
want: context.Canceled,
},
{
name: "advancing past timeout cancels child",
name: "advancing past deadline cancels child",
start: time.Unix(10, 0),
timeout: 10 * time.Second,
action: func(_, _ context.CancelFunc, clock *FakeClock) {
Expand All @@ -179,7 +148,7 @@ func TestWithTimeoutDone(t *testing.T) {

select {
case <-child.Done():
t.Fatalf("WithTimeout context finished early.")
t.Fatal("WithTimeout context finished early.")
default:
}

Expand All @@ -191,35 +160,84 @@ func TestWithTimeoutDone(t *testing.T) {
t.Errorf("WithTimeout context returned %v, want %v", got, tc.want)
}
case <-background.Done():
t.Errorf("WithTimeout context was never canceled.")
t.Error("WithTimeout context was never canceled.")
}
})
}
}

func TestWithTimeoutParentTimeoutDoesNotCancelChild(t *testing.T) {
func TestParentCancellationIsRespected(t *testing.T) {
t.Parallel()
base, cancelBase := context.WithTimeout(context.Background(), timeout)
defer cancelBase()
cases := []struct {
name string

contextFunc func(context.Context, *FakeClock) (context.Context, context.CancelFunc)

// Parent context hits deadline effectively immediately.
parent, cancelParent := context.WithTimeout(base, time.Nanosecond)
defer cancelParent()

clock := NewFakeClockAt(time.Unix(10, 0))
child, cancelChild := WithTimeout(parent, clock, 10*time.Second)
defer cancelChild()

// TODO(https://github.com/jonboulle/clockwork/issues/67): The time.After()
// below makes the case for having a way to validate that no timers have
// fired, rather than waiting an arbitrary amount of time and hoping you've
// waiting long enough to cover any race conditions.
//
// An abandoned attempt to do this can be found in
// https://github.com/jonboulle/clockwork/pull/69.
select {
case <-time.After(50 * time.Millisecond): // Sleeping in tests, yuck.
case <-child.Done():
t.Errorf("WithTimeout context respected parenet deadline, returning %v, want parent deadline to be ignored.", child.Err())
requireContextDeadlineExceeded bool
}{
{
name: "WithDeadline in the future",
contextFunc: func(ctx context.Context, fc *FakeClock) (context.Context, context.CancelFunc) {
return WithDeadline(ctx, fc, time.Now().Add(time.Hour))
},
// The FakeClock does not hit its deadline, so the error must be context.DeadlineExceeded.
requireContextDeadlineExceeded: true,
},
{
name: "WithDeadline in the past",
contextFunc: func(ctx context.Context, fc *FakeClock) (context.Context, context.CancelFunc) {
return WithDeadline(ctx, fc, time.Now().Add(-time.Hour))
},
},
{
name: "WithTimeout in the future",
contextFunc: func(ctx context.Context, fc *FakeClock) (context.Context, context.CancelFunc) {
return WithTimeout(ctx, fc, time.Hour)
},
// The FakeClock does not hit its deadline, so the error must be context.DeadlineExceeded.
requireContextDeadlineExceeded: true,
},
{
name: "WithTimeout immediately",
contextFunc: func(ctx context.Context, fc *FakeClock) (context.Context, context.CancelFunc) {
return WithTimeout(ctx, fc, 0)
},
},
{
name: "WithTimeout in the past",
contextFunc: func(ctx context.Context, fc *FakeClock) (context.Context, context.CancelFunc) {
return WithTimeout(ctx, fc, -time.Hour)
},
},
}

for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
base, cancelBase := context.WithTimeout(context.Background(), timeout)
defer cancelBase()

// Parent context hits deadline effectively immediately.
parent, cancelParent := context.WithTimeout(base, time.Nanosecond)
defer cancelParent()

clock := NewFakeClockAt(time.Unix(10, 0))
child, cancelChild := tc.contextFunc(parent, clock)
defer cancelChild()

select {
case <-child.Done():
case <-base.Done():
t.Fatal("context did not respect parnet deadline")
}

if err := child.Err(); !errors.Is(err, context.DeadlineExceeded) {
t.Errorf("errors.Is(Context.Err(), context.DeadlineExceeded) == falst, want true, error: %v", err)
}
if tc.requireContextDeadlineExceeded {
if err := child.Err(); errors.Is(err, ErrFakeClockDeadlineExceeded) {
t.Errorf("errors.Is(Context.Err(), ErrFakeClockDeadlineExceeded) == true, want false, error: %v", err)
}
}
})
}
}

0 comments on commit a013058

Please sign in to comment.