Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update RateLimiter queues on cancellation #64825

Merged
merged 3 commits into from
Feb 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -118,13 +118,13 @@ protected override ValueTask<RateLimitLease> WaitAsyncCore(int permitCount, Canc
}
}

TaskCompletionSource<RateLimitLease> tcs = new TaskCompletionSource<RateLimitLease>(TaskCreationOptions.RunContinuationsAsynchronously);
CancelQueueState tcs = new CancelQueueState(permitCount, this, cancellationToken);
CancellationTokenRegistration ctr = default;
if (cancellationToken.CanBeCanceled)
{
ctr = cancellationToken.Register(static obj =>
{
((TaskCompletionSource<RateLimitLease>)obj!).TrySetException(new OperationCanceledException());
((CancelQueueState)obj!).TrySetCanceled();
}, tcs);
}

Expand Down Expand Up @@ -194,7 +194,6 @@ private void Release(int releaseCount)

_permitCount -= nextPendingRequest.Count;
_queueCount -= nextPendingRequest.Count;
Debug.Assert(_queueCount >= 0);
Debug.Assert(_permitCount >= 0);

ConcurrencyLease lease = nextPendingRequest.Count == 0 ? SuccessfulLease : new ConcurrencyLease(true, this, nextPendingRequest.Count);
Expand All @@ -203,8 +202,11 @@ private void Release(int releaseCount)
{
// Queued item was canceled so add count back
_permitCount += nextPendingRequest.Count;
// Updating queue count is handled by the cancellation code
_queueCount += nextPendingRequest.Count;
}
nextPendingRequest.CancellationTokenRegistration.Dispose();
Debug.Assert(_queueCount >= 0);
}
else
{
Expand Down Expand Up @@ -319,5 +321,33 @@ public RequestRegistration(int requestedCount, TaskCompletionSource<RateLimitLea

public CancellationTokenRegistration CancellationTokenRegistration { get; }
}

private sealed class CancelQueueState : TaskCompletionSource<RateLimitLease>
{
private readonly int _permitCount;
private readonly ConcurrencyLimiter _limiter;
private readonly CancellationToken _cancellationToken;

public CancelQueueState(int permitCount, ConcurrencyLimiter limiter, CancellationToken cancellationToken)
: base(TaskCreationOptions.RunContinuationsAsynchronously)
{
_permitCount = permitCount;
_limiter = limiter;
_cancellationToken = cancellationToken;
}

public new bool TrySetCanceled()
{
if (TrySetCanceled(_cancellationToken))
{
lock (_limiter.Lock)
{
_limiter._queueCount -= _permitCount;
}
return true;
}
return false;
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -124,14 +124,13 @@ protected override ValueTask<RateLimitLease> WaitAsyncCore(int tokenCount, Cance
}
}

TaskCompletionSource<RateLimitLease> tcs = new TaskCompletionSource<RateLimitLease>(TaskCreationOptions.RunContinuationsAsynchronously);

CancelQueueState tcs = new CancelQueueState(tokenCount, this, cancellationToken);
CancellationTokenRegistration ctr = default;
if (cancellationToken.CanBeCanceled)
{
ctr = cancellationToken.Register(static obj =>
{
((TaskCompletionSource<RateLimitLease>)obj!).TrySetException(new OperationCanceledException());
((CancelQueueState)obj!).TrySetCanceled();
}, tcs);
}

Expand All @@ -140,7 +139,6 @@ protected override ValueTask<RateLimitLease> WaitAsyncCore(int tokenCount, Cance
_queueCount += tokenCount;
Debug.Assert(_queueCount <= _options.QueueLimit);

// handle cancellation
return new ValueTask<RateLimitLease>(registration.Tcs.Task);
}
}
Expand Down Expand Up @@ -276,15 +274,17 @@ private void ReplenishInternal(uint nowTicks)

_queueCount -= nextPendingRequest.Count;
_tokenCount -= nextPendingRequest.Count;
Debug.Assert(_queueCount >= 0);
Debug.Assert(_tokenCount >= 0);

if (!nextPendingRequest.Tcs.TrySetResult(SuccessfulLease))
{
// Queued item was canceled so add count back
_tokenCount += nextPendingRequest.Count;
// Updating queue count is handled by the cancellation code
_queueCount += nextPendingRequest.Count;
}
nextPendingRequest.CancellationTokenRegistration.Dispose();
Debug.Assert(_queueCount >= 0);
}
else
{
Expand Down Expand Up @@ -380,7 +380,34 @@ public RequestRegistration(int tokenCount, TaskCompletionSource<RateLimitLease>
public TaskCompletionSource<RateLimitLease> Tcs { get; }

public CancellationTokenRegistration CancellationTokenRegistration { get; }
}

private sealed class CancelQueueState : TaskCompletionSource<RateLimitLease>
{
private readonly int _tokenCount;
private readonly TokenBucketRateLimiter _limiter;
private readonly CancellationToken _cancellationToken;

public CancelQueueState(int tokenCount, TokenBucketRateLimiter limiter, CancellationToken cancellationToken)
: base(TaskCreationOptions.RunContinuationsAsynchronously)
{
_tokenCount = tokenCount;
_limiter = limiter;
_cancellationToken = cancellationToken;
}

public new bool TrySetCanceled()
{
if (TrySetCanceled(_cancellationToken))
{
lock (_limiter.Lock)
{
_limiter._queueCount -= _tokenCount;
}
return true;
}
return false;
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ public abstract class BaseRateLimiterTests
[Fact]
public abstract Task CanCancelWaitAsyncBeforeQueuing();

[Fact]
public abstract Task CancelUpdatesQueueLimit();

[Fact]
public abstract Task CanAcquireResourcesWithAcquireWithQueuedItemsIfNewestFirst();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,8 @@ public override async Task CanCancelWaitAsyncAfterQueuing()
var wait = limiter.WaitAsync(1, cts.Token);

cts.Cancel();
await Assert.ThrowsAsync<OperationCanceledException>(() => wait.AsTask());
var ex = await Assert.ThrowsAsync<TaskCanceledException>(() => wait.AsTask());
Assert.Equal(cts.Token, ex.CancellationToken);

lease.Dispose();

Expand All @@ -418,13 +419,36 @@ public override async Task CanCancelWaitAsyncBeforeQueuing()
var cts = new CancellationTokenSource();
cts.Cancel();

await Assert.ThrowsAsync<TaskCanceledException>(() => limiter.WaitAsync(1, cts.Token).AsTask());
var ex = await Assert.ThrowsAsync<TaskCanceledException>(() => limiter.WaitAsync(1, cts.Token).AsTask());
Assert.Equal(cts.Token, ex.CancellationToken);

lease.Dispose();

Assert.Equal(1, limiter.GetAvailablePermits());
}

[Fact]
public override async Task CancelUpdatesQueueLimit()
{
var limiter = new ConcurrencyLimiter(new ConcurrencyLimiterOptions(1, QueueProcessingOrder.OldestFirst, 1));
var lease = limiter.Acquire(1);
Assert.True(lease.IsAcquired);

var cts = new CancellationTokenSource();
var wait = limiter.WaitAsync(1, cts.Token);

cts.Cancel();
var ex = await Assert.ThrowsAsync<TaskCanceledException>(() => wait.AsTask());
Assert.Equal(cts.Token, ex.CancellationToken);

wait = limiter.WaitAsync(1);
Assert.False(wait.IsCompleted);

lease.Dispose();
lease = await wait;
Assert.True(lease.IsAcquired);
}

[Fact]
public override void NoMetadataOnAcquiredLease()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,8 @@ public override async Task CanCancelWaitAsyncAfterQueuing()
var wait = limiter.WaitAsync(1, cts.Token);

cts.Cancel();
await Assert.ThrowsAsync<OperationCanceledException>(() => wait.AsTask());
var ex = await Assert.ThrowsAsync<TaskCanceledException>(() => wait.AsTask());
Assert.Equal(cts.Token, ex.CancellationToken);

lease.Dispose();
Assert.True(limiter.TryReplenish());
Expand All @@ -373,14 +374,38 @@ public override async Task CanCancelWaitAsyncBeforeQueuing()
var cts = new CancellationTokenSource();
cts.Cancel();

await Assert.ThrowsAsync<TaskCanceledException>(() => limiter.WaitAsync(1, cts.Token).AsTask());
var ex = await Assert.ThrowsAsync<TaskCanceledException>(() => limiter.WaitAsync(1, cts.Token).AsTask());
Assert.Equal(cts.Token, ex.CancellationToken);

lease.Dispose();
Assert.True(limiter.TryReplenish());

Assert.Equal(1, limiter.GetAvailablePermits());
}

[Fact]
public override async Task CancelUpdatesQueueLimit()
{
var limiter = new TokenBucketRateLimiter(new TokenBucketRateLimiterOptions(1, QueueProcessingOrder.OldestFirst, 1,
TimeSpan.Zero, 1, autoReplenishment: false));
var lease = limiter.Acquire(1);
Assert.True(lease.IsAcquired);

var cts = new CancellationTokenSource();
var wait = limiter.WaitAsync(1, cts.Token);

cts.Cancel();
var ex = await Assert.ThrowsAsync<TaskCanceledException>(() => wait.AsTask());
Assert.Equal(cts.Token, ex.CancellationToken);

wait = limiter.WaitAsync(1);
Assert.False(wait.IsCompleted);

limiter.TryReplenish();
lease = await wait;
Assert.True(lease.IsAcquired);
}

[Fact]
public override void NoMetadataOnAcquiredLease()
{
Expand Down