Skip to content

Commit

Permalink
Preserve context values
Browse files Browse the repository at this point in the history
Instead of passing a fresh `context.Background()` wrap context passed by
the user with a context that ignores cancellation.

This allows to drop the original cancellation while preserving context
values.

If built with Go 1.21+ the stdlib `context.WithoutCancel` is used.
Otherwise it fallbacks to an in-tree copy of the withoutCancel.

Signed-off-by: Paweł Gronowski <pawel.gronowski@docker.com>
  • Loading branch information
vvoland committed Sep 11, 2023
1 parent 12ee437 commit 6f0f39d
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 5 deletions.
14 changes: 9 additions & 5 deletions singleflight.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,12 @@ type Group[K comparable, V any] struct {
// comes in, the duplicate caller waits for the original to complete and
// receives the same results.
//
// The context passed to the fn function is a new context which is canceled when
// contexts from all callers are canceled, so that no caller is expecting the
// result. If there are multiple callers, context passed to one caller does not
// effect the execution and returned values of others.
// The context passed to the fn function is a context that preserves all values
// from the passed context but is cancelled by the singleflight only when all
// awaiting caller's contexts are cancelled (no caller is awaiting the result).
// If there are multiple callers, context passed to one caller does not affect
// the execution and returned values of others except if the function result is
// dependent on the context values.
//
// The return value shared indicates whether v was given to multiple callers.
func (g *Group[K, V]) Do(ctx context.Context, key K, fn func(ctx context.Context) (V, error)) (v V, shared bool, err error) {
Expand All @@ -45,7 +47,9 @@ func (g *Group[K, V]) Do(ctx context.Context, key K, fn func(ctx context.Context
return g.wait(ctx, key, c)
}

callCtx, cancel := context.WithCancel(context.Background())
// Replace cancellation from the user context with a cancellation
// controlled by the singleflight and preserve context values.
callCtx, cancel := context.WithCancel(withoutCancel(ctx))

c := &call[V]{
done: make(chan struct{}),
Expand Down
20 changes: 20 additions & 0 deletions singleflight_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,26 @@ func TestDo_multipleCallsCanceled(t *testing.T) {
}
}

func TestDo_preserveContextValues(t *testing.T) {
var g singleflight.Group[string, any]

type KeyType string
const key KeyType = "foo"

callerCtx := context.WithValue(context.Background(), key, "bar")

val, _, err := g.Do(callerCtx, "key", func(ctx context.Context) (any, error) {
return ctx.Value(key), nil
})

if err != nil {
t.Fatal(err)
}
if val != "bar" {
t.Error("the context should not lose the values")
}
}

func waitStacks(t *testing.T, loc string, count int, timeout time.Duration) {
t.Helper()

Expand Down
12 changes: 12 additions & 0 deletions withoutcancel.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
//go:build go1.21

package singleflight

import "context"

// withoutCancel returns a copy of parent that is not canceled when parent is canceled.
// The returned context returns no Deadline or Err, and its Done channel is nil.
// Calling [Cause] on the returned context returns nil.
func withoutCancel(ctx context.Context) context.Context {
return context.WithoutCancel(ctx)
}
87 changes: 87 additions & 0 deletions withoutcancel_backport.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
//go:build !go1.21

// Copyright (c) 2009 The Go Authors. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
//
// Source: https://cs.opensource.google/go/go/+/refs/tags/go1.21.1:src/context/context.go
// The only modifications to the original source were:
// - renaming WithoutCancel to withoutCancel
// - replacing the usage of internal reflectlite with reflect
// - replacing the usage of private value function with Value method call
package singleflight

import (
"context"
"reflect"
"time"
)

// withoutCancel returns a copy of parent that is not canceled when parent is canceled.
// The returned context returns no Deadline or Err, and its Done channel is nil.
// Calling [Cause] on the returned context returns nil.
func withoutCancel(parent context.Context) context.Context {
if parent == nil {
panic("cannot create context from nil parent")
}
return withoutCancelCtx{parent}
}

type withoutCancelCtx struct {
c context.Context
}

func (withoutCancelCtx) Deadline() (deadline time.Time, ok bool) {
return
}

func (withoutCancelCtx) Done() <-chan struct{} {
return nil
}

func (withoutCancelCtx) Err() error {
return nil
}

func (c withoutCancelCtx) Value(key any) any {
return c.c.Value(key)
}

func (c withoutCancelCtx) String() string {
return contextName(c.c) + ".WithoutCancel"
}

type stringer interface {
String() string
}

func contextName(c context.Context) string {
if s, ok := c.(stringer); ok {
return s.String()
}
return reflect.TypeOf(c).String()
}

0 comments on commit 6f0f39d

Please sign in to comment.