diff --git a/src/context/benchmark_test.go b/src/context/benchmark_test.go index 5d56863050a9ad..c4c72f00f8033e 100644 --- a/src/context/benchmark_test.go +++ b/src/context/benchmark_test.go @@ -5,6 +5,7 @@ package context_test import ( + "context" . "context" "fmt" "runtime" @@ -138,3 +139,17 @@ func BenchmarkCheckCanceled(b *testing.B) { } }) } + +func BenchmarkContextCancelDone(b *testing.B) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + select { + case <-ctx.Done(): + default: + } + } + }) +} diff --git a/src/context/context.go b/src/context/context.go index b3fdb8277afc37..733c5f56d9274d 100644 --- a/src/context/context.go +++ b/src/context/context.go @@ -303,10 +303,8 @@ func parentCancelCtx(parent Context) (*cancelCtx, bool) { if !ok { return nil, false } - p.mu.Lock() - ok = p.done == done - p.mu.Unlock() - if !ok { + pdone, _ := p.done.Load().(chan struct{}) + if pdone != done { return nil, false } return p, true @@ -345,7 +343,7 @@ type cancelCtx struct { Context mu sync.Mutex // protects following fields - done chan struct{} // created lazily, closed by first cancel call + done atomic.Value // of chan struct{}, created lazily, closed by first cancel call children map[canceler]struct{} // set to nil by the first cancel call err error // set to non-nil by the first cancel call } @@ -358,13 +356,18 @@ func (c *cancelCtx) Value(key interface{}) interface{} { } func (c *cancelCtx) Done() <-chan struct{} { + d := c.done.Load() + if d != nil { + return d.(chan struct{}) + } c.mu.Lock() - if c.done == nil { - c.done = make(chan struct{}) + defer c.mu.Unlock() + d = c.done.Load() + if d == nil { + d = make(chan struct{}) + c.done.Store(d) } - d := c.done - c.mu.Unlock() - return d + return d.(chan struct{}) } func (c *cancelCtx) Err() error { @@ -401,10 +404,11 @@ func (c *cancelCtx) cancel(removeFromParent bool, err error) { return // already canceled } c.err = err - if c.done == nil { - c.done = closedchan + d, _ := c.done.Load().(chan struct{}) + if d == nil { + c.done.Store(closedchan) } else { - close(c.done) + close(d) } for child := range c.children { // NOTE: acquiring the child's lock while holding parent's lock.