From 6f0f39d99b1cda9efa5ef6d5c2717d65e5575b73 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Gronowski?= Date: Mon, 11 Sep 2023 12:15:21 +0200 Subject: [PATCH] Preserve context values MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Instead of passing a fresh `context.Background()` wrap context passed by the user with a context that ignores cancellation. This allows to drop the original cancellation while preserving context values. If built with Go 1.21+ the stdlib `context.WithoutCancel` is used. Otherwise it fallbacks to an in-tree copy of the withoutCancel. Signed-off-by: Paweł Gronowski --- singleflight.go | 14 ++++--- singleflight_test.go | 20 +++++++++ withoutcancel.go | 12 ++++++ withoutcancel_backport.go | 87 +++++++++++++++++++++++++++++++++++++++ 4 files changed, 128 insertions(+), 5 deletions(-) create mode 100644 withoutcancel.go create mode 100644 withoutcancel_backport.go diff --git a/singleflight.go b/singleflight.go index ad89737..0dde54b 100644 --- a/singleflight.go +++ b/singleflight.go @@ -25,10 +25,12 @@ type Group[K comparable, V any] struct { // comes in, the duplicate caller waits for the original to complete and // receives the same results. // -// The context passed to the fn function is a new context which is canceled when -// contexts from all callers are canceled, so that no caller is expecting the -// result. If there are multiple callers, context passed to one caller does not -// effect the execution and returned values of others. +// The context passed to the fn function is a context that preserves all values +// from the passed context but is cancelled by the singleflight only when all +// awaiting caller's contexts are cancelled (no caller is awaiting the result). +// If there are multiple callers, context passed to one caller does not affect +// the execution and returned values of others except if the function result is +// dependent on the context values. // // The return value shared indicates whether v was given to multiple callers. func (g *Group[K, V]) Do(ctx context.Context, key K, fn func(ctx context.Context) (V, error)) (v V, shared bool, err error) { @@ -45,7 +47,9 @@ func (g *Group[K, V]) Do(ctx context.Context, key K, fn func(ctx context.Context return g.wait(ctx, key, c) } - callCtx, cancel := context.WithCancel(context.Background()) + // Replace cancellation from the user context with a cancellation + // controlled by the singleflight and preserve context values. + callCtx, cancel := context.WithCancel(withoutCancel(ctx)) c := &call[V]{ done: make(chan struct{}), diff --git a/singleflight_test.go b/singleflight_test.go index e05d9a5..b9af8a1 100644 --- a/singleflight_test.go +++ b/singleflight_test.go @@ -393,6 +393,26 @@ func TestDo_multipleCallsCanceled(t *testing.T) { } } +func TestDo_preserveContextValues(t *testing.T) { + var g singleflight.Group[string, any] + + type KeyType string + const key KeyType = "foo" + + callerCtx := context.WithValue(context.Background(), key, "bar") + + val, _, err := g.Do(callerCtx, "key", func(ctx context.Context) (any, error) { + return ctx.Value(key), nil + }) + + if err != nil { + t.Fatal(err) + } + if val != "bar" { + t.Error("the context should not lose the values") + } +} + func waitStacks(t *testing.T, loc string, count int, timeout time.Duration) { t.Helper() diff --git a/withoutcancel.go b/withoutcancel.go new file mode 100644 index 0000000..bd2b2df --- /dev/null +++ b/withoutcancel.go @@ -0,0 +1,12 @@ +//go:build go1.21 + +package singleflight + +import "context" + +// withoutCancel returns a copy of parent that is not canceled when parent is canceled. +// The returned context returns no Deadline or Err, and its Done channel is nil. +// Calling [Cause] on the returned context returns nil. +func withoutCancel(ctx context.Context) context.Context { + return context.WithoutCancel(ctx) +} diff --git a/withoutcancel_backport.go b/withoutcancel_backport.go new file mode 100644 index 0000000..2f78766 --- /dev/null +++ b/withoutcancel_backport.go @@ -0,0 +1,87 @@ +//go:build !go1.21 + +// Copyright (c) 2009 The Go Authors. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// Source: https://cs.opensource.google/go/go/+/refs/tags/go1.21.1:src/context/context.go +// The only modifications to the original source were: +// - renaming WithoutCancel to withoutCancel +// - replacing the usage of internal reflectlite with reflect +// - replacing the usage of private value function with Value method call +package singleflight + +import ( + "context" + "reflect" + "time" +) + +// withoutCancel returns a copy of parent that is not canceled when parent is canceled. +// The returned context returns no Deadline or Err, and its Done channel is nil. +// Calling [Cause] on the returned context returns nil. +func withoutCancel(parent context.Context) context.Context { + if parent == nil { + panic("cannot create context from nil parent") + } + return withoutCancelCtx{parent} +} + +type withoutCancelCtx struct { + c context.Context +} + +func (withoutCancelCtx) Deadline() (deadline time.Time, ok bool) { + return +} + +func (withoutCancelCtx) Done() <-chan struct{} { + return nil +} + +func (withoutCancelCtx) Err() error { + return nil +} + +func (c withoutCancelCtx) Value(key any) any { + return c.c.Value(key) +} + +func (c withoutCancelCtx) String() string { + return contextName(c.c) + ".WithoutCancel" +} + +type stringer interface { + String() string +} + +func contextName(c context.Context) string { + if s, ok := c.(stringer); ok { + return s.String() + } + return reflect.TypeOf(c).String() +}