From ff93f2dc3619d28b87404d1f86439f4b1d6e8d02 Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Sat, 20 Jan 2024 06:30:34 -0500 Subject: [PATCH] Add TaskCompletionSource.SetFromTask (#97077) --- .../src/System/Net/TaskExtensions.cs | 15 +-- .../src/Resources/Strings.resx | 5 +- .../Threading/Tasks/TaskCompletionSource.cs | 70 +++++++++++ .../Threading/Tasks/TaskCompletionSource_T.cs | 71 ++++++++++++ .../System.Runtime/ref/System.Runtime.cs | 4 + .../Task/TaskCompletionSourceTResultTests.cs | 109 ++++++++++++++++++ .../Task/TaskCompletionSourceTests.cs | 107 +++++++++++++++++ 7 files changed, 366 insertions(+), 15 deletions(-) diff --git a/src/libraries/System.Net.Requests/src/System/Net/TaskExtensions.cs b/src/libraries/System.Net.Requests/src/System/Net/TaskExtensions.cs index 7732cf149317b..bb0d98c40e6e1 100644 --- a/src/libraries/System.Net.Requests/src/System/Net/TaskExtensions.cs +++ b/src/libraries/System.Net.Requests/src/System/Net/TaskExtensions.cs @@ -18,20 +18,7 @@ public static TaskCompletionSource ToApm( task.ContinueWith(completedTask => { - bool shouldInvokeCallback = false; - - if (completedTask.IsFaulted) - { - shouldInvokeCallback = tcs.TrySetException(completedTask.Exception!.InnerExceptions); - } - else if (completedTask.IsCanceled) - { - shouldInvokeCallback = tcs.TrySetCanceled(); - } - else - { - shouldInvokeCallback = tcs.TrySetResult(completedTask.Result); - } + bool shouldInvokeCallback = tcs.TrySetFromTask(completedTask); // Only invoke the callback if it exists AND we were able to transition the TCS // to the terminal state. If we couldn't transition the task it is because it was diff --git a/src/libraries/System.Private.CoreLib/src/Resources/Strings.resx b/src/libraries/System.Private.CoreLib/src/Resources/Strings.resx index ac57b0760021b..d2bc611dbbbee 100644 --- a/src/libraries/System.Private.CoreLib/src/Resources/Strings.resx +++ b/src/libraries/System.Private.CoreLib/src/Resources/Strings.resx @@ -3527,6 +3527,9 @@ The tasks array included at least one null element. + + The provided task must have already completed. + Task<TResult>.ConfigureAwait does not support ConfigureAwaitOptions.SuppressThrowing. To suppress throwing, instead cast the Task<TResult> to its base class Task and await that with SuppressThrowing. @@ -4286,4 +4289,4 @@ This operation is not available because the reflection support was disabled at compile time. - \ No newline at end of file + diff --git a/src/libraries/System.Private.CoreLib/src/System/Threading/Tasks/TaskCompletionSource.cs b/src/libraries/System.Private.CoreLib/src/System/Threading/Tasks/TaskCompletionSource.cs index ee096f6816559..f7ec3ea6efa65 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Threading/Tasks/TaskCompletionSource.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Threading/Tasks/TaskCompletionSource.cs @@ -285,5 +285,75 @@ public bool TrySetCanceled(CancellationToken cancellationToken) return rval; } + + /// + /// Transition the underlying into the same completion state as the specified . + /// + /// The completed task whose completion status (including exception or cancellation information) should be copied to the underlying task. + /// is . + /// is not completed. + /// + /// The underlying is already in one of the three final states: + /// , , or . + /// + /// + /// This operation will return false if the is already in one of the three final states: + /// , , or . + /// + public void SetFromTask(Task completedTask) + { + if (!TrySetFromTask(completedTask)) + { + ThrowHelper.ThrowInvalidOperationException(ExceptionResource.TaskT_TransitionToFinal_AlreadyCompleted); + } + } + + /// + /// Attempts to transition the underlying into the same completion state as the specified . + /// + /// The completed task whose completion status (including exception or cancellation information) should be copied to the underlying task. + /// if the operation was successful; otherwise, . + /// is . + /// is not completed. + /// + /// This operation will return false if the is already in one of the three final states: + /// , , or . + /// + public bool TrySetFromTask(Task completedTask) + { + ArgumentNullException.ThrowIfNull(completedTask); + if (!completedTask.IsCompleted) + { + throw new ArgumentException(SR.Task_MustBeCompleted, nameof(completedTask)); + } + + // Try to transition to the appropriate final state based on the state of completedTask. + bool result = false; + switch (completedTask.Status) + { + case TaskStatus.RanToCompletion: + result = _task.TrySetResult(); + break; + + case TaskStatus.Canceled: + result = _task.TrySetCanceled(completedTask.CancellationToken, completedTask.GetCancellationExceptionDispatchInfo()); + break; + + case TaskStatus.Faulted: + result = _task.TrySetException(completedTask.GetExceptionDispatchInfos()); + break; + } + + // If we successfully transitioned to a final state, we're done. If we didn't, it's possible a concurrent operation + // is still in the process of completing the task, and callers of this method expect the task to already be fully + // completed when this method returns. As such, we spin until the task is completed, and then return whether this + // call successfully did the transition. + if (!result && !_task.IsCompleted) + { + _task.SpinUntilCompleted(); + } + + return result; + } } } diff --git a/src/libraries/System.Private.CoreLib/src/System/Threading/Tasks/TaskCompletionSource_T.cs b/src/libraries/System.Private.CoreLib/src/System/Threading/Tasks/TaskCompletionSource_T.cs index 977db4b94b25d..2a7f76029d119 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Threading/Tasks/TaskCompletionSource_T.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Threading/Tasks/TaskCompletionSource_T.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Collections.Generic; +using System.Runtime.ExceptionServices; namespace System.Threading.Tasks { @@ -286,5 +287,75 @@ public bool TrySetCanceled(CancellationToken cancellationToken) return rval; } + + /// + /// Transition the underlying into the same completion state as the specified . + /// + /// The completed task whose completion status (including result, exception, or cancellation information) should be copied to the underlying task. + /// is . + /// is not completed. + /// + /// The underlying is already in one of the three final states: + /// , , or . + /// + /// + /// This operation will return false if the is already in one of the three final states: + /// , , or . + /// + public void SetFromTask(Task completedTask) + { + if (!TrySetFromTask(completedTask)) + { + ThrowHelper.ThrowInvalidOperationException(ExceptionResource.TaskT_TransitionToFinal_AlreadyCompleted); + } + } + + /// + /// Attempts to transition the underlying into the same completion state as the specified . + /// + /// The completed task whose completion status (including result, exception, or cancellation information) should be copied to the underlying task. + /// if the operation was successful; otherwise, . + /// is . + /// is not completed. + /// + /// This operation will return false if the is already in one of the three final states: + /// , , or . + /// + public bool TrySetFromTask(Task completedTask) + { + ArgumentNullException.ThrowIfNull(completedTask); + if (!completedTask.IsCompleted) + { + throw new ArgumentException(SR.Task_MustBeCompleted, nameof(completedTask)); + } + + // Try to transition to the appropriate final state based on the state of completedTask. + bool result = false; + switch (completedTask.Status) + { + case TaskStatus.RanToCompletion: + result = _task.TrySetResult(completedTask.Result); + break; + + case TaskStatus.Canceled: + result = _task.TrySetCanceled(completedTask.CancellationToken, completedTask.GetCancellationExceptionDispatchInfo()); + break; + + case TaskStatus.Faulted: + result = _task.TrySetException(completedTask.GetExceptionDispatchInfos()); + break; + } + + // If we successfully transitioned to a final state, we're done. If we didn't, it's possible a concurrent operation + // is still in the process of completing the task, and callers of this method expect the task to already be fully + // completed when this method returns. As such, we spin until the task is completed, and then return whether this + // call successfully did the transition. + if (!result && !_task.IsCompleted) + { + _task.SpinUntilCompleted(); + } + + return result; + } } } diff --git a/src/libraries/System.Runtime/ref/System.Runtime.cs b/src/libraries/System.Runtime/ref/System.Runtime.cs index d6442299edfcd..0349694ed7024 100644 --- a/src/libraries/System.Runtime/ref/System.Runtime.cs +++ b/src/libraries/System.Runtime/ref/System.Runtime.cs @@ -15352,6 +15352,7 @@ public TaskCompletionSource(System.Threading.Tasks.TaskCreationOptions creationO public System.Threading.Tasks.Task Task { get { throw null; } } public void SetCanceled() { } public void SetCanceled(System.Threading.CancellationToken cancellationToken) { } + public void SetFromTask(System.Threading.Tasks.Task completedTask) { throw null; } public void SetException(System.Collections.Generic.IEnumerable exceptions) { } public void SetException(System.Exception exception) { } public void SetResult() { } @@ -15359,6 +15360,7 @@ public void SetResult() { } public bool TrySetCanceled(System.Threading.CancellationToken cancellationToken) { throw null; } public bool TrySetException(System.Collections.Generic.IEnumerable exceptions) { throw null; } public bool TrySetException(System.Exception exception) { throw null; } + public bool TrySetFromTask(System.Threading.Tasks.Task completedTask) { throw null; } public bool TrySetResult() { throw null; } } public partial class TaskCompletionSource @@ -15370,11 +15372,13 @@ public TaskCompletionSource(System.Threading.Tasks.TaskCreationOptions creationO public System.Threading.Tasks.Task Task { get { throw null; } } public void SetCanceled() { } public void SetCanceled(System.Threading.CancellationToken cancellationToken) { } + public void SetFromTask(System.Threading.Tasks.Task completedTask) { throw null; } public void SetException(System.Collections.Generic.IEnumerable exceptions) { } public void SetException(System.Exception exception) { } public void SetResult(TResult result) { } public bool TrySetCanceled() { throw null; } public bool TrySetCanceled(System.Threading.CancellationToken cancellationToken) { throw null; } + public bool TrySetFromTask(System.Threading.Tasks.Task completedTask) { throw null; } public bool TrySetException(System.Collections.Generic.IEnumerable exceptions) { throw null; } public bool TrySetException(System.Exception exception) { throw null; } public bool TrySetResult(TResult result) { throw null; } diff --git a/src/libraries/System.Runtime/tests/System.Threading.Tasks.Tests/Task/TaskCompletionSourceTResultTests.cs b/src/libraries/System.Runtime/tests/System.Threading.Tasks.Tests/Task/TaskCompletionSourceTResultTests.cs index 9a08e529abf93..9bc44b683530e 100644 --- a/src/libraries/System.Runtime/tests/System.Threading.Tasks.Tests/Task/TaskCompletionSourceTResultTests.cs +++ b/src/libraries/System.Runtime/tests/System.Threading.Tasks.Tests/Task/TaskCompletionSourceTResultTests.cs @@ -202,5 +202,114 @@ private static void AssertCompletedTcsFailsToCompleteAgain(TaskCompletionSour Assert.False(tcs.TrySetCanceled()); Assert.False(tcs.TrySetCanceled(default)); } + + [Fact] + public void SetFromTask_InvalidArgument_Throws() + { + TaskCompletionSource tcs = new(); + AssertExtensions.Throws("completedTask", () => tcs.SetFromTask(null)); + AssertExtensions.Throws("completedTask", () => tcs.SetFromTask(new TaskCompletionSource().Task)); + Assert.False(tcs.Task.IsCompleted); + + tcs.SetResult(null); + Assert.True(tcs.Task.IsCompletedSuccessfully); + + AssertExtensions.Throws("completedTask", () => tcs.SetFromTask(null)); + AssertExtensions.Throws("completedTask", () => tcs.SetFromTask(new TaskCompletionSource().Task)); + Assert.True(tcs.Task.IsCompletedSuccessfully); + } + + [Fact] + public void SetFromTask_AlreadyCompleted_ReturnsFalseOrThrows() + { + object result = new(); + TaskCompletionSource tcs = new(); + tcs.SetResult(result); + + Assert.False(tcs.TrySetFromTask(Task.FromResult(new object()))); + Assert.False(tcs.TrySetFromTask(Task.FromException(new Exception()))); + Assert.False(tcs.TrySetFromTask(Task.FromCanceled(new CancellationToken(canceled: true)))); + + Assert.Throws(() => tcs.SetFromTask(Task.FromResult(new object()))); + Assert.Throws(() => tcs.SetFromTask(Task.FromException(new Exception()))); + Assert.Throws(() => tcs.SetFromTask(Task.FromCanceled(new CancellationToken(canceled: true)))); + + Assert.True(tcs.Task.IsCompletedSuccessfully); + Assert.Same(result, tcs.Task.Result); + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public void SetFromTask_CompletedSuccessfully(bool tryMethod) + { + TaskCompletionSource tcs = new(); + Task source = Task.FromResult(new object()); + + if (tryMethod) + { + Assert.True(tcs.TrySetFromTask(source)); + } + else + { + tcs.SetFromTask(source); + } + + Assert.Same(source.Result, tcs.Task.Result); + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public void SetFromTask_Faulted(bool tryMethod) + { + TaskCompletionSource tcs = new(); + + var source = new TaskCompletionSource(); + source.SetException([new FormatException(), new DivideByZeroException()]); + + if (tryMethod) + { + Assert.True(tcs.TrySetFromTask(source.Task)); + } + else + { + tcs.SetFromTask(source.Task); + } + + Assert.True(tcs.Task.IsFaulted); + Assert.True(tcs.Task.Exception.InnerExceptions.Count == 2); + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public void SetFromTask_Canceled(bool tryMethod) + { + TaskCompletionSource tcs = new(); + + var cts = new CancellationTokenSource(); + cts.Cancel(); + Task source = Task.FromCanceled(cts.Token); + + if (tryMethod) + { + Assert.True(tcs.TrySetFromTask(source)); + } + else + { + tcs.SetFromTask(source); + } + + Assert.True(tcs.Task.IsCanceled); + try + { + tcs.Task.GetAwaiter().GetResult(); + } + catch (OperationCanceledException oce) + { + Assert.Equal(cts.Token, oce.CancellationToken); + } + } } } diff --git a/src/libraries/System.Runtime/tests/System.Threading.Tasks.Tests/Task/TaskCompletionSourceTests.cs b/src/libraries/System.Runtime/tests/System.Threading.Tasks.Tests/Task/TaskCompletionSourceTests.cs index d97a473763964..2cb19b41486f2 100644 --- a/src/libraries/System.Runtime/tests/System.Threading.Tasks.Tests/Task/TaskCompletionSourceTests.cs +++ b/src/libraries/System.Runtime/tests/System.Threading.Tasks.Tests/Task/TaskCompletionSourceTests.cs @@ -200,5 +200,112 @@ private static void AssertCompletedTcsFailsToCompleteAgain(TaskCompletionSource Assert.False(tcs.TrySetCanceled()); Assert.False(tcs.TrySetCanceled(default)); } + + [Fact] + public void SetFromTask_InvalidArgument_Throws() + { + TaskCompletionSource tcs = new(); + AssertExtensions.Throws("completedTask", () => tcs.SetFromTask(null)); + AssertExtensions.Throws("completedTask", () => tcs.SetFromTask(new TaskCompletionSource().Task)); + Assert.False(tcs.Task.IsCompleted); + + tcs.SetResult(); + Assert.True(tcs.Task.IsCompletedSuccessfully); + + AssertExtensions.Throws("completedTask", () => tcs.SetFromTask(null)); + AssertExtensions.Throws("completedTask", () => tcs.SetFromTask(new TaskCompletionSource().Task)); + Assert.True(tcs.Task.IsCompletedSuccessfully); + } + + [Fact] + public void SetFromTask_AlreadyCompleted_ReturnsFalseOrThrows() + { + TaskCompletionSource tcs = new(); + tcs.SetResult(); + + Assert.False(tcs.TrySetFromTask(Task.CompletedTask)); + Assert.False(tcs.TrySetFromTask(Task.FromException(new Exception()))); + Assert.False(tcs.TrySetFromTask(Task.FromCanceled(new CancellationToken(canceled: true)))); + + Assert.Throws(() => tcs.SetFromTask(Task.CompletedTask)); + Assert.Throws(() => tcs.SetFromTask(Task.FromException(new Exception()))); + Assert.Throws(() => tcs.SetFromTask(Task.FromCanceled(new CancellationToken(canceled: true)))); + + Assert.True(tcs.Task.IsCompletedSuccessfully); + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public void SetFromTask_CompletedSuccessfully(bool tryMethod) + { + TaskCompletionSource tcs = new(); + Task source = Task.CompletedTask; + + if (tryMethod) + { + Assert.True(tcs.TrySetFromTask(source)); + } + else + { + tcs.SetFromTask(source); + } + + Assert.True(tcs.Task.IsCompletedSuccessfully); + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public void SetFromTask_Faulted(bool tryMethod) + { + TaskCompletionSource tcs = new(); + + var source = new TaskCompletionSource(); + source.SetException([new FormatException(), new DivideByZeroException()]); + + if (tryMethod) + { + Assert.True(tcs.TrySetFromTask(source.Task)); + } + else + { + tcs.SetFromTask(source.Task); + } + + Assert.True(tcs.Task.IsFaulted); + Assert.True(tcs.Task.Exception.InnerExceptions.Count == 2); + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public void SetFromTask_Canceled(bool tryMethod) + { + TaskCompletionSource tcs = new(); + + var cts = new CancellationTokenSource(); + cts.Cancel(); + Task source = Task.FromCanceled(cts.Token); + + if (tryMethod) + { + Assert.True(tcs.TrySetFromTask(source)); + } + else + { + tcs.SetFromTask(source); + } + + Assert.True(tcs.Task.IsCanceled); + try + { + tcs.Task.GetAwaiter().GetResult(); + } + catch (OperationCanceledException oce) + { + Assert.Equal(cts.Token, oce.CancellationToken); + } + } } }