Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Task.WaitAsync methods #48842

Merged
merged 1 commit into from
Mar 11, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -558,9 +558,12 @@ protected async Task WhenAllOrAnyFailed(Task task1, Task task2)
}
else
{
var cts = new CancellationTokenSource();
await Task.WhenAny(incomplete, Task.Delay(500, cts.Token)); // give second task a chance to complete
cts.Cancel();
try
{
await incomplete.WaitAsync(TimeSpan.FromMilliseconds(500)); // give second task a chance to complete
}
catch (TimeoutException) { }

await (incomplete.IsCompleted ? Task.WhenAll(completed, incomplete) : completed);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ public Task CreateClientAndServerAsync(Func<Uri, Task> clientFunc, Func<GenericL
Task serverTask = serverFunc(server);

await new Task[] { clientTask, serverTask }.WhenAllOrAnyFailed().ConfigureAwait(false);
}, options: options).TimeoutAfter(millisecondsTimeout);
}, options: options).WaitAsync(TimeSpan.FromMilliseconds(millisecondsTimeout));
stephentoub marked this conversation as resolved.
Show resolved Hide resolved
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -241,8 +241,7 @@ public async Task ExpectSettingsAckAsync(int timeoutMs = 5000)
Task currentTask = _ignoredSettingsAckPromise?.Task;
if (currentTask != null)
{
var timeout = TimeSpan.FromMilliseconds(timeoutMs);
await currentTask.TimeoutAfter(timeout);
await currentTask.WaitAsync(TimeSpan.FromMilliseconds(timeoutMs));
}

_ignoredSettingsAckPromise = new TaskCompletionSource<bool>();
Expand Down Expand Up @@ -909,7 +908,7 @@ public override async Task WaitForCancellationAsync(bool ignoreIncomingData = tr
Frame frame;
do
{
frame = await ReadFrameAsync(TimeSpan.FromMilliseconds(TestHelper.PassingTestTimeoutMilliseconds));
frame = await ReadFrameAsync(TestHelper.PassingTestTimeout);
Assert.NotNull(frame); // We should get Rst before closing connection.
Assert.Equal(0, (int)(frame.Flags & FrameFlags.EndStream));
if (ignoreIncomingData)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ public static async Task CreateServerAsync(Func<Http2LoopbackServer, Uri, Task>
{
using (var server = Http2LoopbackServer.CreateServer())
{
await funcAsync(server, server.Address).TimeoutAfter(millisecondsTimeout).ConfigureAwait(false);
await funcAsync(server, server.Address).WaitAsync(TimeSpan.FromMilliseconds(millisecondsTimeout));
}
}

Expand Down Expand Up @@ -223,7 +223,7 @@ public override async Task CreateServerAsync(Func<GenericLoopbackServer, Uri, Ta
{
using (var server = CreateServer(options))
{
await funcAsync(server, server.Address).TimeoutAfter(millisecondsTimeout).ConfigureAwait(false);
await funcAsync(server, server.Address).WaitAsync(TimeSpan.FromMilliseconds(millisecondsTimeout));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ public override GenericLoopbackServer CreateServer(GenericLoopbackOptions option
public override async Task CreateServerAsync(Func<GenericLoopbackServer, Uri, Task> funcAsync, int millisecondsTimeout = 60000, GenericLoopbackOptions options = null)
{
using GenericLoopbackServer server = CreateServer(options);
await funcAsync(server, server.Address).TimeoutAfter(millisecondsTimeout).ConfigureAwait(false);
await funcAsync(server, server.Address).WaitAsync(TimeSpan.FromMilliseconds(millisecondsTimeout));
}

public override Task<GenericLoopbackConnection> CreateConnectionAsync(Socket socket, Stream stream, GenericLoopbackOptions options = null)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ public static async Task CreateServerAsync(Func<HttpAgnosticLoopbackServer, Uri,
{
using (var server = HttpAgnosticLoopbackServer.CreateServer())
{
await funcAsync(server, server.Address).TimeoutAfter(millisecondsTimeout).ConfigureAwait(false);
await funcAsync(server, server.Address).WaitAsync(TimeSpan.FromMilliseconds(millisecondsTimeout));
}
}

Expand Down Expand Up @@ -240,7 +240,7 @@ public override async Task CreateServerAsync(Func<GenericLoopbackServer, Uri, Ta
{
using (var server = CreateServer(options))
{
await funcAsync(server, server.Address).TimeoutAfter(millisecondsTimeout).ConfigureAwait(false);
await funcAsync(server, server.Address).WaitAsync(TimeSpan.FromMilliseconds(millisecondsTimeout));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ await LoopbackServer.CreateServerAsync(async (proxyServer, proxyUrl) =>

// Send Digest challenge.
Task<List<string>> serverTask = proxyServer.AcceptConnectionSendResponseAndCloseAsync(HttpStatusCode.ProxyAuthenticationRequired, authHeader);
if (clientTask == await Task.WhenAny(clientTask, serverTask).TimeoutAfter(TestHelper.PassingTestTimeoutMilliseconds))
if (clientTask == await Task.WhenAny(clientTask, serverTask).WaitAsync(TestHelper.PassingTestTimeout))
{
// Client task shouldn't have completed successfully; propagate failure.
Assert.NotEqual(TaskStatus.RanToCompletion, clientTask.Status);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1513,7 +1513,7 @@ await server.AcceptConnectionAsync(async connection =>
await connection.ReadRequestDataAsync(readBody: true);
}
catch { } // Eat errors from client disconnect.
await clientFinished.Task.TimeoutAfter(TimeSpan.FromMinutes(2));
await clientFinished.Task.WaitAsync(TimeSpan.FromMinutes(2));
});
});
}
Expand Down
2 changes: 2 additions & 0 deletions src/libraries/Common/tests/System/Net/Http/TestHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ namespace System.Net.Http.Functional.Tests
{
public static class TestHelper
{
public static TimeSpan PassingTestTimeout => TimeSpan.FromMilliseconds(PassingTestTimeoutMilliseconds);
public static int PassingTestTimeoutMilliseconds => 60 * 1000;
stephentoub marked this conversation as resolved.
Show resolved Hide resolved

public static bool JsonMessageContainsKeyValue(string message, string key, string value)
{
// Deal with JSON encoding of '\' and '"' in value
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,91 +3,50 @@

using System.Collections.Generic;
using System.Diagnostics;
using System.Runtime.CompilerServices;

/// <summary>
/// Task timeout helper based on https://devblogs.microsoft.com/pfxteam/crafting-a-task-timeoutafter-method/
/// </summary>
namespace System.Threading.Tasks
{
public static class TaskTimeoutExtensions
{
public static async Task WithCancellation(this Task task, CancellationToken cancellationToken)
{
var tcs = new TaskCompletionSource<bool>();
using (cancellationToken.Register(s => ((TaskCompletionSource<bool>)s).TrySetResult(true), tcs))
{
if (task != await Task.WhenAny(task, tcs.Task).ConfigureAwait(false))
{
throw new OperationCanceledException(cancellationToken);
}
await task; // already completed; propagate any exception
}
}

public static Task TimeoutAfter(this Task task, int millisecondsTimeout)
=> task.TimeoutAfter(TimeSpan.FromMilliseconds(millisecondsTimeout));

public static async Task TimeoutAfter(this Task task, TimeSpan timeout)
{
var cts = new CancellationTokenSource();
#region WaitAsync polyfills
// Test polyfills when targeting a platform that doesn't have these ConfigureAwait overloads on Task

if (task == await Task.WhenAny(task, Task.Delay(timeout, cts.Token)).ConfigureAwait(false))
{
cts.Cancel();
await task.ConfigureAwait(false);
}
else
{
throw new TimeoutException($"Task timed out after {timeout}");
}
}
public static Task WaitAsync(this Task task, TimeSpan timeout) =>
WaitAsync(task, timeout, default);

public static Task<TResult> TimeoutAfter<TResult>(this Task<TResult> task, int millisecondsTimeout)
=> task.TimeoutAfter(TimeSpan.FromMilliseconds(millisecondsTimeout));
public static Task WaitAsync(this Task task, CancellationToken cancellationToken) =>
WaitAsync(task, Timeout.InfiniteTimeSpan, cancellationToken);

public static async Task<TResult> TimeoutAfter<TResult>(this Task<TResult> task, TimeSpan timeout)
public async static Task WaitAsync(this Task task, TimeSpan timeout, CancellationToken cancellationToken)
{
var cts = new CancellationTokenSource();

if (task == await Task<TResult>.WhenAny(task, Task<TResult>.Delay(timeout, cts.Token)).ConfigureAwait(false))
{
cts.Cancel();
return await task.ConfigureAwait(false);
}
else
var tcs = new TaskCompletionSource<bool>();
using (new Timer(s => ((TaskCompletionSource<bool>)s).TrySetException(new TimeoutException()), tcs, timeout, Timeout.InfiniteTimeSpan))
stephentoub marked this conversation as resolved.
Show resolved Hide resolved
using (cancellationToken.Register(s => ((TaskCompletionSource<bool>)s).TrySetCanceled(), tcs))
{
throw new TimeoutException($"Task timed out after {timeout}");
await(await Task.WhenAny(task, tcs.Task).ConfigureAwait(false)).ConfigureAwait(false);
}
}

#if !NETFRAMEWORK
public static Task TimeoutAfter(this ValueTask task, int millisecondsTimeout)
=> task.AsTask().TimeoutAfter(TimeSpan.FromMilliseconds(millisecondsTimeout));

public static Task TimeoutAfter(this ValueTask task, TimeSpan timeout)
=> task.AsTask().TimeoutAfter(timeout);

public static Task<TResult> TimeoutAfter<TResult>(this ValueTask<TResult> task, int millisecondsTimeout)
=> task.AsTask().TimeoutAfter(TimeSpan.FromMilliseconds(millisecondsTimeout));
public static Task<TResult> WaitAsync<TResult>(this Task<TResult> task, TimeSpan timeout) =>
WaitAsync(task, timeout, default);

public static Task<TResult> TimeoutAfter<TResult>(this ValueTask<TResult> task, TimeSpan timeout)
=> task.AsTask().TimeoutAfter(timeout);
#endif
public static Task<TResult> WaitAsync<TResult>(this Task<TResult> task, CancellationToken cancellationToken) =>
WaitAsync(task, Timeout.InfiniteTimeSpan, cancellationToken);

public static async Task WhenAllOrAnyFailed(this Task[] tasks, int millisecondsTimeout)
public static async Task<TResult> WaitAsync<TResult>(this Task<TResult> task, TimeSpan timeout, CancellationToken cancellationToken)
{
var cts = new CancellationTokenSource();
Task task = tasks.WhenAllOrAnyFailed();
if (task == await Task.WhenAny(task, Task.Delay(millisecondsTimeout, cts.Token)).ConfigureAwait(false))
var tcs = new TaskCompletionSource<TResult>();
using (new Timer(s => ((TaskCompletionSource<TResult>)s).TrySetException(new TimeoutException()), tcs, timeout, Timeout.InfiniteTimeSpan))
using (cancellationToken.Register(s => ((TaskCompletionSource<TResult>)s).TrySetCanceled(), tcs))
{
cts.Cancel();
await task.ConfigureAwait(false);
}
else
{
throw new TimeoutException($"{nameof(WhenAllOrAnyFailed)} timed out after {millisecondsTimeout}ms");
return await (await Task.WhenAny(task, tcs.Task).ConfigureAwait(false)).ConfigureAwait(false);
}
}
#endregion

public static async Task WhenAllOrAnyFailed(this Task[] tasks, int millisecondsTimeout) =>
await tasks.WhenAllOrAnyFailed().WaitAsync(TimeSpan.FromMilliseconds(millisecondsTimeout));

public static async Task WhenAllOrAnyFailed(this Task[] tasks)
{
Expand All @@ -99,12 +58,11 @@ public static async Task WhenAllOrAnyFailed(this Task[] tasks)
{
// Wait a bit to allow other tasks to complete so we can include their exceptions
// in the error we throw.
using (var cts = new CancellationTokenSource())
try
{
await Task.WhenAny(
Task.WhenAll(tasks),
Task.Delay(3_000, cts.Token)).ConfigureAwait(false); // arbitrary delay; can be dialed up or down in the future
await Task.WhenAll(tasks).WaitAsync(TimeSpan.FromSeconds(3)); // arbitrary delay; can be dialed up or down in the future
ManickaP marked this conversation as resolved.
Show resolved Hide resolved
}
catch { }
ManickaP marked this conversation as resolved.
Show resolved Hide resolved

var exceptions = new List<Exception>();
foreach (Task t in tasks)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1286,7 +1286,7 @@ void Fail(object state)

fileSystemWatcher.CallOnRenamed(new RenamedEventArgs(WatcherChangeTypes.Renamed, root.RootPath, newDirectoryName, oldDirectoryName));

await Task.WhenAll(oldDirectoryTcs.Task, newDirectoryTcs.Task, newSubDirectoryTcs.Task, newFileTcs.Task).TimeoutAfter(TimeSpan.FromSeconds(30));
await Task.WhenAll(oldDirectoryTcs.Task, newDirectoryTcs.Task, newSubDirectoryTcs.Task, newFileTcs.Task).WaitAsync(TimeSpan.FromSeconds(30));

Assert.False(oldSubDirectoryToken.HasChanged, "Old subdirectory token should not have changed");
Assert.False(oldFileToken.HasChanged, "Old file token should not have changed");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ protected async Task<CancellationToken> StartSelfHostAsync()
// The timeout here is large, because we don't know how long the test could need
// We cover a lot of error cases above, but I want to make sure we eventually give up and don't hang the build
// just in case we missed one -anurse
await started.Task.TimeoutAfter(TimeSpan.FromMinutes(10));
await started.Task.WaitAsync(TimeSpan.FromMinutes(10));
}

return hostExitTokenSource.Token;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,11 @@ private async Task ExecuteShutdownTest(string testName, string shutdownMechanic)
}
};

await started.Task.TimeoutAfter(TimeSpan.FromSeconds(60));
await started.Task.WaitAsync(TimeSpan.FromSeconds(60));

SendShutdownSignal(deployer.HostProcess);

await completed.Task.TimeoutAfter(TimeSpan.FromSeconds(60));
await completed.Task.WaitAsync(TimeSpan.FromSeconds(60));

WaitForExitOrKill(deployer.HostProcess);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ private async Task DictionaryConcurrentAccessDetection<TKey, TValue>(Dictionary<
}, TaskCreationOptions.LongRunning);

// If Dictionary regresses, we do not want to hang here indefinitely
Assert.True((await Task.WhenAny(task, Task.Delay(TimeSpan.FromSeconds(60))) == task) && task.IsCompletedSuccessfully);
await task.WaitAsync(TimeSpan.FromSeconds(60));
}

[ConditionalTheory(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,7 @@ public async Task RunWorkerAsync_NoOnWorkHandler_SetsResultToNull()

backgroundWorker.RunWorkerAsync();

await Task.WhenAny(tcs.Task, Task.Delay(TimeSpan.FromSeconds(10))); // Usually takes 100th of a sec
Assert.True(tcs.Task.IsCompleted);
await tcs.Task.WaitAsync(TimeSpan.FromSeconds(10)); // Usually takes 100th of a sec
stephentoub marked this conversation as resolved.
Show resolved Hide resolved
}

#region TestCancelAsync
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,6 @@
Link="System\PasteArguments.cs" />
<Compile Include="$(CommonPath)Interop\Windows\Interop.Errors.cs"
Link="Common\Interop\Windows\Interop.Errors.cs" />
<Compile Include="$(CommonPath)System\Threading\Tasks\TaskCompletionSourceWithCancellation.cs"
Link="Common\System\Threading\Tasks\TaskCompletionSourceWithCancellation.cs" />
<Compile Include="$(CommonPath)System\Threading\Tasks\TaskTimeoutExtensions.cs"
Link="Common\System\Threading\Tasks\TaskTimeoutExtensions.cs" />
<Compile Include="$(CommonPath)System\Text\ValueStringBuilder.cs"
Link="Common\System\Text\ValueStringBuilder.cs" />
</ItemGroup>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -251,25 +251,7 @@ private bool FlushMessageQueue(bool rethrowInNewThread)
}
}

// Wait until we hit EOF. This is called from Process.WaitForExit
// We will lose some information if we don't do this.
internal void WaitUntilEOF()
{
if (_readToBufferTask is Task task)
{
task.GetAwaiter().GetResult();
}
}

internal Task WaitUntilEOFAsync(CancellationToken cancellationToken)
{
if (_readToBufferTask is Task task)
{
return task.WithCancellation(cancellationToken);
}

return Task.CompletedTask;
}
internal Task EOF => _readToBufferTask ?? Task.CompletedTask;

public void Dispose()
{
Expand Down
Loading