From a7affacaf07488d9b49680e2b3f6a24b12186f32 Mon Sep 17 00:00:00 2001 From: Janos Guljas Date: Tue, 5 Dec 2023 11:31:09 +0100 Subject: [PATCH] Protect the shared value for concurrent access --- singleflight.go | 3 ++- singleflight_test.go | 28 ++++++++++++++++++++++++++++ 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/singleflight.go b/singleflight.go index 0dde54b..ff79a52 100644 --- a/singleflight.go +++ b/singleflight.go @@ -82,8 +82,9 @@ func (g *Group[K, V]) wait(ctx context.Context, key K, c *call[V]) (v V, shared c.cancel() delete(g.calls, key) } + shared = c.shared g.mu.Unlock() - return v, c.shared, err + return v, shared, err } // Forget tells the singleflight to forget about a key. Future calls diff --git a/singleflight_test.go b/singleflight_test.go index b9af8a1..475a6bd 100644 --- a/singleflight_test.go +++ b/singleflight_test.go @@ -39,6 +39,34 @@ func TestDo(t *testing.T) { } } +func TestDo_concurrentAccess(t *testing.T) { + var g singleflight.Group[string, string] + + want := "val" + key := "key" + var wg sync.WaitGroup + n := 100 + + wg.Add(n) + for i := 0; i < n; i++ { + go func(i int) { + defer wg.Done() + got, shared, err := g.Do(context.Background(), key, func(_ context.Context) (string, error) { + return want, nil + }) + if err != nil { + t.Error(err) + } + _ = shared // read the shared to test the concurrent access + if got != want { + t.Errorf("got value %v, want %v", got, want) + } + time.Sleep(5 * time.Millisecond) + }(i) + } + wg.Wait() +} + func TestDo_error(t *testing.T) { var g singleflight.Group[string, string] wantErr := errors.New("test error")