diff --git a/src/Microsoft.VisualStudio.Threading/TplExtensions.cs b/src/Microsoft.VisualStudio.Threading/TplExtensions.cs index d929ff1ea..7fe2a473d 100644 --- a/src/Microsoft.VisualStudio.Threading/TplExtensions.cs +++ b/src/Microsoft.VisualStudio.Threading/TplExtensions.cs @@ -255,6 +255,38 @@ public static NoThrowValueTaskAwaitable NoThrowAwaitable(this ValueTask task, bo return new NoThrowValueTaskAwaitable(task, captureContext); } + /// + /// Returns an awaitable for the specified task that will never throw, even if the source task + /// faults or is canceled. + /// + /// + /// The awaitable returned by this method does not provide access to the result of a successfully-completed + /// . To await without throwing and use the resulting value, the following + /// pattern may be used: + /// + /// + /// var methodValueTask = MethodAsync().Preserve(); + /// await methodValueTask.NoThrowAwaitable(true); + /// if (methodValueTask.IsCompletedSuccessfully) + /// { + /// var result = methodValueTask.Result; + /// } + /// else + /// { + /// var exception = methodValueTask.AsTask().Exception.InnerException; + /// } + /// + /// + /// The task whose completion should signal the completion of the returned awaitable. + /// if set to the continuation will be scheduled on the caller's context; to always execute the continuation on the threadpool. + /// An awaitable. + /// The type of the result. + [SuppressMessage("ApiDesign", "RS0026:Do not add multiple public overloads with optional parameters", Justification = "The receiver type is disjoint.")] + public static NoThrowValueTaskAwaitable NoThrowAwaitable(this ValueTask task, bool captureContext = true) + { + return new NoThrowValueTaskAwaitable(task, captureContext); + } + /// /// Consumes a task and doesn't do anything with it. Useful for fire-and-forget calls to async methods within async methods. /// @@ -876,6 +908,106 @@ public void GetResult() } } + /// + /// An awaitable that wraps a and never throws an exception when waited on. + /// + /// The type of the result. + public readonly struct NoThrowValueTaskAwaitable + { + /// + /// The task. + /// + private readonly ValueTask task; + + /// + /// A value indicating whether the continuation should be scheduled on the current sync context. + /// + private readonly bool captureContext; + + /// + /// Initializes a new instance of the struct. + /// + /// The task. + /// Whether the continuation should be scheduled on the current sync context. + public NoThrowValueTaskAwaitable(ValueTask task, bool captureContext) + { + this.task = task.Preserve(); + this.captureContext = captureContext; + } + + /// + /// Gets the awaiter. + /// + /// The awaiter. + public NoThrowValueTaskAwaiter GetAwaiter() + { + return new NoThrowValueTaskAwaiter(this.task, this.captureContext); + } + } + + /// + /// An awaiter that wraps a task and never throws an exception when waited on. + /// + /// The type of the result. + public readonly struct NoThrowValueTaskAwaiter : ICriticalNotifyCompletion + { + /// + /// The task. + /// + private readonly ValueTask task; + + /// + /// A value indicating whether the continuation should be scheduled on the current sync context. + /// + private readonly bool captureContext; + + /// + /// Initializes a new instance of the struct. + /// + /// The task. + /// if set to [capture context]. + public NoThrowValueTaskAwaiter(ValueTask task, bool captureContext) + { + this.task = task; + this.captureContext = captureContext; + } + + /// + /// Gets a value indicating whether the task has completed. + /// + public bool IsCompleted + { + get { return this.task.IsCompleted; } + } + + /// + /// Schedules a delegate for execution at the conclusion of a task's execution. + /// + /// The action. + public void OnCompleted(Action continuation) + { + this.task.ConfigureAwait(this.captureContext).GetAwaiter().OnCompleted(continuation); + } + + /// + /// Schedules a delegate for execution at the conclusion of a task's execution + /// without capturing the ExecutionContext. + /// + /// The action. + public void UnsafeOnCompleted(Action continuation) + { + this.task.ConfigureAwait(this.captureContext).GetAwaiter().UnsafeOnCompleted(continuation); + } + + /// + /// Does nothing. + /// + public void GetResult() + { + // No need to do anything with 'task' because we already called Preserve on it. + } + } + /// /// A state bag for the method. /// diff --git a/src/Microsoft.VisualStudio.Threading/net472/PublicAPI.Unshipped.txt b/src/Microsoft.VisualStudio.Threading/net472/PublicAPI.Unshipped.txt index 753755055..4b6e53cb5 100644 --- a/src/Microsoft.VisualStudio.Threading/net472/PublicAPI.Unshipped.txt +++ b/src/Microsoft.VisualStudio.Threading/net472/PublicAPI.Unshipped.txt @@ -2,13 +2,23 @@ Microsoft.VisualStudio.Threading.JoinableTaskFactory.MainThreadAwaitable.NoThrow Microsoft.VisualStudio.Threading.TplExtensions.NoThrowValueTaskAwaitable Microsoft.VisualStudio.Threading.TplExtensions.NoThrowValueTaskAwaitable.GetAwaiter() -> Microsoft.VisualStudio.Threading.TplExtensions.NoThrowValueTaskAwaiter Microsoft.VisualStudio.Threading.TplExtensions.NoThrowValueTaskAwaitable.NoThrowValueTaskAwaitable(System.Threading.Tasks.ValueTask task, bool captureContext) -> void +Microsoft.VisualStudio.Threading.TplExtensions.NoThrowValueTaskAwaitable +Microsoft.VisualStudio.Threading.TplExtensions.NoThrowValueTaskAwaitable.GetAwaiter() -> Microsoft.VisualStudio.Threading.TplExtensions.NoThrowValueTaskAwaiter +Microsoft.VisualStudio.Threading.TplExtensions.NoThrowValueTaskAwaitable.NoThrowValueTaskAwaitable(System.Threading.Tasks.ValueTask task, bool captureContext) -> void Microsoft.VisualStudio.Threading.TplExtensions.NoThrowValueTaskAwaiter Microsoft.VisualStudio.Threading.TplExtensions.NoThrowValueTaskAwaiter.GetResult() -> void Microsoft.VisualStudio.Threading.TplExtensions.NoThrowValueTaskAwaiter.IsCompleted.get -> bool Microsoft.VisualStudio.Threading.TplExtensions.NoThrowValueTaskAwaiter.NoThrowValueTaskAwaiter(System.Threading.Tasks.ValueTask task, bool captureContext) -> void Microsoft.VisualStudio.Threading.TplExtensions.NoThrowValueTaskAwaiter.OnCompleted(System.Action! continuation) -> void Microsoft.VisualStudio.Threading.TplExtensions.NoThrowValueTaskAwaiter.UnsafeOnCompleted(System.Action! continuation) -> void +Microsoft.VisualStudio.Threading.TplExtensions.NoThrowValueTaskAwaiter +Microsoft.VisualStudio.Threading.TplExtensions.NoThrowValueTaskAwaiter.GetResult() -> void +Microsoft.VisualStudio.Threading.TplExtensions.NoThrowValueTaskAwaiter.IsCompleted.get -> bool +Microsoft.VisualStudio.Threading.TplExtensions.NoThrowValueTaskAwaiter.NoThrowValueTaskAwaiter(System.Threading.Tasks.ValueTask task, bool captureContext) -> void +Microsoft.VisualStudio.Threading.TplExtensions.NoThrowValueTaskAwaiter.OnCompleted(System.Action! continuation) -> void +Microsoft.VisualStudio.Threading.TplExtensions.NoThrowValueTaskAwaiter.UnsafeOnCompleted(System.Action! continuation) -> void static Microsoft.VisualStudio.Threading.TplExtensions.NoThrowAwaitable(this System.Threading.Tasks.ValueTask task, bool captureContext = true) -> Microsoft.VisualStudio.Threading.TplExtensions.NoThrowValueTaskAwaitable +static Microsoft.VisualStudio.Threading.TplExtensions.NoThrowAwaitable(this System.Threading.Tasks.ValueTask task, bool captureContext = true) -> Microsoft.VisualStudio.Threading.TplExtensions.NoThrowValueTaskAwaitable virtual Microsoft.VisualStudio.Threading.AsyncReaderWriterResourceLock.GetTaskSchedulerToPrepareResourcesForConcurrentAccess(TResource! resource) -> System.Threading.Tasks.TaskScheduler! Microsoft.VisualStudio.Threading.JoinableTaskContext.Capture() -> string? Microsoft.VisualStudio.Threading.JoinableTaskFactory.RunAsync(System.Func! asyncMethod, string? parentToken, Microsoft.VisualStudio.Threading.JoinableTaskCreationOptions creationOptions) -> Microsoft.VisualStudio.Threading.JoinableTask! diff --git a/src/Microsoft.VisualStudio.Threading/net6.0-windows/PublicAPI.Unshipped.txt b/src/Microsoft.VisualStudio.Threading/net6.0-windows/PublicAPI.Unshipped.txt index 753755055..4b6e53cb5 100644 --- a/src/Microsoft.VisualStudio.Threading/net6.0-windows/PublicAPI.Unshipped.txt +++ b/src/Microsoft.VisualStudio.Threading/net6.0-windows/PublicAPI.Unshipped.txt @@ -2,13 +2,23 @@ Microsoft.VisualStudio.Threading.JoinableTaskFactory.MainThreadAwaitable.NoThrow Microsoft.VisualStudio.Threading.TplExtensions.NoThrowValueTaskAwaitable Microsoft.VisualStudio.Threading.TplExtensions.NoThrowValueTaskAwaitable.GetAwaiter() -> Microsoft.VisualStudio.Threading.TplExtensions.NoThrowValueTaskAwaiter Microsoft.VisualStudio.Threading.TplExtensions.NoThrowValueTaskAwaitable.NoThrowValueTaskAwaitable(System.Threading.Tasks.ValueTask task, bool captureContext) -> void +Microsoft.VisualStudio.Threading.TplExtensions.NoThrowValueTaskAwaitable +Microsoft.VisualStudio.Threading.TplExtensions.NoThrowValueTaskAwaitable.GetAwaiter() -> Microsoft.VisualStudio.Threading.TplExtensions.NoThrowValueTaskAwaiter +Microsoft.VisualStudio.Threading.TplExtensions.NoThrowValueTaskAwaitable.NoThrowValueTaskAwaitable(System.Threading.Tasks.ValueTask task, bool captureContext) -> void Microsoft.VisualStudio.Threading.TplExtensions.NoThrowValueTaskAwaiter Microsoft.VisualStudio.Threading.TplExtensions.NoThrowValueTaskAwaiter.GetResult() -> void Microsoft.VisualStudio.Threading.TplExtensions.NoThrowValueTaskAwaiter.IsCompleted.get -> bool Microsoft.VisualStudio.Threading.TplExtensions.NoThrowValueTaskAwaiter.NoThrowValueTaskAwaiter(System.Threading.Tasks.ValueTask task, bool captureContext) -> void Microsoft.VisualStudio.Threading.TplExtensions.NoThrowValueTaskAwaiter.OnCompleted(System.Action! continuation) -> void Microsoft.VisualStudio.Threading.TplExtensions.NoThrowValueTaskAwaiter.UnsafeOnCompleted(System.Action! continuation) -> void +Microsoft.VisualStudio.Threading.TplExtensions.NoThrowValueTaskAwaiter +Microsoft.VisualStudio.Threading.TplExtensions.NoThrowValueTaskAwaiter.GetResult() -> void +Microsoft.VisualStudio.Threading.TplExtensions.NoThrowValueTaskAwaiter.IsCompleted.get -> bool +Microsoft.VisualStudio.Threading.TplExtensions.NoThrowValueTaskAwaiter.NoThrowValueTaskAwaiter(System.Threading.Tasks.ValueTask task, bool captureContext) -> void +Microsoft.VisualStudio.Threading.TplExtensions.NoThrowValueTaskAwaiter.OnCompleted(System.Action! continuation) -> void +Microsoft.VisualStudio.Threading.TplExtensions.NoThrowValueTaskAwaiter.UnsafeOnCompleted(System.Action! continuation) -> void static Microsoft.VisualStudio.Threading.TplExtensions.NoThrowAwaitable(this System.Threading.Tasks.ValueTask task, bool captureContext = true) -> Microsoft.VisualStudio.Threading.TplExtensions.NoThrowValueTaskAwaitable +static Microsoft.VisualStudio.Threading.TplExtensions.NoThrowAwaitable(this System.Threading.Tasks.ValueTask task, bool captureContext = true) -> Microsoft.VisualStudio.Threading.TplExtensions.NoThrowValueTaskAwaitable virtual Microsoft.VisualStudio.Threading.AsyncReaderWriterResourceLock.GetTaskSchedulerToPrepareResourcesForConcurrentAccess(TResource! resource) -> System.Threading.Tasks.TaskScheduler! Microsoft.VisualStudio.Threading.JoinableTaskContext.Capture() -> string? Microsoft.VisualStudio.Threading.JoinableTaskFactory.RunAsync(System.Func! asyncMethod, string? parentToken, Microsoft.VisualStudio.Threading.JoinableTaskCreationOptions creationOptions) -> Microsoft.VisualStudio.Threading.JoinableTask! diff --git a/src/Microsoft.VisualStudio.Threading/net6.0/PublicAPI.Unshipped.txt b/src/Microsoft.VisualStudio.Threading/net6.0/PublicAPI.Unshipped.txt index 753755055..4b6e53cb5 100644 --- a/src/Microsoft.VisualStudio.Threading/net6.0/PublicAPI.Unshipped.txt +++ b/src/Microsoft.VisualStudio.Threading/net6.0/PublicAPI.Unshipped.txt @@ -2,13 +2,23 @@ Microsoft.VisualStudio.Threading.JoinableTaskFactory.MainThreadAwaitable.NoThrow Microsoft.VisualStudio.Threading.TplExtensions.NoThrowValueTaskAwaitable Microsoft.VisualStudio.Threading.TplExtensions.NoThrowValueTaskAwaitable.GetAwaiter() -> Microsoft.VisualStudio.Threading.TplExtensions.NoThrowValueTaskAwaiter Microsoft.VisualStudio.Threading.TplExtensions.NoThrowValueTaskAwaitable.NoThrowValueTaskAwaitable(System.Threading.Tasks.ValueTask task, bool captureContext) -> void +Microsoft.VisualStudio.Threading.TplExtensions.NoThrowValueTaskAwaitable +Microsoft.VisualStudio.Threading.TplExtensions.NoThrowValueTaskAwaitable.GetAwaiter() -> Microsoft.VisualStudio.Threading.TplExtensions.NoThrowValueTaskAwaiter +Microsoft.VisualStudio.Threading.TplExtensions.NoThrowValueTaskAwaitable.NoThrowValueTaskAwaitable(System.Threading.Tasks.ValueTask task, bool captureContext) -> void Microsoft.VisualStudio.Threading.TplExtensions.NoThrowValueTaskAwaiter Microsoft.VisualStudio.Threading.TplExtensions.NoThrowValueTaskAwaiter.GetResult() -> void Microsoft.VisualStudio.Threading.TplExtensions.NoThrowValueTaskAwaiter.IsCompleted.get -> bool Microsoft.VisualStudio.Threading.TplExtensions.NoThrowValueTaskAwaiter.NoThrowValueTaskAwaiter(System.Threading.Tasks.ValueTask task, bool captureContext) -> void Microsoft.VisualStudio.Threading.TplExtensions.NoThrowValueTaskAwaiter.OnCompleted(System.Action! continuation) -> void Microsoft.VisualStudio.Threading.TplExtensions.NoThrowValueTaskAwaiter.UnsafeOnCompleted(System.Action! continuation) -> void +Microsoft.VisualStudio.Threading.TplExtensions.NoThrowValueTaskAwaiter +Microsoft.VisualStudio.Threading.TplExtensions.NoThrowValueTaskAwaiter.GetResult() -> void +Microsoft.VisualStudio.Threading.TplExtensions.NoThrowValueTaskAwaiter.IsCompleted.get -> bool +Microsoft.VisualStudio.Threading.TplExtensions.NoThrowValueTaskAwaiter.NoThrowValueTaskAwaiter(System.Threading.Tasks.ValueTask task, bool captureContext) -> void +Microsoft.VisualStudio.Threading.TplExtensions.NoThrowValueTaskAwaiter.OnCompleted(System.Action! continuation) -> void +Microsoft.VisualStudio.Threading.TplExtensions.NoThrowValueTaskAwaiter.UnsafeOnCompleted(System.Action! continuation) -> void static Microsoft.VisualStudio.Threading.TplExtensions.NoThrowAwaitable(this System.Threading.Tasks.ValueTask task, bool captureContext = true) -> Microsoft.VisualStudio.Threading.TplExtensions.NoThrowValueTaskAwaitable +static Microsoft.VisualStudio.Threading.TplExtensions.NoThrowAwaitable(this System.Threading.Tasks.ValueTask task, bool captureContext = true) -> Microsoft.VisualStudio.Threading.TplExtensions.NoThrowValueTaskAwaitable virtual Microsoft.VisualStudio.Threading.AsyncReaderWriterResourceLock.GetTaskSchedulerToPrepareResourcesForConcurrentAccess(TResource! resource) -> System.Threading.Tasks.TaskScheduler! Microsoft.VisualStudio.Threading.JoinableTaskContext.Capture() -> string? Microsoft.VisualStudio.Threading.JoinableTaskFactory.RunAsync(System.Func! asyncMethod, string? parentToken, Microsoft.VisualStudio.Threading.JoinableTaskCreationOptions creationOptions) -> Microsoft.VisualStudio.Threading.JoinableTask! diff --git a/src/Microsoft.VisualStudio.Threading/netstandard2.0/PublicAPI.Unshipped.txt b/src/Microsoft.VisualStudio.Threading/netstandard2.0/PublicAPI.Unshipped.txt index 753755055..4b6e53cb5 100644 --- a/src/Microsoft.VisualStudio.Threading/netstandard2.0/PublicAPI.Unshipped.txt +++ b/src/Microsoft.VisualStudio.Threading/netstandard2.0/PublicAPI.Unshipped.txt @@ -2,13 +2,23 @@ Microsoft.VisualStudio.Threading.JoinableTaskFactory.MainThreadAwaitable.NoThrow Microsoft.VisualStudio.Threading.TplExtensions.NoThrowValueTaskAwaitable Microsoft.VisualStudio.Threading.TplExtensions.NoThrowValueTaskAwaitable.GetAwaiter() -> Microsoft.VisualStudio.Threading.TplExtensions.NoThrowValueTaskAwaiter Microsoft.VisualStudio.Threading.TplExtensions.NoThrowValueTaskAwaitable.NoThrowValueTaskAwaitable(System.Threading.Tasks.ValueTask task, bool captureContext) -> void +Microsoft.VisualStudio.Threading.TplExtensions.NoThrowValueTaskAwaitable +Microsoft.VisualStudio.Threading.TplExtensions.NoThrowValueTaskAwaitable.GetAwaiter() -> Microsoft.VisualStudio.Threading.TplExtensions.NoThrowValueTaskAwaiter +Microsoft.VisualStudio.Threading.TplExtensions.NoThrowValueTaskAwaitable.NoThrowValueTaskAwaitable(System.Threading.Tasks.ValueTask task, bool captureContext) -> void Microsoft.VisualStudio.Threading.TplExtensions.NoThrowValueTaskAwaiter Microsoft.VisualStudio.Threading.TplExtensions.NoThrowValueTaskAwaiter.GetResult() -> void Microsoft.VisualStudio.Threading.TplExtensions.NoThrowValueTaskAwaiter.IsCompleted.get -> bool Microsoft.VisualStudio.Threading.TplExtensions.NoThrowValueTaskAwaiter.NoThrowValueTaskAwaiter(System.Threading.Tasks.ValueTask task, bool captureContext) -> void Microsoft.VisualStudio.Threading.TplExtensions.NoThrowValueTaskAwaiter.OnCompleted(System.Action! continuation) -> void Microsoft.VisualStudio.Threading.TplExtensions.NoThrowValueTaskAwaiter.UnsafeOnCompleted(System.Action! continuation) -> void +Microsoft.VisualStudio.Threading.TplExtensions.NoThrowValueTaskAwaiter +Microsoft.VisualStudio.Threading.TplExtensions.NoThrowValueTaskAwaiter.GetResult() -> void +Microsoft.VisualStudio.Threading.TplExtensions.NoThrowValueTaskAwaiter.IsCompleted.get -> bool +Microsoft.VisualStudio.Threading.TplExtensions.NoThrowValueTaskAwaiter.NoThrowValueTaskAwaiter(System.Threading.Tasks.ValueTask task, bool captureContext) -> void +Microsoft.VisualStudio.Threading.TplExtensions.NoThrowValueTaskAwaiter.OnCompleted(System.Action! continuation) -> void +Microsoft.VisualStudio.Threading.TplExtensions.NoThrowValueTaskAwaiter.UnsafeOnCompleted(System.Action! continuation) -> void static Microsoft.VisualStudio.Threading.TplExtensions.NoThrowAwaitable(this System.Threading.Tasks.ValueTask task, bool captureContext = true) -> Microsoft.VisualStudio.Threading.TplExtensions.NoThrowValueTaskAwaitable +static Microsoft.VisualStudio.Threading.TplExtensions.NoThrowAwaitable(this System.Threading.Tasks.ValueTask task, bool captureContext = true) -> Microsoft.VisualStudio.Threading.TplExtensions.NoThrowValueTaskAwaitable virtual Microsoft.VisualStudio.Threading.AsyncReaderWriterResourceLock.GetTaskSchedulerToPrepareResourcesForConcurrentAccess(TResource! resource) -> System.Threading.Tasks.TaskScheduler! Microsoft.VisualStudio.Threading.JoinableTaskContext.Capture() -> string? Microsoft.VisualStudio.Threading.JoinableTaskFactory.RunAsync(System.Func! asyncMethod, string? parentToken, Microsoft.VisualStudio.Threading.JoinableTaskCreationOptions creationOptions) -> Microsoft.VisualStudio.Threading.JoinableTask! diff --git a/test/Microsoft.VisualStudio.Threading.Tests/TplExtensionsTests.cs b/test/Microsoft.VisualStudio.Threading.Tests/TplExtensionsTests.cs index 5adb16b60..3dd1c13d8 100644 --- a/test/Microsoft.VisualStudio.Threading.Tests/TplExtensionsTests.cs +++ b/test/Microsoft.VisualStudio.Threading.Tests/TplExtensionsTests.cs @@ -529,6 +529,153 @@ public async Task NoThrowAwaitable_ValueTask_UnsafeOnCompleted_DoesNotCaptureExe await testResultTcs.Task.WithTimeout(UnexpectedTimeout); } + [Fact] + public async Task NoThrowAwaitable_ValueTaskT_Succeeds() + { + var barrier = new TaskCompletionSource(); + var result = new object(); + var tcs = new TaskCompletionSource(); + var test = Task.Run(async () => + { + ValueTask awaitable = MethodAsync(barrier, result).Preserve(); + await awaitable.NoThrowAwaitable(); + Assert.True(awaitable.IsCompletedSuccessfully); + Assert.Same(result, awaitable.Result); + }); + + barrier.SetResult(null); + await test; + + static async ValueTask MethodAsync(TaskCompletionSource barrier, object result) + { + await barrier.Task; + return result; + } + } + + [Fact] + public async Task NoThrowAwaitable_ValueTaskT_Fails() + { + var barrier = new TaskCompletionSource(); + var result = new InvalidOperationException(); + var tcs = new TaskCompletionSource(); + var test = Task.Run(async () => + { + ValueTask awaitable = MethodAsync(barrier, result).Preserve(); + await awaitable.NoThrowAwaitable(); + Assert.True(awaitable.IsFaulted); + Assert.Same(result, awaitable.AsTask().Exception!.InnerException); + }); + + barrier.SetResult(null); + await test; + + static async ValueTask MethodAsync(TaskCompletionSource barrier, Exception result) + { + await barrier.Task; + throw result; + } + } + + [Fact] + public async Task NoThrowAwaitable_ValueTaskT() + { + var tcs = new TaskCompletionSource(); + TplExtensions.NoThrowValueTaskAwaitable nothrowTask = new ValueTask(tcs.Task).NoThrowAwaitable(); + Assert.False(nothrowTask.GetAwaiter().IsCompleted); + tcs.SetException(new InvalidOperationException()); + await nothrowTask; + + tcs = new TaskCompletionSource(); + nothrowTask = new ValueTask(tcs.Task).NoThrowAwaitable(); + Assert.False(nothrowTask.GetAwaiter().IsCompleted); + tcs.SetCanceled(); + await nothrowTask; + } + + /// + /// Verifies that independent of whether the or + /// is captured and used to schedule the continuation, the is always captured and applied. + /// + [Theory] + [CombinatorialData] + public async Task NoThrowAwaitable_ValueTaskT_Await_CapturesExecutionContext(bool captureContext) + { + var awaitableTcs = new TaskCompletionSource(); + var asyncLocal = new System.Threading.AsyncLocal(); + asyncLocal.Value = "expected"; + var testResult = Task.Run(async delegate + { + await new ValueTask(awaitableTcs.Task).NoThrowAwaitable(captureContext); // uses UnsafeOnCompleted + Assert.Equal("expected", asyncLocal.Value); + }); + asyncLocal.Value = null; + await Task.Delay(AsyncDelay); // Make sure the delegate above has time to yield + awaitableTcs.SetResult(null); + + await testResult.WithTimeout(UnexpectedTimeout); + } + + /// + /// Verifies that independent of whether the or + /// is captured and used to schedule the continuation, the is always captured and applied. + /// + [Theory] + [CombinatorialData] + public async Task NoThrowAwaitable_ValueTaskT_OnCompleted_CapturesExecutionContext(bool captureContext) + { + var testResultTcs = new TaskCompletionSource(); + var awaitableTcs = new TaskCompletionSource(); + var asyncLocal = new System.Threading.AsyncLocal(); + asyncLocal.Value = "expected"; + TplExtensions.NoThrowValueTaskAwaiter awaiter = new ValueTask(awaitableTcs.Task).NoThrowAwaitable(captureContext).GetAwaiter(); + awaiter.OnCompleted(delegate + { + try + { + Assert.Equal("expected", asyncLocal.Value); + testResultTcs.SetResult(null); + } + catch (Exception ex) + { + testResultTcs.SetException(ex); + } + }); + asyncLocal.Value = null; + await Task.Yield(); + awaitableTcs.SetResult(null); + + await testResultTcs.Task.WithTimeout(UnexpectedTimeout); + } + + [Theory] + [CombinatorialData] + public async Task NoThrowAwaitable_ValueTaskT_UnsafeOnCompleted_DoesNotCaptureExecutionContext(bool captureContext) + { + var testResultTcs = new TaskCompletionSource(); + var awaitableTcs = new TaskCompletionSource(); + var asyncLocal = new System.Threading.AsyncLocal(); + asyncLocal.Value = "expected"; + TplExtensions.NoThrowValueTaskAwaiter awaiter = new ValueTask(awaitableTcs.Task).NoThrowAwaitable(captureContext).GetAwaiter(); + awaiter.UnsafeOnCompleted(delegate + { + try + { + Assert.Null(asyncLocal.Value); + testResultTcs.SetResult(null); + } + catch (Exception ex) + { + testResultTcs.SetException(ex); + } + }); + asyncLocal.Value = null; + await Task.Yield(); + awaitableTcs.SetResult(null); + + await testResultTcs.Task.WithTimeout(UnexpectedTimeout); + } + [Fact] public void InvokeAsyncNullEverything() {