diff --git a/semaphore/semaphore.go b/semaphore/semaphore.go index 30f632c..b618162 100644 --- a/semaphore/semaphore.go +++ b/semaphore/semaphore.go @@ -35,11 +35,25 @@ type Weighted struct { // Acquire acquires the semaphore with a weight of n, blocking until resources // are available or ctx is done. On success, returns nil. On failure, returns // ctx.Err() and leaves the semaphore unchanged. -// -// If ctx is already done, Acquire may still succeed without blocking. func (s *Weighted) Acquire(ctx context.Context, n int64) error { + done := ctx.Done() + s.mu.Lock() + select { + case <-done: + // ctx becoming done has "happened before" acquiring the semaphore, + // whether it became done before the call began or while we were + // waiting for the mutex. We prefer to fail even if we could acquire + // the mutex without blocking. + s.mu.Unlock() + return ctx.Err() + default: + } if s.size-s.cur >= n && s.waiters.Len() == 0 { + // Since we hold s.mu and haven't synchronized since checking done, if + // ctx becomes done before we return here, it becoming done must have + // "happened concurrently" with this call - it cannot "happen before" + // we return in this branch. So, we're ok to always acquire here. s.cur += n s.mu.Unlock() return nil @@ -48,7 +62,7 @@ func (s *Weighted) Acquire(ctx context.Context, n int64) error { if n > s.size { // Don't make other Acquire calls block on one that's doomed to fail. s.mu.Unlock() - <-ctx.Done() + <-done return ctx.Err() } @@ -58,14 +72,14 @@ func (s *Weighted) Acquire(ctx context.Context, n int64) error { s.mu.Unlock() select { - case <-ctx.Done(): - err := ctx.Err() + case <-done: s.mu.Lock() select { case <-ready: - // Acquired the semaphore after we were canceled. Rather than trying to - // fix up the queue, just pretend we didn't notice the cancelation. - err = nil + // Acquired the semaphore after we were canceled. + // Pretend we didn't and put the tokens back. + s.cur -= n + s.notifyWaiters() default: isFront := s.waiters.Front() == elem s.waiters.Remove(elem) @@ -75,9 +89,19 @@ func (s *Weighted) Acquire(ctx context.Context, n int64) error { } } s.mu.Unlock() - return err + return ctx.Err() case <-ready: + // Acquired the semaphore. Check that ctx isn't already done. + // We check the done channel instead of calling ctx.Err because we + // already have the channel, and ctx.Err is O(n) with the nesting + // depth of ctx. + select { + case <-done: + s.Release(n) + return ctx.Err() + default: + } return nil } } diff --git a/semaphore/semaphore_test.go b/semaphore/semaphore_test.go index 6e8eca2..61012d6 100644 --- a/semaphore/semaphore_test.go +++ b/semaphore/semaphore_test.go @@ -200,3 +200,38 @@ func TestAllocCancelDoesntStarve(t *testing.T) { } sem.Release(1) } + +func TestWeightedAcquireCanceled(t *testing.T) { + // https://go.dev/issue/63615 + sem := semaphore.NewWeighted(2) + ctx, cancel := context.WithCancel(context.Background()) + sem.Acquire(context.Background(), 1) + ch := make(chan struct{}) + go func() { + // Synchronize with the Acquire(2) below. + for sem.TryAcquire(1) { + sem.Release(1) + } + // Now cancel ctx, and then release the token. + cancel() + sem.Release(1) + close(ch) + }() + // Since the context closing happens before enough tokens become available, + // this Acquire must fail. + if err := sem.Acquire(ctx, 2); err != context.Canceled { + t.Errorf("Acquire with canceled context returned wrong error: want context.Canceled, got %v", err) + } + // There must always be two tokens in the semaphore after the other + // goroutine releases the one we held at the start. + <-ch + if !sem.TryAcquire(2) { + t.Fatal("TryAcquire after canceled Acquire failed") + } + // Additionally verify that we don't acquire with a done context even when + // we wouldn't need to block to do so. + sem.Release(2) + if err := sem.Acquire(ctx, 1); err != context.Canceled { + t.Errorf("Acquire with canceled context returned wrong error: want context.Canceled, got %v", err) + } +}