Skip to content

Commit

Permalink
Added WaitAny support for cancellation tokens
Browse files Browse the repository at this point in the history
  • Loading branch information
sakno committed Jun 24, 2024
1 parent 97ed908 commit 14cb0cd
Show file tree
Hide file tree
Showing 3 changed files with 191 additions and 0 deletions.
39 changes: 39 additions & 0 deletions src/DotNext.Tests/Threading/AsyncBridgeTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -82,4 +82,43 @@ public static async Task CancellationTokenAwaitCornerCases()
await new CancellationToken(true).WaitAsync();
await ThrowsAsync<ArgumentException>(new CancellationToken(false).WaitAsync().AsTask);
}

[Fact]
public static async Task WaitForCancellationSingleToken()
{
using var cts = new CancellationTokenSource();
var task = AsyncBridge.WaitAnyAsync([cts.Token]);
False(task.IsCompletedSuccessfully);

cts.Cancel();
Equal(cts.Token, await task);
}

[Fact]
public static async Task WaitForCancellationTwoTokens()
{
using var cts1 = new CancellationTokenSource();
using var cts2 = new CancellationTokenSource();
var task = AsyncBridge.WaitAnyAsync([cts1.Token, cts2.Token]);
False(task.IsCompletedSuccessfully);

cts2.Cancel();
cts1.Cancel();
Equal(cts2.Token, await task);
}

[Fact]
public static async Task WaitForCancellationMultipleTokens()
{
using var cts1 = new CancellationTokenSource();
using var cts2 = new CancellationTokenSource();
using var cts3 = new CancellationTokenSource();
var task = AsyncBridge.WaitAnyAsync([cts1.Token, cts2.Token, cts3.Token]);
False(task.IsCompletedSuccessfully);

cts3.Cancel();
cts2.Cancel();
cts1.Cancel();
Equal(cts3.Token, await task);
}
}
123 changes: 123 additions & 0 deletions src/DotNext.Threading/Threading/AsyncBridge.CancellationToken.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
using System.Buffers;
using System.Collections.Concurrent;
using System.Runtime.InteropServices;
using Debug = System.Diagnostics.Debug;
using Unsafe = System.Runtime.CompilerServices.Unsafe;

Expand Down Expand Up @@ -38,6 +40,127 @@ internal void Return(CancellationTokenValueTask vt)
Add(vt);
}
}

private abstract class CancellationTokenCompletionSource : TaskCompletionSource<CancellationToken>
{
protected static readonly Action<object?, CancellationToken> Callback = OnCanceled;

private bool initialized; // volatile

protected CancellationTokenCompletionSource(out InitializationFlag flag)
: base(TaskCreationOptions.RunContinuationsAsynchronously)
=> flag = new(ref initialized);

private static void OnCanceled(object? source, CancellationToken token)
{
Debug.Assert(source is CancellationTokenCompletionSource);

Unsafe.As<CancellationTokenCompletionSource>(source).OnCanceled(token);
}

private void OnCanceled(CancellationToken token)
{
if (Volatile.Read(ref initialized) && TrySetResult(token))
{
Cleanup();
}
}

private static void Unregister(ReadOnlySpan<CancellationTokenRegistration> registrations)
{
foreach (ref readonly var registration in registrations)
{
registration.Unregister();
}
}

private protected virtual void Cleanup() => Unregister(Registrations);

private protected abstract ReadOnlySpan<CancellationTokenRegistration> Registrations { get; }

[StructLayout(LayoutKind.Auto)]
protected readonly ref struct InitializationFlag
{
private readonly ref bool flag;

internal InitializationFlag(ref bool flag) => this.flag = ref flag;

internal CancellationToken InitializationCompleted(ReadOnlySpan<CancellationTokenRegistration> registrations)
{
Volatile.Write(ref flag, true);

foreach (ref readonly var registration in registrations)
{
if (registration.Token.IsCancellationRequested)
{
return registration.Token;
}
}

return new(canceled: false);
}
}
}

private sealed class CancellationTokenCompletionSource1 : CancellationTokenCompletionSource
{
private readonly CancellationTokenRegistration registration;

internal CancellationTokenCompletionSource1(CancellationToken token)
: base(out var flag)
{
registration = token.UnsafeRegister(Callback, this);

if (flag.InitializationCompleted(new(in registration)) is { IsCancellationRequested: true } canceledToken)
Callback(this, canceledToken);
}

private protected override ReadOnlySpan<CancellationTokenRegistration> Registrations => new(in registration);
}

private sealed class CancellationTokenCompletionSource2 : CancellationTokenCompletionSource
{
private readonly (CancellationTokenRegistration, CancellationTokenRegistration) registrations;

internal CancellationTokenCompletionSource2(CancellationToken token1, CancellationToken token2)
: base(out var flag)
{
registrations.Item1 = token1.UnsafeRegister(Callback, this);
registrations.Item2 = token2.UnsafeRegister(Callback, this);

if (flag.InitializationCompleted(registrations.AsReadOnlySpan()) is { IsCancellationRequested: true } canceledToken)
Callback(this, canceledToken);
}

private protected override ReadOnlySpan<CancellationTokenRegistration> Registrations => registrations.AsReadOnlySpan();
}

private sealed class CancellationTokenCompletionSourceN : CancellationTokenCompletionSource
{
private readonly CancellationTokenRegistration[] registrations;

internal CancellationTokenCompletionSourceN(ReadOnlySpan<CancellationToken> tokens)
: base(out var flag)
{
registrations = ArrayPool<CancellationTokenRegistration>.Shared.Rent(tokens.Length);

for (var i = 0; i < tokens.Length; i++)
{
registrations[i] = tokens[i].UnsafeRegister(Callback, this);
}

if (flag.InitializationCompleted(registrations) is { IsCancellationRequested: true } canceledToken)
Callback(this, canceledToken);
}

private protected override ReadOnlySpan<CancellationTokenRegistration> Registrations => new(registrations);

private protected override void Cleanup()
{
ArrayPool<CancellationTokenRegistration>.Shared.Return(registrations, clearArray: true);
base.Cleanup();
}
}

private static readonly Action<CancellationTokenValueTask> CancellationTokenValueTaskCompletionCallback = new CancellationTokenValueTaskPool().Return;

Expand Down
29 changes: 29 additions & 0 deletions src/DotNext.Threading/Threading/AsyncBridge.cs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,35 @@ public static ValueTask WaitAsync(this CancellationToken token, bool completeAsC
return result.CreateTask(InfiniteTimeSpan, token);
}

/// <summary>
/// Creates a task that will complete when any of the supplied tokens have canceled.
/// </summary>
/// <param name="tokens">The tokens to wait on for cancellation.</param>
/// <returns>The canceled token.</returns>
/// <exception cref="InvalidOperationException"><paramref name="tokens"/> is empty.</exception>
public static Task<CancellationToken> WaitAnyAsync(this ReadOnlySpan<CancellationToken> tokens)
{
Task<CancellationToken> result;
try
{
CancellationTokenCompletionSource source = tokens switch
{
[] => throw new InvalidOperationException(),
[var token] => new CancellationTokenCompletionSource1(token),
[var token1, var token2] => new CancellationTokenCompletionSource2(token1, token2),
_ => new CancellationTokenCompletionSourceN(tokens),
};

result = source.Task;
}
catch (Exception e)
{
result = Task.FromException<CancellationToken>(e);
}

return result;
}

private static WaitHandleValueTask GetCompletionSource(WaitHandle handle, TimeSpan timeout)
{
WaitHandleValueTask? result;
Expand Down

0 comments on commit 14cb0cd

Please sign in to comment.