diff --git a/src/Common/src/CoreLib/System/Runtime/CompilerServices/AsyncValueTaskMethodBuilder.cs b/src/Common/src/CoreLib/System/Runtime/CompilerServices/AsyncValueTaskMethodBuilder.cs index 49cdaccb0e4b..558b46775cf5 100644 --- a/src/Common/src/CoreLib/System/Runtime/CompilerServices/AsyncValueTaskMethodBuilder.cs +++ b/src/Common/src/CoreLib/System/Runtime/CompilerServices/AsyncValueTaskMethodBuilder.cs @@ -8,6 +8,108 @@ namespace System.Runtime.CompilerServices { + /// Represents a builder for asynchronous methods that return a . + [StructLayout(LayoutKind.Auto)] + public struct AsyncValueTaskMethodBuilder + { + /// The to which most operations are delegated. + private AsyncTaskMethodBuilder _methodBuilder; // mutable struct; do not make it readonly + /// true if completed synchronously and successfully; otherwise, false. + private bool _haveResult; + /// true if the builder should be used for setting/getting the result; otherwise, false. + private bool _useBuilder; + + /// Creates an instance of the struct. + /// The initialized instance. + public static AsyncValueTaskMethodBuilder Create() => +#if CORERT + // corert's AsyncTaskMethodBuilder.Create() currently does additional debugger-related + // work, so we need to delegate to it. + new AsyncValueTaskMethodBuilder() { _methodBuilder = AsyncTaskMethodBuilder.Create() }; +#else + // _methodBuilder should be initialized to AsyncTaskMethodBuilder.Create(), but on coreclr + // that Create() is a nop, so we can just return the default here. + default; +#endif + + /// Begins running the builder with the associated state machine. + /// The type of the state machine. + /// The state machine instance, passed by reference. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public void Start(ref TStateMachine stateMachine) where TStateMachine : IAsyncStateMachine => + // will provide the right ExecutionContext semantics +#if netstandard + _methodBuilder.Start(ref stateMachine); +#else + AsyncMethodBuilderCore.Start(ref stateMachine); +#endif + + /// Associates the builder with the specified state machine. + /// The state machine instance to associate with the builder. + public void SetStateMachine(IAsyncStateMachine stateMachine) => _methodBuilder.SetStateMachine(stateMachine); + + /// Marks the task as successfully completed. + public void SetResult() + { + if (_useBuilder) + { + _methodBuilder.SetResult(); + } + else + { + _haveResult = true; + } + } + + /// Marks the task as failed and binds the specified exception to the task. + /// The exception to bind to the task. + public void SetException(Exception exception) => _methodBuilder.SetException(exception); + + /// Gets the task for this builder. + public ValueTask Task + { + get + { + if (_haveResult) + { + return default; + } + else + { + _useBuilder = true; + return new ValueTask(_methodBuilder.Task); + } + } + } + + /// Schedules the state machine to proceed to the next action when the specified awaiter completes. + /// The type of the awaiter. + /// The type of the state machine. + /// the awaiter + /// The state machine. + public void AwaitOnCompleted(ref TAwaiter awaiter, ref TStateMachine stateMachine) + where TAwaiter : INotifyCompletion + where TStateMachine : IAsyncStateMachine + { + _useBuilder = true; + _methodBuilder.AwaitOnCompleted(ref awaiter, ref stateMachine); + } + + /// Schedules the state machine to proceed to the next action when the specified awaiter completes. + /// The type of the awaiter. + /// The type of the state machine. + /// the awaiter + /// The state machine. + [SecuritySafeCritical] + public void AwaitUnsafeOnCompleted(ref TAwaiter awaiter, ref TStateMachine stateMachine) + where TAwaiter : ICriticalNotifyCompletion + where TStateMachine : IAsyncStateMachine + { + _useBuilder = true; + _methodBuilder.AwaitUnsafeOnCompleted(ref awaiter, ref stateMachine); + } + } + /// Represents a builder for asynchronous methods that returns a . /// The type of the result. [StructLayout(LayoutKind.Auto)] @@ -32,7 +134,7 @@ public static AsyncValueTaskMethodBuilder Create() => #else // _methodBuilder should be initialized to AsyncTaskMethodBuilder.Create(), but on coreclr // that Create() is a nop, so we can just return the default here. - default(AsyncValueTaskMethodBuilder); + default; #endif /// Begins running the builder with the associated state machine. diff --git a/src/Common/src/CoreLib/System/Runtime/CompilerServices/ConfiguredValueTaskAwaitable.cs b/src/Common/src/CoreLib/System/Runtime/CompilerServices/ConfiguredValueTaskAwaitable.cs index f22b9d94bf65..0d7e3d11192c 100644 --- a/src/Common/src/CoreLib/System/Runtime/CompilerServices/ConfiguredValueTaskAwaitable.cs +++ b/src/Common/src/CoreLib/System/Runtime/CompilerServices/ConfiguredValueTaskAwaitable.cs @@ -5,9 +5,113 @@ using System.Diagnostics; using System.Runtime.InteropServices; using System.Threading.Tasks; +#if !netstandard +using Internal.Runtime.CompilerServices; +#endif namespace System.Runtime.CompilerServices { + /// Provides an awaitable type that enables configured awaits on a . + [StructLayout(LayoutKind.Auto)] + public readonly struct ConfiguredValueTaskAwaitable + { + /// The wrapped . + private readonly ValueTask _value; + + /// Initializes the awaitable. + /// The wrapped . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal ConfiguredValueTaskAwaitable(ValueTask value) => _value = value; + + /// Returns an awaiter for this instance. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public ConfiguredValueTaskAwaiter GetAwaiter() => new ConfiguredValueTaskAwaiter(_value); + + /// Provides an awaiter for a . + [StructLayout(LayoutKind.Auto)] + public readonly struct ConfiguredValueTaskAwaiter : ICriticalNotifyCompletion +#if CORECLR + , IValueTaskAwaiter +#endif + { + /// The value being awaited. + private readonly ValueTask _value; + + /// Initializes the awaiter. + /// The value to be awaited. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal ConfiguredValueTaskAwaiter(ValueTask value) => _value = value; + + /// Gets whether the has completed. + public bool IsCompleted + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + get => _value.IsCompleted; + } + + /// Gets the result of the ValueTask. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + [StackTraceHidden] + public void GetResult() => _value.ThrowIfCompletedUnsuccessfully(); + + /// Schedules the continuation action for the . + public void OnCompleted(Action continuation) + { + if (_value.ObjectIsTask) + { + _value.UnsafeTask.ConfigureAwait(_value.ContinueOnCapturedContext).GetAwaiter().OnCompleted(continuation); + } + else if (_value._obj != null) + { + _value.UnsafeValueTaskSource.OnCompleted(ValueTaskAwaiter.s_invokeActionDelegate, continuation, _value._token, + ValueTaskSourceOnCompletedFlags.FlowExecutionContext | + (_value.ContinueOnCapturedContext ? ValueTaskSourceOnCompletedFlags.UseSchedulingContext : ValueTaskSourceOnCompletedFlags.None)); + } + else + { + Task.CompletedTask.ConfigureAwait(_value.ContinueOnCapturedContext).GetAwaiter().OnCompleted(continuation); + } + } + + /// Schedules the continuation action for the . + public void UnsafeOnCompleted(Action continuation) + { + if (_value.ObjectIsTask) + { + _value.UnsafeTask.ConfigureAwait(_value.ContinueOnCapturedContext).GetAwaiter().UnsafeOnCompleted(continuation); + } + else if (_value._obj != null) + { + _value.UnsafeValueTaskSource.OnCompleted(ValueTaskAwaiter.s_invokeActionDelegate, continuation, _value._token, + _value.ContinueOnCapturedContext ? ValueTaskSourceOnCompletedFlags.UseSchedulingContext : ValueTaskSourceOnCompletedFlags.None); + } + else + { + Task.CompletedTask.ConfigureAwait(_value.ContinueOnCapturedContext).GetAwaiter().UnsafeOnCompleted(continuation); + } + } + +#if CORECLR + void IValueTaskAwaiter.AwaitUnsafeOnCompleted(IAsyncStateMachineBox box) + { + if (_value.ObjectIsTask) + { + TaskAwaiter.UnsafeOnCompletedInternal(_value.UnsafeTask, box, _value.ContinueOnCapturedContext); + } + else if (_value._obj != null) + { + _value.UnsafeValueTaskSource.OnCompleted(ValueTaskAwaiter.s_invokeAsyncStateMachineBox, box, _value._token, + _value.ContinueOnCapturedContext ? ValueTaskSourceOnCompletedFlags.UseSchedulingContext : ValueTaskSourceOnCompletedFlags.None); + } + else + { + TaskAwaiter.UnsafeOnCompletedInternal(Task.CompletedTask, box, _value.ContinueOnCapturedContext); + } + } +#endif + } + } + /// Provides an awaitable type that enables configured awaits on a . /// The type of the result produced. [StructLayout(LayoutKind.Auto)] @@ -15,78 +119,98 @@ public readonly struct ConfiguredValueTaskAwaitable { /// The wrapped . private readonly ValueTask _value; - /// true to attempt to marshal the continuation back to the original context captured; otherwise, false. - private readonly bool _continueOnCapturedContext; /// Initializes the awaitable. /// The wrapped . - /// - /// true to attempt to marshal the continuation back to the original synchronization context captured; otherwise, false. - /// - internal ConfiguredValueTaskAwaitable(ValueTask value, bool continueOnCapturedContext) - { - _value = value; - _continueOnCapturedContext = continueOnCapturedContext; - } + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal ConfiguredValueTaskAwaitable(ValueTask value) => _value = value; /// Returns an awaiter for this instance. - public ConfiguredValueTaskAwaiter GetAwaiter() => - new ConfiguredValueTaskAwaiter(_value, _continueOnCapturedContext); + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public ConfiguredValueTaskAwaiter GetAwaiter() => new ConfiguredValueTaskAwaiter(_value); /// Provides an awaiter for a . [StructLayout(LayoutKind.Auto)] - public readonly struct ConfiguredValueTaskAwaiter : ICriticalNotifyCompletion, IConfiguredValueTaskAwaiter + public readonly struct ConfiguredValueTaskAwaiter : ICriticalNotifyCompletion +#if CORECLR + , IValueTaskAwaiter +#endif { /// The value being awaited. private readonly ValueTask _value; - /// The value to pass to ConfigureAwait. - internal readonly bool _continueOnCapturedContext; /// Initializes the awaiter. /// The value to be awaited. - /// The value to pass to ConfigureAwait. - internal ConfiguredValueTaskAwaiter(ValueTask value, bool continueOnCapturedContext) - { - _value = value; - _continueOnCapturedContext = continueOnCapturedContext; - } + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal ConfiguredValueTaskAwaiter(ValueTask value) => _value = value; /// Gets whether the has completed. - public bool IsCompleted => _value.IsCompleted; + public bool IsCompleted + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + get => _value.IsCompleted; + } /// Gets the result of the ValueTask. + [MethodImpl(MethodImplOptions.AggressiveInlining)] [StackTraceHidden] - public TResult GetResult() => - _value._task == null ? - _value._result : - _value._task.GetAwaiter().GetResult(); + public TResult GetResult() => _value.Result; /// Schedules the continuation action for the . - public void OnCompleted(Action continuation) => - _value.AsTask().ConfigureAwait(_continueOnCapturedContext).GetAwaiter().OnCompleted(continuation); + public void OnCompleted(Action continuation) + { + if (_value.ObjectIsTask) + { + _value.UnsafeTask.ConfigureAwait(_value.ContinueOnCapturedContext).GetAwaiter().OnCompleted(continuation); + } + else if (_value._obj != null) + { + _value.UnsafeValueTaskSource.OnCompleted(ValueTaskAwaiter.s_invokeActionDelegate, continuation, _value._token, + ValueTaskSourceOnCompletedFlags.FlowExecutionContext | + (_value.ContinueOnCapturedContext ? ValueTaskSourceOnCompletedFlags.UseSchedulingContext : ValueTaskSourceOnCompletedFlags.None)); + } + else + { + Task.CompletedTask.ConfigureAwait(_value.ContinueOnCapturedContext).GetAwaiter().OnCompleted(continuation); + } + } /// Schedules the continuation action for the . - public void UnsafeOnCompleted(Action continuation) => - _value.AsTask().ConfigureAwait(_continueOnCapturedContext).GetAwaiter().UnsafeOnCompleted(continuation); - - /// Gets the task underlying . - internal Task AsTask() => _value.AsTask(); + public void UnsafeOnCompleted(Action continuation) + { + if (_value.ObjectIsTask) + { + _value.UnsafeTask.ConfigureAwait(_value.ContinueOnCapturedContext).GetAwaiter().UnsafeOnCompleted(continuation); + } + else if (_value._obj != null) + { + _value.UnsafeValueTaskSource.OnCompleted(ValueTaskAwaiter.s_invokeActionDelegate, continuation, _value._token, + _value.ContinueOnCapturedContext ? ValueTaskSourceOnCompletedFlags.UseSchedulingContext : ValueTaskSourceOnCompletedFlags.None); + } + else + { + Task.CompletedTask.ConfigureAwait(_value.ContinueOnCapturedContext).GetAwaiter().UnsafeOnCompleted(continuation); + } + } - /// Gets the task underlying the incomplete . - /// This method is used when awaiting and IsCompleted returned false; thus we expect the value task to be wrapping a non-null task. - Task IConfiguredValueTaskAwaiter.GetTask(out bool continueOnCapturedContext) +#if CORECLR + void IValueTaskAwaiter.AwaitUnsafeOnCompleted(IAsyncStateMachineBox box) { - continueOnCapturedContext = _continueOnCapturedContext; - return _value.AsTaskExpectNonNull(); + if (_value.ObjectIsTask) + { + TaskAwaiter.UnsafeOnCompletedInternal(_value.UnsafeTask, box, _value.ContinueOnCapturedContext); + } + else if (_value._obj != null) + { + _value.UnsafeValueTaskSource.OnCompleted(ValueTaskAwaiter.s_invokeAsyncStateMachineBox, box, _value._token, + _value.ContinueOnCapturedContext ? ValueTaskSourceOnCompletedFlags.UseSchedulingContext : ValueTaskSourceOnCompletedFlags.None); + } + else + { + TaskAwaiter.UnsafeOnCompletedInternal(Task.CompletedTask, box, _value.ContinueOnCapturedContext); + } } +#endif } } - - /// - /// Internal interface used to enable extract the Task from arbitrary configured ValueTask awaiters. - /// - internal interface IConfiguredValueTaskAwaiter - { - Task GetTask(out bool continueOnCapturedContext); - } } diff --git a/src/Common/src/CoreLib/System/Runtime/CompilerServices/ValueTaskAwaiter.cs b/src/Common/src/CoreLib/System/Runtime/CompilerServices/ValueTaskAwaiter.cs index 3f212d8bf9b8..4b3df947adaa 100644 --- a/src/Common/src/CoreLib/System/Runtime/CompilerServices/ValueTaskAwaiter.cs +++ b/src/Common/src/CoreLib/System/Runtime/CompilerServices/ValueTaskAwaiter.cs @@ -7,47 +7,194 @@ namespace System.Runtime.CompilerServices { + /// Provides an awaiter for a . + public readonly struct ValueTaskAwaiter : ICriticalNotifyCompletion +#if CORECLR + , IValueTaskAwaiter +#endif + { + /// Shim used to invoke an passed as the state argument to a . + internal static readonly Action s_invokeActionDelegate = state => + { + if (!(state is Action action)) + { + ThrowHelper.ThrowArgumentOutOfRangeException(ExceptionArgument.state); + return; + } + + action(); + }; + /// The value being awaited. + private readonly ValueTask _value; + + /// Initializes the awaiter. + /// The value to be awaited. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal ValueTaskAwaiter(ValueTask value) => _value = value; + + /// Gets whether the has completed. + public bool IsCompleted + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + get => _value.IsCompleted; + } + + /// Gets the result of the ValueTask. + [StackTraceHidden] + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public void GetResult() => _value.ThrowIfCompletedUnsuccessfully(); + + /// Schedules the continuation action for this ValueTask. + public void OnCompleted(Action continuation) + { + if (_value.ObjectIsTask) + { + _value.UnsafeTask.GetAwaiter().OnCompleted(continuation); + } + else if (_value._obj != null) + { + _value.UnsafeValueTaskSource.OnCompleted(s_invokeActionDelegate, continuation, _value._token, ValueTaskSourceOnCompletedFlags.UseSchedulingContext | ValueTaskSourceOnCompletedFlags.FlowExecutionContext); + } + else + { + Task.CompletedTask.GetAwaiter().OnCompleted(continuation); + } + } + + /// Schedules the continuation action for this ValueTask. + public void UnsafeOnCompleted(Action continuation) + { + if (_value.ObjectIsTask) + { + _value.UnsafeTask.GetAwaiter().UnsafeOnCompleted(continuation); + } + else if (_value._obj != null) + { + _value.UnsafeValueTaskSource.OnCompleted(s_invokeActionDelegate, continuation, _value._token, ValueTaskSourceOnCompletedFlags.UseSchedulingContext); + } + else + { + Task.CompletedTask.GetAwaiter().UnsafeOnCompleted(continuation); + } + } + +#if CORECLR + void IValueTaskAwaiter.AwaitUnsafeOnCompleted(IAsyncStateMachineBox box) + { + if (_value.ObjectIsTask) + { + TaskAwaiter.UnsafeOnCompletedInternal(_value.UnsafeTask, box, continueOnCapturedContext: true); + } + else if (_value._obj != null) + { + _value.UnsafeValueTaskSource.OnCompleted(s_invokeAsyncStateMachineBox, box, _value._token, ValueTaskSourceOnCompletedFlags.UseSchedulingContext); + } + else + { + TaskAwaiter.UnsafeOnCompletedInternal(Task.CompletedTask, box, continueOnCapturedContext: true); + } + } + + /// Shim used to invoke of the supplied . + internal static readonly Action s_invokeAsyncStateMachineBox = state => + { + if (!(state is IAsyncStateMachineBox box)) + { + ThrowHelper.ThrowArgumentOutOfRangeException(ExceptionArgument.state); + return; + } + + box.Invoke(null); + }; +#endif + } + /// Provides an awaiter for a . - public readonly struct ValueTaskAwaiter : ICriticalNotifyCompletion, IValueTaskAwaiter + public readonly struct ValueTaskAwaiter : ICriticalNotifyCompletion +#if CORECLR + , IValueTaskAwaiter +#endif { /// The value being awaited. private readonly ValueTask _value; /// Initializes the awaiter. /// The value to be awaited. + [MethodImpl(MethodImplOptions.AggressiveInlining)] internal ValueTaskAwaiter(ValueTask value) => _value = value; /// Gets whether the has completed. - public bool IsCompleted => _value.IsCompleted; + public bool IsCompleted + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + get => _value.IsCompleted; + } /// Gets the result of the ValueTask. [StackTraceHidden] - public TResult GetResult() => - _value._task == null ? - _value._result : - _value._task.GetAwaiter().GetResult(); + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public TResult GetResult() => _value.Result; /// Schedules the continuation action for this ValueTask. - public void OnCompleted(Action continuation) => - _value.AsTask().ConfigureAwait(continueOnCapturedContext: true).GetAwaiter().OnCompleted(continuation); + public void OnCompleted(Action continuation) + { + if (_value.ObjectIsTask) + { + _value.UnsafeTask.GetAwaiter().OnCompleted(continuation); + } + else if (_value._obj != null) + { + _value.UnsafeValueTaskSource.OnCompleted(ValueTaskAwaiter.s_invokeActionDelegate, continuation, _value._token, ValueTaskSourceOnCompletedFlags.UseSchedulingContext | ValueTaskSourceOnCompletedFlags.FlowExecutionContext); + } + else + { + Task.CompletedTask.GetAwaiter().OnCompleted(continuation); + } + } /// Schedules the continuation action for this ValueTask. - public void UnsafeOnCompleted(Action continuation) => - _value.AsTask().ConfigureAwait(continueOnCapturedContext: true).GetAwaiter().UnsafeOnCompleted(continuation); - - /// Gets the task underlying . - internal Task AsTask() => _value.AsTask(); + public void UnsafeOnCompleted(Action continuation) + { + if (_value.ObjectIsTask) + { + _value.UnsafeTask.GetAwaiter().UnsafeOnCompleted(continuation); + } + else if (_value._obj != null) + { + _value.UnsafeValueTaskSource.OnCompleted(ValueTaskAwaiter.s_invokeActionDelegate, continuation, _value._token, ValueTaskSourceOnCompletedFlags.UseSchedulingContext); + } + else + { + Task.CompletedTask.GetAwaiter().UnsafeOnCompleted(continuation); + } + } - /// Gets the task underlying the incomplete . - /// This method is used when awaiting and IsCompleted returned false; thus we expect the value task to be wrapping a non-null task. - Task IValueTaskAwaiter.GetTask() => _value.AsTaskExpectNonNull(); +#if CORECLR + void IValueTaskAwaiter.AwaitUnsafeOnCompleted(IAsyncStateMachineBox box) + { + if (_value.ObjectIsTask) + { + TaskAwaiter.UnsafeOnCompletedInternal(_value.UnsafeTask, box, continueOnCapturedContext: true); + } + else if (_value._obj != null) + { + _value.UnsafeValueTaskSource.OnCompleted(ValueTaskAwaiter.s_invokeAsyncStateMachineBox, box, _value._token, ValueTaskSourceOnCompletedFlags.UseSchedulingContext); + } + else + { + TaskAwaiter.UnsafeOnCompletedInternal(Task.CompletedTask, box, continueOnCapturedContext: true); + } + } +#endif } - /// - /// Internal interface used to enable extract the Task from arbitrary ValueTask awaiters. - /// > +#if CORECLR + /// Internal interface used to enable optimizations from on .> internal interface IValueTaskAwaiter { - Task GetTask(); + /// Invoked to set of the as the awaiter's continuation. + /// The box object. + void AwaitUnsafeOnCompleted(IAsyncStateMachineBox box); } +#endif } diff --git a/src/Common/src/CoreLib/System/Threading/Tasks/IValueTaskSource.cs b/src/Common/src/CoreLib/System/Threading/Tasks/IValueTaskSource.cs new file mode 100644 index 000000000000..7c7312ac06a5 --- /dev/null +++ b/src/Common/src/CoreLib/System/Threading/Tasks/IValueTaskSource.cs @@ -0,0 +1,83 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +namespace System.Threading.Tasks +{ + /// + /// Flags passed from and to + /// and + /// to control behavior. + /// + [Flags] + public enum ValueTaskSourceOnCompletedFlags + { + /// + /// No requirements are placed on how the continuation is invoked. + /// + None, + /// + /// Set if OnCompleted should capture the current scheduling context (e.g. SynchronizationContext) + /// and use it when queueing the continuation for execution. If this is not set, the implementation + /// may choose to execute the continuation in an arbitrary location. + /// + UseSchedulingContext = 0x1, + /// + /// Set if OnCompleted should capture the current and use it to + /// the continuation. + /// + FlowExecutionContext = 0x2, + } + + /// Indicates the status of an or . + public enum ValueTaskSourceStatus + { + /// The operation has not yet completed. + Pending = 0, + /// The operation completed successfully. + Succeeded = 1, + /// The operation completed with an error. + Faulted = 2, + /// The operation completed due to cancellation. + Canceled = 3 + } + + /// Represents an object that can be wrapped by a . + public interface IValueTaskSource + { + /// Gets the status of the current operation. + /// Opaque value that was provided to the 's constructor. + ValueTaskSourceStatus GetStatus(short token); + + /// Schedules the continuation action for this . + /// The continuation to invoke when the operation has completed. + /// The state object to pass to when it's invoked. + /// Opaque value that was provided to the 's constructor. + /// The flags describing the behavior of the continuation. + void OnCompleted(Action continuation, object state, short token, ValueTaskSourceOnCompletedFlags flags); + + /// Gets the result of the . + /// Opaque value that was provided to the 's constructor. + void GetResult(short token); + } + + /// Represents an object that can be wrapped by a . + /// Specifies the type of data returned from the object. + public interface IValueTaskSource + { + /// Gets the status of the current operation. + /// Opaque value that was provided to the 's constructor. + ValueTaskSourceStatus GetStatus(short token); + + /// Schedules the continuation action for this . + /// The continuation to invoke when the operation has completed. + /// The state object to pass to when it's invoked. + /// Opaque value that was provided to the 's constructor. + /// The flags describing the behavior of the continuation. + void OnCompleted(Action continuation, object state, short token, ValueTaskSourceOnCompletedFlags flags); + + /// Gets the result of the . + /// Opaque value that was provided to the 's constructor. + TResult GetResult(short token); + } +} diff --git a/src/Common/src/CoreLib/System/Threading/Tasks/ValueTask.cs b/src/Common/src/CoreLib/System/Threading/Tasks/ValueTask.cs index 5edd8501b0c6..b6689c1d82fe 100644 --- a/src/Common/src/CoreLib/System/Threading/Tasks/ValueTask.cs +++ b/src/Common/src/CoreLib/System/Threading/Tasks/ValueTask.cs @@ -3,71 +3,413 @@ // See the LICENSE file in the project root for more information. using System.Collections.Generic; +using System.Diagnostics; using System.Runtime.CompilerServices; using System.Runtime.InteropServices; +#if !netstandard +using Internal.Runtime.CompilerServices; +#endif namespace System.Threading.Tasks { - /// - /// Provides a value type that wraps a and a , - /// only one of which is used. - /// - /// The type of the result. + /// Provides an awaitable result of an asynchronous operation. /// - /// - /// Methods may return an instance of this value type when it's likely that the result of their - /// operations will be available synchronously and when the method is expected to be invoked so - /// frequently that the cost of allocating a new for each call will - /// be prohibitive. - /// - /// - /// There are tradeoffs to using a instead of a . - /// For example, while a can help avoid an allocation in the case where the - /// successful result is available synchronously, it also contains two fields whereas a - /// as a reference type is a single field. This means that a method call ends up returning two fields worth of - /// data instead of one, which is more data to copy. It also means that if a method that returns one of these - /// is awaited within an async method, the state machine for that async method will be larger due to needing - /// to store the struct that's two fields instead of a single reference. - /// - /// - /// Further, for uses other than consuming the result of an asynchronous operation via await, - /// can lead to a more convoluted programming model, which can in turn actually - /// lead to more allocations. For example, consider a method that could return either a - /// with a cached task as a common result or a . If the consumer of the result - /// wants to use it as a , such as to use with in methods like Task.WhenAll and Task.WhenAny, - /// the would first need to be converted into a using - /// , which leads to an allocation that would have been avoided if a cached - /// had been used in the first place. - /// - /// - /// As such, the default choice for any asynchronous method should be to return a or - /// . Only if performance analysis proves it worthwhile should a - /// be used instead of . There is no non-generic version of - /// as the Task.CompletedTask property may be used to hand back a successfully completed singleton in the case where - /// a -returning method completes synchronously and successfully. - /// + /// s are meant to be directly awaited. To do more complicated operations with them, a + /// should be extracted using . Such operations might include caching an instance to be awaited later, + /// registering multiple continuations with a single operation, awaiting the same task multiple times, and using combinators over + /// multiple operations. + /// + [AsyncMethodBuilder(typeof(AsyncValueTaskMethodBuilder))] + [StructLayout(LayoutKind.Auto)] + public readonly struct ValueTask : IEquatable + { +#if netstandard + /// A successfully completed task. + private static readonly Task s_completedTask = Task.Delay(0); +#endif + + /// null if representing a successful synchronous completion, otherwise a or a . + internal readonly object _obj; + /// Flags providing additional details about the ValueTask's contents and behavior. + internal readonly ValueTaskFlags _flags; + /// Opaque value passed through to the . + internal readonly short _token; + + // An instance created with the default ctor (a zero init'd struct) represents a synchronously, successfully completed operation. + + /// Initialize the with a that represents the operation. + /// The task. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public ValueTask(Task task) + { + if (task == null) + { + ThrowHelper.ThrowArgumentNullException(ExceptionArgument.task); + } + + _obj = task; + + _flags = ValueTaskFlags.ObjectIsTask; + _token = 0; + } + + /// Initialize the with a object that represents the operation. + /// The source. + /// Opaque value passed through to the . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public ValueTask(IValueTaskSource source, short token) + { + if (source == null) + { + ThrowHelper.ThrowArgumentNullException(ExceptionArgument.source); + } + + _obj = source; + _token = token; + + _flags = 0; + } + + /// Non-verified initialization of the struct to the specified values. + /// The object. + /// The token. + /// The flags. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private ValueTask(object obj, short token, ValueTaskFlags flags) + { + _obj = obj; + _token = token; + _flags = flags; + } + + /// Gets whether the contination should be scheduled to the current context. + internal bool ContinueOnCapturedContext + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + get => (_flags & ValueTaskFlags.AvoidCapturedContext) == 0; + } + + /// Gets whether the object in the field is a . + internal bool ObjectIsTask + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + get => (_flags & ValueTaskFlags.ObjectIsTask) != 0; + } + + /// Returns the stored in . This uses . + internal Task UnsafeTask + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + get + { + Debug.Assert(ObjectIsTask); + Debug.Assert(_obj is Task); + return Unsafe.As(_obj); + } + } + + /// Returns the stored in . This uses . + internal IValueTaskSource UnsafeValueTaskSource + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + get + { + Debug.Assert(!ObjectIsTask); + Debug.Assert(_obj is IValueTaskSource); + return Unsafe.As(_obj); + } + } + + /// Returns the hash code for this instance. + public override int GetHashCode() => _obj?.GetHashCode() ?? 0; + + /// Returns a value indicating whether this value is equal to a specified . + public override bool Equals(object obj) => + obj is ValueTask && + Equals((ValueTask)obj); + + /// Returns a value indicating whether this value is equal to a specified value. + public bool Equals(ValueTask other) => _obj == other._obj && _token == other._token; + + /// Returns a value indicating whether two values are equal. + public static bool operator ==(ValueTask left, ValueTask right) => + left.Equals(right); + + /// Returns a value indicating whether two values are not equal. + public static bool operator !=(ValueTask left, ValueTask right) => + !left.Equals(right); + + /// + /// Gets a object to represent this ValueTask. + /// + /// + /// It will either return the wrapped task object if one exists, or it'll + /// manufacture a new task object to represent the result. + /// + public Task AsTask() => + _obj == null ? +#if netstandard + s_completedTask : +#else + Task.CompletedTask : +#endif + ObjectIsTask ? UnsafeTask : + GetTaskForValueTaskSource(); + + /// Gets a that may be used at any point in the future. + public ValueTask Preserve() => _obj == null ? this : new ValueTask(AsTask()); + + /// Creates a to represent the . + private Task GetTaskForValueTaskSource() + { + IValueTaskSource t = UnsafeValueTaskSource; + ValueTaskSourceStatus status = t.GetStatus(_token); + if (status != ValueTaskSourceStatus.Pending) + { + try + { + // Propagate any exceptions that may have occurred, then return + // an already successfully completed task. + t.GetResult(_token); + return +#if netstandard + s_completedTask; +#else + Task.CompletedTask; +#endif + + // If status is Faulted or Canceled, GetResult should throw. But + // we can't guarantee every implementation will do the "right thing". + // If it doesn't throw, we just treat that as success and ignore + // the status. + } + catch (Exception exc) + { + if (status == ValueTaskSourceStatus.Canceled) + { +#if netstandard + var tcs = new TaskCompletionSource(); + tcs.TrySetCanceled(); + return tcs.Task; +#else + if (exc is OperationCanceledException oce) + { + var task = new Task(); + task.TrySetCanceled(oce.CancellationToken, oce); + return task; + } + else + { + return Task.FromCanceled(new CancellationToken(true)); + } +#endif + } + else + { +#if netstandard + var tcs = new TaskCompletionSource(); + tcs.TrySetException(exc); + return tcs.Task; +#else + return Task.FromException(exc); +#endif + } + } + } + + var m = new ValueTaskSourceTask(t, _token); + return +#if netstandard + m.Task; +#else + m; +#endif + } + + /// Type used to create a to represent a . + private sealed class ValueTaskSourceTask : +#if netstandard + TaskCompletionSource +#else + Task +#endif + { + private static readonly Action s_completionAction = state => + { + if (!(state is ValueTaskSourceTask vtst) || + !(vtst._source is IValueTaskSource source)) + { + // This could only happen if the IValueTaskSource passed the wrong state + // or if this callback were invoked multiple times such that the state + // was previously nulled out. + ThrowHelper.ThrowArgumentOutOfRangeException(ExceptionArgument.state); + return; + } + + vtst._source = null; + ValueTaskSourceStatus status = source.GetStatus(vtst._token); + try + { + source.GetResult(vtst._token); + vtst.TrySetResult(default); + } + catch (Exception exc) + { + if (status == ValueTaskSourceStatus.Canceled) + { +#if netstandard + vtst.TrySetCanceled(); +#else + if (exc is OperationCanceledException oce) + { + vtst.TrySetCanceled(oce.CancellationToken, oce); + } + else + { + vtst.TrySetCanceled(new CancellationToken(true)); + } +#endif + } + else + { + vtst.TrySetException(exc); + } + } + }; + + /// The associated . + private IValueTaskSource _source; + /// The token to pass through to operations on + private readonly short _token; + + public ValueTaskSourceTask(IValueTaskSource source, short token) + { + _token = token; + _source = source; + source.OnCompleted(s_completionAction, this, token, ValueTaskSourceOnCompletedFlags.None); + } + } + + /// Gets whether the represents a completed operation. + public bool IsCompleted + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + get => _obj == null || (ObjectIsTask ? UnsafeTask.IsCompleted : UnsafeValueTaskSource.GetStatus(_token) != ValueTaskSourceStatus.Pending); + } + + /// Gets whether the represents a successfully completed operation. + public bool IsCompletedSuccessfully + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + get => + _obj == null || + (ObjectIsTask ? +#if netstandard + UnsafeTask.Status == TaskStatus.RanToCompletion : +#else + UnsafeTask.IsCompletedSuccessfully : +#endif + UnsafeValueTaskSource.GetStatus(_token) == ValueTaskSourceStatus.Succeeded); + } + + /// Gets whether the represents a failed operation. + public bool IsFaulted + { + get => + _obj != null && + (ObjectIsTask ? UnsafeTask.IsFaulted : UnsafeValueTaskSource.GetStatus(_token) == ValueTaskSourceStatus.Faulted); + } + + /// Gets whether the represents a canceled operation. + /// + /// If the is backed by a result or by a , + /// this will always return false. If it's backed by a , it'll return the + /// value of the task's property. + /// + public bool IsCanceled + { + get => + _obj != null && + (ObjectIsTask ? UnsafeTask.IsCanceled : UnsafeValueTaskSource.GetStatus(_token) == ValueTaskSourceStatus.Canceled); + } + + /// Throws the exception that caused the to fail. If it completed successfully, nothing is thrown. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + [StackTraceHidden] + internal void ThrowIfCompletedUnsuccessfully() + { + if (_obj != null) + { + if (ObjectIsTask) + { +#if netstandard + UnsafeTask.GetAwaiter().GetResult(); +#else + TaskAwaiter.ValidateEnd(UnsafeTask); +#endif + } + else + { + UnsafeValueTaskSource.GetResult(_token); + } + } + } + + /// Gets an awaiter for this . + public ValueTaskAwaiter GetAwaiter() => new ValueTaskAwaiter(this); + + /// Configures an awaiter for this . + /// + /// true to attempt to marshal the continuation back to the captured context; otherwise, false. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public ConfiguredValueTaskAwaitable ConfigureAwait(bool continueOnCapturedContext) + { + // TODO: Simplify once https://github.com/dotnet/coreclr/pull/16138 is fixed. + bool avoidCapture = !continueOnCapturedContext; + return new ConfiguredValueTaskAwaitable(new ValueTask(_obj, _token, _flags | Unsafe.As(ref avoidCapture))); + } + } + + /// Provides a value type that can represent a synchronously available value or a task object. + /// Specifies the type of the result. + /// + /// s are meant to be directly awaited. To do more complicated operations with them, a + /// should be extracted using or . Such operations might include caching an instance to + /// be awaited later, registering multiple continuations with a single operation, awaiting the same task multiple times, and using + /// combinators over multiple operations. /// [AsyncMethodBuilder(typeof(AsyncValueTaskMethodBuilder<>))] [StructLayout(LayoutKind.Auto)] public readonly struct ValueTask : IEquatable> { - /// The task to be used if the operation completed asynchronously or if it completed synchronously but non-successfully. - internal readonly Task _task; + /// null if has the result, otherwise a or a . + internal readonly object _obj; /// The result to be used if the operation completed successfully synchronously. internal readonly TResult _result; + /// Flags providing additional details about the ValueTask's contents and behavior. + internal readonly ValueTaskFlags _flags; + /// Opaque value passed through to the . + internal readonly short _token; - /// Initialize the with the result of the successful operation. + // An instance created with the default ctor (a zero init'd struct) represents a synchronously, successfully completed operation + // with a result of default(TResult). + + /// Initialize the with a result value. /// The result. + [MethodImpl(MethodImplOptions.AggressiveInlining)] public ValueTask(TResult result) { - _task = null; _result = result; + + _obj = null; + _flags = 0; + _token = 0; } - /// - /// Initialize the with a that represents the operation. - /// + /// Initialize the with a that represents the operation. /// The task. + [MethodImpl(MethodImplOptions.AggressiveInlining)] public ValueTask(Task task) { if (task == null) @@ -75,95 +417,341 @@ public ValueTask(Task task) ThrowHelper.ThrowArgumentNullException(ExceptionArgument.task); } - _task = task; + _obj = task; + + _result = default; + _flags = ValueTaskFlags.ObjectIsTask; + _token = 0; + } + + /// Initialize the with a object that represents the operation. + /// The source. + /// Opaque value passed through to the . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public ValueTask(IValueTaskSource source, short token) + { + if (source == null) + { + ThrowHelper.ThrowArgumentNullException(ExceptionArgument.source); + } + + _obj = source; + _token = token; + _result = default; + _flags = 0; + } + + /// Non-verified initialization of the struct to the specified values. + /// The object. + /// The result. + /// The token. + /// The flags. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private ValueTask(object obj, TResult result, short token, ValueTaskFlags flags) + { + _obj = obj; + _result = result; + _token = token; + _flags = flags; + } + + /// Gets whether the contination should be scheduled to the current context. + internal bool ContinueOnCapturedContext + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + get => (_flags & ValueTaskFlags.AvoidCapturedContext) == 0; + } + + /// Gets whether the object in the field is a . + internal bool ObjectIsTask + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + get => (_flags & ValueTaskFlags.ObjectIsTask) != 0; + } + + /// Returns the stored in . This uses . + internal Task UnsafeTask + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + get + { + Debug.Assert(ObjectIsTask); + Debug.Assert(_obj is Task); + return Unsafe.As>(_obj); + } + } + + /// Returns the stored in . This uses . + internal IValueTaskSource UnsafeValueTaskSource + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + get + { + Debug.Assert(!ObjectIsTask); + Debug.Assert(_obj is IValueTaskSource); + return Unsafe.As>(_obj); + } } /// Returns the hash code for this instance. public override int GetHashCode() => - _task != null ? _task.GetHashCode() : + _obj != null ? _obj.GetHashCode() : _result != null ? _result.GetHashCode() : 0; /// Returns a value indicating whether this value is equal to a specified . public override bool Equals(object obj) => - obj is ValueTask && + obj is ValueTask && Equals((ValueTask)obj); /// Returns a value indicating whether this value is equal to a specified value. public bool Equals(ValueTask other) => - _task != null || other._task != null ? - _task == other._task : + _obj != null || other._obj != null ? + _obj == other._obj && _token == other._token : EqualityComparer.Default.Equals(_result, other._result); /// Returns a value indicating whether two values are equal. - public static bool operator==(ValueTask left, ValueTask right) => + public static bool operator ==(ValueTask left, ValueTask right) => left.Equals(right); /// Returns a value indicating whether two values are not equal. - public static bool operator!=(ValueTask left, ValueTask right) => + public static bool operator !=(ValueTask left, ValueTask right) => !left.Equals(right); /// - /// Gets a object to represent this ValueTask. It will - /// either return the wrapped task object if one exists, or it'll manufacture a new - /// task object to represent the result. + /// Gets a object to represent this ValueTask. /// + /// + /// It will either return the wrapped task object if one exists, or it'll + /// manufacture a new task object to represent the result. + /// public Task AsTask() => - // Return the task if we were constructed from one, otherwise manufacture one. We don't - // cache the generated task into _task as it would end up changing both equality comparison - // and the hash code we generate in GetHashCode. - _task ?? + _obj == null ? #if netstandard - Task.FromResult(_result); + Task.FromResult(_result) : #else - AsyncTaskMethodBuilder.GetTaskForResult(_result); + AsyncTaskMethodBuilder.GetTaskForResult(_result) : #endif + ObjectIsTask ? UnsafeTask : + GetTaskForValueTaskSource(); - internal Task AsTaskExpectNonNull() => - // Return the task if we were constructed from one, otherwise manufacture one. - // Unlike AsTask(), this method is called only when we expect _task to be non-null, - // and thus we don't want GetTaskForResult inlined. - _task ?? GetTaskForResultNoInlining(); + /// Gets a that may be used at any point in the future. + public ValueTask Preserve() => _obj == null ? this : new ValueTask(AsTask()); - [MethodImpl(MethodImplOptions.NoInlining)] - private Task GetTaskForResultNoInlining() => + /// Creates a to represent the . + private Task GetTaskForValueTaskSource() + { + IValueTaskSource t = UnsafeValueTaskSource; + ValueTaskSourceStatus status = t.GetStatus(_token); + if (status != ValueTaskSourceStatus.Pending) + { + try + { + // Get the result of the operation and return a task for it. + // If any exception occurred, propagate it + return #if netstandard - Task.FromResult(_result); + Task.FromResult(t.GetResult(_token)); #else - AsyncTaskMethodBuilder.GetTaskForResult(_result); + AsyncTaskMethodBuilder.GetTaskForResult(t.GetResult(_token)); #endif + // If status is Faulted or Canceled, GetResult should throw. But + // we can't guarantee every implementation will do the "right thing". + // If it doesn't throw, we just treat that as success and ignore + // the status. + } + catch (Exception exc) + { + if (status == ValueTaskSourceStatus.Canceled) + { +#if netstandard + var tcs = new TaskCompletionSource(); + tcs.TrySetCanceled(); + return tcs.Task; +#else + if (exc is OperationCanceledException oce) + { + var task = new Task(); + task.TrySetCanceled(oce.CancellationToken, oce); + return task; + } + else + { + return Task.FromCanceled(new CancellationToken(true)); + } +#endif + } + else + { +#if netstandard + var tcs = new TaskCompletionSource(); + tcs.TrySetException(exc); + return tcs.Task; +#else + return Task.FromException(exc); +#endif + } + } + } + + var m = new ValueTaskSourceTask(t, _token); + return +#if netstandard + m.Task; +#else + m; +#endif + } + + /// Type used to create a to represent a . + private sealed class ValueTaskSourceTask : +#if netstandard + TaskCompletionSource +#else + Task +#endif + { + private static readonly Action s_completionAction = state => + { + if (!(state is ValueTaskSourceTask vtst) || + !(vtst._source is IValueTaskSource source)) + { + // This could only happen if the IValueTaskSource passed the wrong state + // or if this callback were invoked multiple times such that the state + // was previously nulled out. + ThrowHelper.ThrowArgumentOutOfRangeException(ExceptionArgument.state); + return; + } + + vtst._source = null; + ValueTaskSourceStatus status = source.GetStatus(vtst._token); + try + { + vtst.TrySetResult(source.GetResult(vtst._token)); + } + catch (Exception exc) + { + if (status == ValueTaskSourceStatus.Canceled) + { +#if netstandard + vtst.TrySetCanceled(); +#else + if (exc is OperationCanceledException oce) + { + vtst.TrySetCanceled(oce.CancellationToken, oce); + } + else + { + vtst.TrySetCanceled(new CancellationToken(true)); + } +#endif + } + else + { + vtst.TrySetException(exc); + } + } + }; + + /// The associated . + private IValueTaskSource _source; + /// The token to pass through to operations on + private readonly short _token; + + public ValueTaskSourceTask(IValueTaskSource source, short token) + { + _source = source; + _token = token; + source.OnCompleted(s_completionAction, this, token, ValueTaskSourceOnCompletedFlags.None); + } + } + /// Gets whether the represents a completed operation. - public bool IsCompleted => _task == null || _task.IsCompleted; + public bool IsCompleted + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + get => _obj == null || (ObjectIsTask ? UnsafeTask.IsCompleted : UnsafeValueTaskSource.GetStatus(_token) != ValueTaskSourceStatus.Pending); + } /// Gets whether the represents a successfully completed operation. - public bool IsCompletedSuccessfully => - _task == null || + public bool IsCompletedSuccessfully + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + get => + _obj == null || + (ObjectIsTask ? #if netstandard - _task.Status == TaskStatus.RanToCompletion; + UnsafeTask.Status == TaskStatus.RanToCompletion : #else - _task.IsCompletedSuccessfully; + UnsafeTask.IsCompletedSuccessfully : #endif + UnsafeValueTaskSource.GetStatus(_token) == ValueTaskSourceStatus.Succeeded); + } /// Gets whether the represents a failed operation. - public bool IsFaulted => _task != null && _task.IsFaulted; + public bool IsFaulted + { + get => + _obj != null && + (ObjectIsTask ? UnsafeTask.IsFaulted : UnsafeValueTaskSource.GetStatus(_token) == ValueTaskSourceStatus.Faulted); + } /// Gets whether the represents a canceled operation. - public bool IsCanceled => _task != null && _task.IsCanceled; + /// + /// If the is backed by a result or by a , + /// this will always return false. If it's backed by a , it'll return the + /// value of the task's property. + /// + public bool IsCanceled + { + get => + _obj != null && + (ObjectIsTask ? UnsafeTask.IsCanceled : UnsafeValueTaskSource.GetStatus(_token) == ValueTaskSourceStatus.Canceled); + } /// Gets the result. - public TResult Result => _task == null ? _result : _task.GetAwaiter().GetResult(); + public TResult Result + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + get + { + if (_obj == null) + { + return _result; + } + + if (ObjectIsTask) + { +#if netstandard + return UnsafeTask.GetAwaiter().GetResult(); +#else + Task t = UnsafeTask; + TaskAwaiter.ValidateEnd(t); + return t.ResultOnSuccess; +#endif + } + + return UnsafeValueTaskSource.GetResult(_token); + } + } - /// Gets an awaiter for this value. + /// Gets an awaiter for this . + [MethodImpl(MethodImplOptions.AggressiveInlining)] public ValueTaskAwaiter GetAwaiter() => new ValueTaskAwaiter(this); - /// Configures an awaiter for this value. + /// Configures an awaiter for this . /// /// true to attempt to marshal the continuation back to the captured context; otherwise, false. /// - public ConfiguredValueTaskAwaitable ConfigureAwait(bool continueOnCapturedContext) => - new ConfiguredValueTaskAwaitable(this, continueOnCapturedContext); + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public ConfiguredValueTaskAwaitable ConfigureAwait(bool continueOnCapturedContext) + { + // TODO: Simplify once https://github.com/dotnet/coreclr/pull/16138 is fixed. + bool avoidCapture = !continueOnCapturedContext; + return new ConfiguredValueTaskAwaitable(new ValueTask(_obj, _result, _token, _flags | Unsafe.As(ref avoidCapture))); + } /// Gets a string-representation of this . public override string ToString() @@ -180,4 +768,26 @@ public override string ToString() return string.Empty; } } + + /// Internal flags used in the implementation of and . + [Flags] + internal enum ValueTaskFlags : byte + { + /// + /// Indicates that context (e.g. SynchronizationContext) should not be captured when adding + /// a continuation. + /// + /// + /// The value here must be 0x1, to match the value of a true Boolean reinterpreted as a byte. + /// This only has meaning when awaiting a ValueTask, with ConfigureAwait creating a new + /// ValueTask setting or not setting this flag appropriately. + /// + AvoidCapturedContext = 0x1, + + /// + /// Indicates that the ValueTask's object field stores a Task. This is used to avoid + /// a type check on whatever is stored in the object field. + /// + ObjectIsTask = 0x2 + } } diff --git a/src/System.Net.Sockets/src/Resources/Strings.resx b/src/System.Net.Sockets/src/Resources/Strings.resx index de08663dd3a9..b95d2522e719 100644 --- a/src/System.Net.Sockets/src/Resources/Strings.resx +++ b/src/System.Net.Sockets/src/Resources/Strings.resx @@ -238,4 +238,10 @@ This operation may only be performed when the buffer was set using the SetBuffer overload that accepts an array. + + The result of the operation was already consumed and may not be used again. + + + Another continuation was already registered. + diff --git a/src/System.Net.Sockets/src/System/Net/Sockets/Socket.Tasks.cs b/src/System.Net.Sockets/src/System/Net/Sockets/Socket.Tasks.cs index 12c77e21cac4..bc79ac35b8a4 100644 --- a/src/System.Net.Sockets/src/System/Net/Sockets/Socket.Tasks.cs +++ b/src/System.Net.Sockets/src/System/Net/Sockets/Socket.Tasks.cs @@ -819,10 +819,25 @@ internal sealed class AwaitableSocketAsyncEventArgs : SocketAsyncEventArgs, IVal { Debug.Assert(c != s_availableSentinel, "The delegate should not have been the available sentinel."); Debug.Assert(c != s_completedSentinel, "The delegate should not have been the completed sentinel."); + object continuationState = ea.UserToken; ea.UserToken = null; ea._continuation = s_completedSentinel; // in case someone's polling IsCompleted - ea.InvokeContinuation(c, continuationState, forceAsync: false); + + ExecutionContext ec = ea._executionContext; + if (ec == null) + { + ea.InvokeContinuation(c, continuationState, forceAsync: false); + } + else + { + ea._executionContext = null; + ExecutionContext.Run(ec, runState => + { + var t = (Tuple, object>)runState; + t.Item1.InvokeContinuation(t.Item2, t.Item3, forceAsync: false); + }, Tuple.Create(ea, c, continuationState)); + } } }; /// @@ -835,6 +850,13 @@ internal sealed class AwaitableSocketAsyncEventArgs : SocketAsyncEventArgs, IVal private Action _continuation = s_availableSentinel; private ExecutionContext _executionContext; private object _scheduler; + /// Current token value given to a ValueTask and then verified against the value it passes back to us. + /// + /// This is not meant to be a completely reliable mechanism, doesn't require additional synchronization, etc. + /// It's purely a best effort attempt to catch misuse, including awaiting for a value task twice and after + /// it's already being reused by someone else. + /// + private short _token; /// Initializes the event args. /// The associated socket. @@ -845,7 +867,11 @@ internal sealed class AwaitableSocketAsyncEventArgs : SocketAsyncEventArgs, IVal public bool Reserve() => Interlocked.CompareExchange(ref _continuation, null, s_availableSentinel) == s_availableSentinel; - private void Release() => Volatile.Write(ref _continuation, s_availableSentinel); + private void Release() + { + _token++; + Volatile.Write(ref _continuation, s_availableSentinel); + } /// Initiates a receive operation on the associated socket. /// This instance. @@ -855,7 +881,7 @@ public ValueTask ReceiveAsync(Socket socket) if (socket.ReceiveAsync(this)) { - return new ValueTask(this); + return new ValueTask(this, _token); } int bytesTransferred = BytesTransferred; @@ -876,7 +902,7 @@ public ValueTask SendAsync(Socket socket) if (socket.SendAsync(this)) { - return new ValueTask(this); + return new ValueTask(this, _token); } int bytesTransferred = BytesTransferred; @@ -895,7 +921,7 @@ public ValueTask SendAsyncForNetworkStream(Socket socket) if (socket.SendAsync(this)) { - return new ValueTask(this); + return new ValueTask(this, _token); } SocketError error = SocketError; @@ -908,14 +934,27 @@ public ValueTask SendAsyncForNetworkStream(Socket socket) } /// Gets the status of the operation. - public ValueTaskSourceStatus Status => - _continuation != s_completedSentinel ? ValueTaskSourceStatus.Pending : - base.SocketError == SocketError.Success ? ValueTaskSourceStatus.Succeeded : - ValueTaskSourceStatus.Faulted; + public ValueTaskSourceStatus GetStatus(short token) + { + if (token != _token) + { + ThrowIncorrectTokenException(); + } + + return + _continuation != s_completedSentinel ? ValueTaskSourceStatus.Pending : + base.SocketError == SocketError.Success ? ValueTaskSourceStatus.Succeeded : + ValueTaskSourceStatus.Faulted; + } /// Queues the provided continuation to be executed once the operation has completed. - public void OnCompleted(Action continuation, object state, ValueTaskSourceOnCompletedFlags flags) + public void OnCompleted(Action continuation, object state, short token, ValueTaskSourceOnCompletedFlags flags) { + if (token != _token) + { + ThrowIncorrectTokenException(); + } + if ((flags & ValueTaskSourceOnCompletedFlags.FlowExecutionContext) != 0) { _executionContext = ExecutionContext.Capture(); @@ -939,32 +978,27 @@ public void OnCompleted(Action continuation, object state, ValueTaskSour } UserToken = state; // Use UserToken to carry the continuation state around - if (ReferenceEquals(Interlocked.CompareExchange(ref _continuation, continuation, null), s_completedSentinel)) + Action prevContinuation = Interlocked.CompareExchange(ref _continuation, continuation, null); + if (ReferenceEquals(prevContinuation, s_completedSentinel)) { - UserToken = null; + // Lost the race condition and the operation has now already completed. + // We need to invoke the continuation, but it must be asynchronously to + // avoid a stack dive. However, since all of the queueing mechanisms flow + // ExecutionContext, and since we're still in the same context where we + // captured it, we can just ignore the one we captured. + _executionContext = null; + UserToken = null; // we have the state in "state"; no need for the one in UserToken InvokeContinuation(continuation, state, forceAsync: true); } - } - - private void InvokeContinuation(Action continuation, object state, bool forceAsync) - { - ExecutionContext ec = _executionContext; - if (ec == null) - { - InvokeContinuationCore(continuation, state, forceAsync); - } - else + else if (prevContinuation != null) { - _executionContext = null; - ExecutionContext.Run(ec, s => - { - var t = (Tuple, object, bool>)s; - t.Item1.InvokeContinuationCore(t.Item2, t.Item3, t.Item4); - }, Tuple.Create(this, continuation, state, forceAsync)); + // Flag errors with the continuation being hooked up multiple times. + // This is purely to help alert a developer to a bug they need to fix. + ThrowMultipleContinuationsException(); } } - private void InvokeContinuationCore(Action continuation, object state, bool forceAsync) + private void InvokeContinuation(Action continuation, object state, bool forceAsync) { object scheduler = _scheduler; _scheduler = null; @@ -981,13 +1015,13 @@ private void InvokeContinuationCore(Action continuation, object state, b } else { + Debug.Assert(scheduler is TaskScheduler, $"Expected TaskScheduler, got {scheduler}"); Task.Factory.StartNew(continuation, state, CancellationToken.None, TaskCreationOptions.DenyChildAttach, (TaskScheduler)scheduler); } } else if (forceAsync) { - // TODO #27464: Use QueueUserWorkItem when it has a compatible signature. - Task.Factory.StartNew(continuation, state, CancellationToken.None, TaskCreationOptions.DenyChildAttach, TaskScheduler.Default); + ThreadPool.QueueUserWorkItem(continuation, state, preferLocal: true); } else { @@ -1001,8 +1035,13 @@ private void InvokeContinuationCore(Action continuation, object state, b /// Unlike Task's awaiter's GetResult, this does not block until the operation completes: it must only /// be used once the operation has completed. This is handled implicitly by await. /// - public int GetResult() + public int GetResult(short token) { + if (token != _token) + { + ThrowIncorrectTokenException(); + } + SocketError error = SocketError; int bytes = BytesTransferred; @@ -1015,8 +1054,13 @@ public int GetResult() return bytes; } - void IValueTaskSource.GetResult() + void IValueTaskSource.GetResult(short token) { + if (token != _token) + { + ThrowIncorrectTokenException(); + } + SocketError error = SocketError; Release(); @@ -1027,6 +1071,10 @@ void IValueTaskSource.GetResult() } } + private void ThrowIncorrectTokenException() => throw new InvalidOperationException(SR.InvalidOperation_IncorrectToken); + + private void ThrowMultipleContinuationsException() => throw new InvalidOperationException(SR.InvalidOperation_MultipleContinuations); + private void ThrowException(SocketError error) => throw CreateException(error); private Exception CreateException(SocketError error) diff --git a/src/System.Net.Sockets/tests/FunctionalTests/NetworkStreamTest.netcoreapp.cs b/src/System.Net.Sockets/tests/FunctionalTests/NetworkStreamTest.netcoreapp.cs index 98a06da7e2b1..10ff4bbba642 100644 --- a/src/System.Net.Sockets/tests/FunctionalTests/NetworkStreamTest.netcoreapp.cs +++ b/src/System.Net.Sockets/tests/FunctionalTests/NetworkStreamTest.netcoreapp.cs @@ -88,6 +88,36 @@ await RunWithConnectedNetworkStreamsAsync(async (server, client) => }); } + [Fact] + public async Task ReadAsync_AwaitMultipleTimes_Throws() + { + await RunWithConnectedNetworkStreamsAsync(async (server, client) => + { + var b = new byte[1]; + ValueTask r = server.ReadAsync(b); + await client.WriteAsync(new byte[] { 42 }); + Assert.Equal(1, await r); + Assert.Equal(42, b[0]); + await Assert.ThrowsAsync(async () => await r); + Assert.Throws(() => r.GetAwaiter().IsCompleted); + Assert.Throws(() => r.GetAwaiter().OnCompleted(() => { })); + Assert.Throws(() => r.GetAwaiter().GetResult()); + }); + } + + [Fact] + public async Task ReadAsync_MultipleContinuations_Throws() + { + await RunWithConnectedNetworkStreamsAsync((server, client) => + { + var b = new byte[1]; + ValueTask r = server.ReadAsync(b); + r.GetAwaiter().OnCompleted(() => { }); + Assert.Throws(() => r.GetAwaiter().OnCompleted(() => { })); + return Task.CompletedTask; + }); + } + [Fact] public async Task ReadAsync_MultipleConcurrentValueTaskReads_Success() { diff --git a/src/System.Runtime/ref/System.Runtime.cs b/src/System.Runtime/ref/System.Runtime.cs index 9b1a68455354..9b12e201e572 100644 --- a/src/System.Runtime/ref/System.Runtime.cs +++ b/src/System.Runtime/ref/System.Runtime.cs @@ -8013,22 +8013,22 @@ public enum ValueTaskSourceStatus } public interface IValueTaskSource { - System.Threading.Tasks.ValueTaskSourceStatus Status { get; } - void OnCompleted(System.Action continuation, object state, System.Threading.Tasks.ValueTaskSourceOnCompletedFlags flags); - void GetResult(); + System.Threading.Tasks.ValueTaskSourceStatus GetStatus(short token); + void OnCompleted(System.Action continuation, object state, short token, System.Threading.Tasks.ValueTaskSourceOnCompletedFlags flags); + void GetResult(short token); } public interface IValueTaskSource { - System.Threading.Tasks.ValueTaskSourceStatus Status { get; } - void OnCompleted(System.Action continuation, object state, System.Threading.Tasks.ValueTaskSourceOnCompletedFlags flags); - TResult GetResult(); + System.Threading.Tasks.ValueTaskSourceStatus GetStatus(short token); + void OnCompleted(System.Action continuation, object state, short token, System.Threading.Tasks.ValueTaskSourceOnCompletedFlags flags); + TResult GetResult(short token); } [System.Runtime.CompilerServices.AsyncMethodBuilderAttribute(typeof(System.Runtime.CompilerServices.AsyncValueTaskMethodBuilder))] public readonly partial struct ValueTask : System.IEquatable { internal readonly object _dummy; public ValueTask(System.Threading.Tasks.Task task) { throw null; } - public ValueTask(System.Threading.Tasks.IValueTaskSource source) { throw null; } + public ValueTask(System.Threading.Tasks.IValueTaskSource source, short token) { throw null; } public bool IsCanceled { get { throw null; } } public bool IsCompleted { get { throw null; } } public bool IsCompletedSuccessfully { get { throw null; } } @@ -8048,7 +8048,7 @@ public interface IValueTaskSource { internal readonly TResult _result; public ValueTask(System.Threading.Tasks.Task task) { throw null; } - public ValueTask(System.Threading.Tasks.IValueTaskSource source) { throw null; } + public ValueTask(System.Threading.Tasks.IValueTaskSource source, short token) { throw null; } public ValueTask(TResult result) { throw null; } public bool IsCanceled { get { throw null; } } public bool IsCompleted { get { throw null; } } diff --git a/src/System.Threading.Channels/src/Resources/Strings.resx b/src/System.Threading.Channels/src/Resources/Strings.resx index 83acd241208f..01229ace1d75 100644 --- a/src/System.Threading.Channels/src/Resources/Strings.resx +++ b/src/System.Threading.Channels/src/Resources/Strings.resx @@ -124,6 +124,9 @@ The asynchronous operation has not completed. - OnCompleted has already been used to register another continuation. + Another continuation was already registered. - \ No newline at end of file + + The result of the operation was already consumed and may not be used again. + + diff --git a/src/System.Threading.Channels/src/System/Threading/Channels/AsyncOperation.cs b/src/System.Threading.Channels/src/System/Threading/Channels/AsyncOperation.cs index acf48c02245d..189ca5e05045 100644 --- a/src/System.Threading.Channels/src/System/Threading/Channels/AsyncOperation.cs +++ b/src/System.Threading.Channels/src/System/Threading/Channels/AsyncOperation.cs @@ -10,19 +10,31 @@ namespace System.Threading.Channels { internal abstract class ResettableValueTaskSource { - protected static readonly Action s_completedSentinel = s => Debug.Fail($"{nameof(ResettableValueTaskSource)}.{nameof(s_completedSentinel)} invoked."); + /// Sentinel object used in a field to indicate the operation has completed. + protected static readonly Action s_completedSentinel = new Action(s => Debug.Fail($"{nameof(ResettableValueTaskSource)}.{nameof(s_completedSentinel)} invoked.")); + /// Throws an exception indicating that the operation's result was accessed before the operation completed. protected static void ThrowIncompleteOperationException() => throw new InvalidOperationException(SR.InvalidOperation_IncompleteAsyncOperation); + /// Throws an exception indicating that multiple continuations can't be set for the same operation. protected static void ThrowMultipleContinuations() => throw new InvalidOperationException(SR.InvalidOperation_MultipleContinuations); + /// Throws an exception indicating that the operation was used after it was supposed to be used. + protected static void ThrowIncorrectCurrentIdException() => + throw new InvalidOperationException(SR.InvalidOperation_IncorrectToken); + + /// Describes states the operation can be in. public enum States { + /// The operation has been assigned an owner. No one else can use it. Owned = 0, + /// Completion has been reserved. Only the reserver is allowed to complete it. CompletionReserved = 1, + /// The operation has completed and has had its result or error stored. CompletionSet = 2, + /// The operation's result/error has been retrieved. It's available for reuse. Released = 3 } } @@ -36,37 +48,48 @@ internal abstract class ResettableValueTaskSource : ResettableValueTaskSource private object _continuationState; private object _schedulingContext; private ExecutionContext _executionContext; + private short _currentId; + + public ValueTask CreateNonGenericValueTask() => new ValueTask(this, _currentId); + public ValueTask CreateGenericValueTask() => new ValueTask(this, _currentId); public bool RunContinutationsAsynchronously { get; protected set; } - public ValueTaskSourceStatus Status + public ValueTaskSourceStatus GetStatus(short token) { - get + if (_currentId != token) { - switch ((States)_state) - { - case States.Owned: - case States.CompletionReserved: - return ValueTaskSourceStatus.Pending; - - case States.CompletionSet: - case States.Released: - return - _error == null ? ValueTaskSourceStatus.Succeeded : - _error.SourceException is OperationCanceledException ? ValueTaskSourceStatus.Canceled : - ValueTaskSourceStatus.Faulted; - - default: - Debug.Fail($"Shouldn't be accessed in the '{(States)_state}' state."); - goto case States.CompletionSet; - } + throw new InvalidOperationException(); + } + + switch ((States)_state) + { + case States.Owned: + case States.CompletionReserved: + return ValueTaskSourceStatus.Pending; + + case States.CompletionSet: + case States.Released: + return + _error == null ? ValueTaskSourceStatus.Succeeded : + _error.SourceException is OperationCanceledException ? ValueTaskSourceStatus.Canceled : + ValueTaskSourceStatus.Faulted; + + default: + Debug.Fail($"Shouldn't be accessed in the '{(States)_state}' state."); + goto case States.CompletionSet; } } public bool IsCompleted => _state >= (int)States.CompletionSet; public States UnsafeState { get => (States)_state; set => _state = (int)value; } - public T GetResult() + public T GetResult(short token) { + if (_currentId != token) + { + ThrowIncorrectCurrentIdException(); + } + if (!IsCompleted) { ThrowIncompleteOperationException(); @@ -74,22 +97,27 @@ public T GetResult() ExceptionDispatchInfo error = _error; T result = _result; - + _currentId++; _state = (int)States.Released; // only after fetching all needed data error?.Throw(); return result; } - void IValueTaskSource.GetResult() + void IValueTaskSource.GetResult(short token) { + if (_currentId != token) + { + ThrowIncorrectCurrentIdException(); + } + if (!IsCompleted) { ThrowIncompleteOperationException(); } ExceptionDispatchInfo error = _error; - + _currentId++; _state = (int)States.Released; // only after fetching all needed data error?.Throw(); @@ -111,8 +139,13 @@ public bool TryOwnAndReset() return false; } - public void OnCompleted(Action continuation, object state, ValueTaskSourceOnCompletedFlags flags) + public void OnCompleted(Action continuation, object state, short token, ValueTaskSourceOnCompletedFlags flags) { + if (_currentId != token) + { + ThrowIncorrectCurrentIdException(); + } + // We need to store the state before the CompareExchange, so that if it completes immediately // after the CompareExchange, it'll find the state already stored. If someone misuses this // and schedules multiple continuations erroneously, we could end up using the wrong state. @@ -166,15 +199,9 @@ public void OnCompleted(Action continuation, object state, ValueTaskSour t.Item1(t.Item2); }, Tuple.Create(continuation, state)); } - else if (ts != null) - { - Task.Factory.StartNew(continuation, state, CancellationToken.None, TaskCreationOptions.DenyChildAttach, ts); - } else { - // TODO #27464: Change this to use the new QueueUserWorkItem signature when it's available. - Debug.Assert(_schedulingContext == null, $"Expected null context, got {_schedulingContext}"); - Task.Factory.StartNew(continuation, state, CancellationToken.None, TaskCreationOptions.DenyChildAttach, TaskScheduler.Default); + Task.Factory.StartNew(continuation, state, CancellationToken.None, TaskCreationOptions.DenyChildAttach, ts ?? TaskScheduler.Default); } } } diff --git a/src/System.Threading.Channels/src/System/Threading/Channels/BoundedChannel.cs b/src/System.Threading.Channels/src/System/Threading/Channels/BoundedChannel.cs index df81bb778835..7a595e48a959 100644 --- a/src/System.Threading.Channels/src/System/Threading/Channels/BoundedChannel.cs +++ b/src/System.Threading.Channels/src/System/Threading/Channels/BoundedChannel.cs @@ -107,7 +107,7 @@ public override ValueTask ReadAsync(CancellationToken cancellationToken) // Otherwise, queue the reader. var reader = new AsyncOperation(parent._runContinuationsAsynchronously, cancellationToken); parent._blockedReaders.EnqueueTail(reader); - return new ValueTask(reader); + return reader.CreateGenericValueTask(); } } @@ -141,7 +141,7 @@ public override ValueTask WaitToReadAsync(CancellationToken cancellationTo // there's a blocked reader task and return it. var waiter = new AsyncOperation(parent._runContinuationsAsynchronously, cancellationToken); ChannelUtilities.QueueWaiter(ref _parent._waitingReadersTail, waiter); - return new ValueTask(waiter); + return waiter.CreateGenericValueTask(); } } @@ -385,7 +385,7 @@ public override ValueTask WaitToWriteAsync(CancellationToken cancellationT // We're still allowed to write, but there's no space, so ensure a waiter is queued and return it. var waiter = new AsyncOperation(runContinuationsAsynchronously: true, cancellationToken); ChannelUtilities.QueueWaiter(ref parent._waitingWritersTail, waiter); - return new ValueTask(waiter); + return waiter.CreateGenericValueTask(); } } @@ -459,7 +459,7 @@ public override ValueTask WriteAsync(T item, CancellationToken cancellationToken var writer = new VoidAsyncOperationWithData(runContinuationsAsynchronously: true, cancellationToken); writer.Item = item; parent._blockedWriters.EnqueueTail(writer); - return new ValueTask(writer); + return writer.CreateNonGenericValueTask(); } else if (parent._mode == BoundedChannelFullMode.DropWrite) { diff --git a/src/System.Threading.Channels/src/System/Threading/Channels/ChannelUtilities.cs b/src/System.Threading.Channels/src/System/Threading/Channels/ChannelUtilities.cs index 711e97a1285e..10414be72cb1 100644 --- a/src/System.Threading.Channels/src/System/Threading/Channels/ChannelUtilities.cs +++ b/src/System.Threading.Channels/src/System/Threading/Channels/ChannelUtilities.cs @@ -73,7 +73,7 @@ internal static ValueTask QueueWaiter(ref AsyncOperation tail, Async c.Next = waiter; } tail = waiter; - return new ValueTask(waiter); + return waiter.CreateGenericValueTask(); } internal static void WakeUpWaiters(ref AsyncOperation listTail, bool result, Exception error = null) diff --git a/src/System.Threading.Channels/src/System/Threading/Channels/SingleConsumerUnboundedChannel.cs b/src/System.Threading.Channels/src/System/Threading/Channels/SingleConsumerUnboundedChannel.cs index 3a8760966b54..1720bd5541b3 100644 --- a/src/System.Threading.Channels/src/System/Threading/Channels/SingleConsumerUnboundedChannel.cs +++ b/src/System.Threading.Channels/src/System/Threading/Channels/SingleConsumerUnboundedChannel.cs @@ -115,7 +115,7 @@ public override ValueTask ReadAsync(CancellationToken cancellationToken) } oldBlockedReader?.TrySetCanceled(); - return new ValueTask(newBlockedReader); + return newBlockedReader.CreateGenericValueTask(); } public override bool TryRead(out T item) @@ -184,7 +184,7 @@ public override ValueTask WaitToReadAsync(CancellationToken cancellationTo } oldWaitingReader?.TrySetCanceled(); - return new ValueTask(newWaitingReader); + return newWaitingReader.CreateGenericValueTask(); } /// Gets the number of items in the channel. This should only be used by the debugger. diff --git a/src/System.Threading.Channels/src/System/Threading/Channels/UnboundedChannel.cs b/src/System.Threading.Channels/src/System/Threading/Channels/UnboundedChannel.cs index e25a829d85c2..fd910ca29f30 100644 --- a/src/System.Threading.Channels/src/System/Threading/Channels/UnboundedChannel.cs +++ b/src/System.Threading.Channels/src/System/Threading/Channels/UnboundedChannel.cs @@ -97,14 +97,14 @@ private ValueTask ReadAsyncCore(CancellationToken cancellationToken) if (singleton.TryOwnAndReset()) { parent._blockedReaders.EnqueueTail(singleton); - return new ValueTask(singleton); + return singleton.CreateGenericValueTask(); } } // Otherwise, create and queue a reader. var reader = new AsyncOperation(parent._runContinuationsAsynchronously, cancellationToken); parent._blockedReaders.EnqueueTail(reader); - return new ValueTask(reader); + return reader.CreateGenericValueTask(); } } @@ -166,14 +166,14 @@ public override ValueTask WaitToReadAsync(CancellationToken cancellationTo if (singleton.TryOwnAndReset()) { ChannelUtilities.QueueWaiter(ref parent._waitingReadersTail, singleton); - return new ValueTask(singleton); + return singleton.CreateGenericValueTask(); } } // Otherwise, create and queue a waiter. var waiter = new AsyncOperation(parent._runContinuationsAsynchronously, cancellationToken); ChannelUtilities.QueueWaiter(ref parent._waitingReadersTail, waiter); - return new ValueTask(waiter); + return waiter.CreateGenericValueTask(); } } diff --git a/src/System.Threading.Channels/tests/ChannelTestBase.cs b/src/System.Threading.Channels/tests/ChannelTestBase.cs index 11879918004b..35a599b51088 100644 --- a/src/System.Threading.Channels/tests/ChannelTestBase.cs +++ b/src/System.Threading.Channels/tests/ChannelTestBase.cs @@ -729,6 +729,242 @@ public async Task WaitToReadAsync_ConsecutiveReadsSucceed() } } + [Theory] + [InlineData(false, null)] + [InlineData(false, false)] + [InlineData(false, true)] + [InlineData(true, null)] + [InlineData(true, false)] + [InlineData(true, true)] + public void WaitToReadAsync_MultipleContinuations_Throws(bool onCompleted, bool? continueOnCapturedContext) + { + Channel c = CreateChannel(); + + ValueTask read = c.Reader.WaitToReadAsync(); + switch (continueOnCapturedContext) + { + case null: + if (onCompleted) + { + read.GetAwaiter().OnCompleted(() => { }); + Assert.Throws(() => read.GetAwaiter().OnCompleted(() => { })); + } + else + { + read.GetAwaiter().UnsafeOnCompleted(() => { }); + Assert.Throws(() => read.GetAwaiter().UnsafeOnCompleted(() => { })); + } + break; + + default: + if (onCompleted) + { + read.ConfigureAwait(continueOnCapturedContext.Value).GetAwaiter().OnCompleted(() => { }); + Assert.Throws(() => read.ConfigureAwait(continueOnCapturedContext.Value).GetAwaiter().OnCompleted(() => { })); + } + else + { + read.ConfigureAwait(continueOnCapturedContext.Value).GetAwaiter().UnsafeOnCompleted(() => { }); + Assert.Throws(() => read.ConfigureAwait(continueOnCapturedContext.Value).GetAwaiter().UnsafeOnCompleted(() => { })); + } + break; + } + } + + [Theory] + [InlineData(false, null)] + [InlineData(false, false)] + [InlineData(false, true)] + [InlineData(true, null)] + [InlineData(true, false)] + [InlineData(true, true)] + public void ReadAsync_MultipleContinuations_Throws(bool onCompleted, bool? continueOnCapturedContext) + { + Channel c = CreateChannel(); + + ValueTask read = c.Reader.ReadAsync(); + switch (continueOnCapturedContext) + { + case null: + if (onCompleted) + { + read.GetAwaiter().OnCompleted(() => { }); + Assert.Throws(() => read.GetAwaiter().OnCompleted(() => { })); + } + else + { + read.GetAwaiter().UnsafeOnCompleted(() => { }); + Assert.Throws(() => read.GetAwaiter().UnsafeOnCompleted(() => { })); + } + break; + + default: + if (onCompleted) + { + read.ConfigureAwait(continueOnCapturedContext.Value).GetAwaiter().OnCompleted(() => { }); + Assert.Throws(() => read.ConfigureAwait(continueOnCapturedContext.Value).GetAwaiter().OnCompleted(() => { })); + } + else + { + read.ConfigureAwait(continueOnCapturedContext.Value).GetAwaiter().UnsafeOnCompleted(() => { }); + Assert.Throws(() => read.ConfigureAwait(continueOnCapturedContext.Value).GetAwaiter().UnsafeOnCompleted(() => { })); + } + break; + } + } + + [Fact] + public async Task WaitToReadAsync_AwaitThenGetResult_Throws() + { + Channel c = CreateChannel(); + + ValueTask read = c.Reader.WaitToReadAsync(); + Assert.True(c.Writer.TryWrite(42)); + Assert.True(await read); + Assert.Throws(() => read.GetAwaiter().IsCompleted); + Assert.Throws(() => read.GetAwaiter().OnCompleted(() => { })); + Assert.Throws(() => read.GetAwaiter().GetResult()); + } + + [Fact] + public async Task ReadAsync_AwaitThenGetResult_Throws() + { + Channel c = CreateChannel(); + + ValueTask read = c.Reader.ReadAsync(); + Assert.True(c.Writer.TryWrite(42)); + Assert.Equal(42, await read); + Assert.Throws(() => read.GetAwaiter().IsCompleted); + Assert.Throws(() => read.GetAwaiter().OnCompleted(() => { })); + Assert.Throws(() => read.GetAwaiter().GetResult()); + } + + [Fact] + public async Task WaitToWriteAsync_AwaitThenGetResult_Throws() + { + Channel c = CreateFullChannel(); + if (c == null) + { + return; + } + + ValueTask write = c.Writer.WaitToWriteAsync(); + await c.Reader.ReadAsync(); + Assert.True(await write); + Assert.Throws(() => write.GetAwaiter().IsCompleted); + Assert.Throws(() => write.GetAwaiter().OnCompleted(() => { })); + Assert.Throws(() => write.GetAwaiter().GetResult()); + } + + [Fact] + public async Task WriteAsync_AwaitThenGetResult_Throws() + { + Channel c = CreateFullChannel(); + if (c == null) + { + return; + } + + ValueTask write = c.Writer.WriteAsync(42); + await c.Reader.ReadAsync(); + await write; + Assert.Throws(() => write.GetAwaiter().IsCompleted); + Assert.Throws(() => write.GetAwaiter().OnCompleted(() => { })); + Assert.Throws(() => write.GetAwaiter().GetResult()); + } + + [Theory] + [InlineData(false, null)] + [InlineData(false, false)] + [InlineData(false, true)] + [InlineData(true, null)] + [InlineData(true, false)] + [InlineData(true, true)] + public void WaitToWriteAsync_MultipleContinuations_Throws(bool onCompleted, bool? continueOnCapturedContext) + { + Channel c = CreateFullChannel(); + if (c == null) + { + return; + } + + ValueTask write = c.Writer.WaitToWriteAsync(); + switch (continueOnCapturedContext) + { + case null: + if (onCompleted) + { + write.GetAwaiter().OnCompleted(() => { }); + Assert.Throws(() => write.GetAwaiter().OnCompleted(() => { })); + } + else + { + write.GetAwaiter().UnsafeOnCompleted(() => { }); + Assert.Throws(() => write.GetAwaiter().UnsafeOnCompleted(() => { })); + } + break; + + default: + if (onCompleted) + { + write.ConfigureAwait(continueOnCapturedContext.Value).GetAwaiter().OnCompleted(() => { }); + Assert.Throws(() => write.ConfigureAwait(continueOnCapturedContext.Value).GetAwaiter().OnCompleted(() => { })); + } + else + { + write.ConfigureAwait(continueOnCapturedContext.Value).GetAwaiter().UnsafeOnCompleted(() => { }); + Assert.Throws(() => write.ConfigureAwait(continueOnCapturedContext.Value).GetAwaiter().UnsafeOnCompleted(() => { })); + } + break; + } + } + + [Theory] + [InlineData(false, null)] + [InlineData(false, false)] + [InlineData(false, true)] + [InlineData(true, null)] + [InlineData(true, false)] + [InlineData(true, true)] + public void WriteAsync_MultipleContinuations_Throws(bool onCompleted, bool? continueOnCapturedContext) + { + Channel c = CreateFullChannel(); + if (c == null) + { + return; + } + + ValueTask write = c.Writer.WriteAsync(42); + switch (continueOnCapturedContext) + { + case null: + if (onCompleted) + { + write.GetAwaiter().OnCompleted(() => { }); + Assert.Throws(() => write.GetAwaiter().OnCompleted(() => { })); + } + else + { + write.GetAwaiter().UnsafeOnCompleted(() => { }); + Assert.Throws(() => write.GetAwaiter().UnsafeOnCompleted(() => { })); + } + break; + + default: + if (onCompleted) + { + write.ConfigureAwait(continueOnCapturedContext.Value).GetAwaiter().OnCompleted(() => { }); + Assert.Throws(() => write.ConfigureAwait(continueOnCapturedContext.Value).GetAwaiter().OnCompleted(() => { })); + } + else + { + write.ConfigureAwait(continueOnCapturedContext.Value).GetAwaiter().UnsafeOnCompleted(() => { }); + Assert.Throws(() => write.ConfigureAwait(continueOnCapturedContext.Value).GetAwaiter().UnsafeOnCompleted(() => { })); + } + break; + } + } + public static IEnumerable Reader_ContinuesOnCurrentContextIfDesired_MemberData() => from readOrWait in new[] { true, false } from completeBeforeOnCompleted in new[] { true, false } diff --git a/src/System.Threading.Channels/tests/TestBase.cs b/src/System.Threading.Channels/tests/TestBase.cs index 9a84a8ecd69b..e41a6139528e 100644 --- a/src/System.Threading.Channels/tests/TestBase.cs +++ b/src/System.Threading.Channels/tests/TestBase.cs @@ -15,7 +15,11 @@ protected void AssertSynchronouslyCanceled(Task task, CancellationToken token) { Assert.Equal(TaskStatus.Canceled, task.Status); OperationCanceledException oce = Assert.ThrowsAny(() => task.GetAwaiter().GetResult()); - Assert.Equal(token, oce.CancellationToken); + if (PlatformDetection.IsNetCore) + { + // Earlier netstandard versions didn't have the APIs to always make this possible. + Assert.Equal(token, oce.CancellationToken); + } } protected async Task AssertCanceled(Task task, CancellationToken token) diff --git a/src/System.Threading.Tasks.Extensions/ref/System.Threading.Tasks.Extensions.cs b/src/System.Threading.Tasks.Extensions/ref/System.Threading.Tasks.Extensions.cs index 36f4afcbf63f..fdb20dc25fd5 100644 --- a/src/System.Threading.Tasks.Extensions/ref/System.Threading.Tasks.Extensions.cs +++ b/src/System.Threading.Tasks.Extensions/ref/System.Threading.Tasks.Extensions.cs @@ -98,22 +98,22 @@ public enum ValueTaskSourceStatus } public interface IValueTaskSource { - System.Threading.Tasks.ValueTaskSourceStatus Status { get; } - void OnCompleted(System.Action continuation, object state, System.Threading.Tasks.ValueTaskSourceOnCompletedFlags flags); - void GetResult(); + System.Threading.Tasks.ValueTaskSourceStatus GetStatus(short token); + void OnCompleted(System.Action continuation, object state, short token, System.Threading.Tasks.ValueTaskSourceOnCompletedFlags flags); + void GetResult(short token); } public interface IValueTaskSource { - System.Threading.Tasks.ValueTaskSourceStatus Status { get; } - void OnCompleted(System.Action continuation, object state, System.Threading.Tasks.ValueTaskSourceOnCompletedFlags flags); - TResult GetResult(); + System.Threading.Tasks.ValueTaskSourceStatus GetStatus(short token); + void OnCompleted(System.Action continuation, object state, short token, System.Threading.Tasks.ValueTaskSourceOnCompletedFlags flags); + TResult GetResult(short token); } [System.Runtime.CompilerServices.AsyncMethodBuilderAttribute(typeof(System.Runtime.CompilerServices.AsyncValueTaskMethodBuilder))] public readonly partial struct ValueTask : System.IEquatable { internal readonly object _dummy; public ValueTask(System.Threading.Tasks.Task task) { throw null; } - public ValueTask(System.Threading.Tasks.IValueTaskSource source) { throw null; } + public ValueTask(System.Threading.Tasks.IValueTaskSource source, short token) { throw null; } public bool IsCanceled { get { throw null; } } public bool IsCompleted { get { throw null; } } public bool IsCompletedSuccessfully { get { throw null; } } @@ -133,7 +133,7 @@ public interface IValueTaskSource { internal readonly TResult _result; public ValueTask(System.Threading.Tasks.Task task) { throw null; } - public ValueTask(System.Threading.Tasks.IValueTaskSource source) { throw null; } + public ValueTask(System.Threading.Tasks.IValueTaskSource source, short token) { throw null; } public ValueTask(TResult result) { throw null; } public bool IsCanceled { get { throw null; } } public bool IsCompleted { get { throw null; } } diff --git a/src/System.Threading.Tasks.Extensions/tests/AsyncValueTaskMethodBuilderTests.cs b/src/System.Threading.Tasks.Extensions/tests/AsyncValueTaskMethodBuilderTests.cs index 096338523109..258404f92555 100644 --- a/src/System.Threading.Tasks.Extensions/tests/AsyncValueTaskMethodBuilderTests.cs +++ b/src/System.Threading.Tasks.Extensions/tests/AsyncValueTaskMethodBuilderTests.cs @@ -355,20 +355,20 @@ async Task TaskReturningMethod() await Task.FromResult(42); await new ValueTask(); await Assert.ThrowsAsync(async () => await new ValueTask(Task.FromException(new FormatException()))); - await Assert.ThrowsAsync(async () => await new ValueTask(ManualResetValueTaskSource.Completed(0, new FormatException()))); + await Assert.ThrowsAsync(async () => await new ValueTask(ManualResetValueTaskSource.Completed(0, new FormatException()), 0)); Assert.Equal(42, await new ValueTask(42)); Assert.Equal(42, await new ValueTask(Task.FromResult(42))); - Assert.Equal(42, await new ValueTask(ManualResetValueTaskSource.Completed(42, null))); + Assert.Equal(42, await new ValueTask(ManualResetValueTaskSource.Completed(42, null), 0)); await Assert.ThrowsAsync(async () => await new ValueTask(Task.FromException(new FormatException()))); - await Assert.ThrowsAsync(async () => await new ValueTask(ManualResetValueTaskSource.Completed(0, new FormatException()))); + await Assert.ThrowsAsync(async () => await new ValueTask(ManualResetValueTaskSource.Completed(0, new FormatException()), 0)); // Incomplete await Assert.ThrowsAsync(async () => await new ValueTask(Task.Delay(1).ContinueWith(_ => throw new FormatException()))); - await Assert.ThrowsAsync(async () => await new ValueTask(ManualResetValueTaskSource.Delay(1, 0, new FormatException()))); + await Assert.ThrowsAsync(async () => await new ValueTask(ManualResetValueTaskSource.Delay(1, 0, new FormatException()), 0)); Assert.Equal(42, await new ValueTask(Task.Delay(1).ContinueWith(_ => 42))); - Assert.Equal(42, await new ValueTask(ManualResetValueTaskSource.Delay(1, 42, null))); + Assert.Equal(42, await new ValueTask(ManualResetValueTaskSource.Delay(1, 42, null), 0)); await Assert.ThrowsAsync(async () => await new ValueTask(Task.Delay(1).ContinueWith(_ => throw new FormatException()))); - await Assert.ThrowsAsync(async () => await new ValueTask(ManualResetValueTaskSource.Delay(1, 0, new FormatException()))); + await Assert.ThrowsAsync(async () => await new ValueTask(ManualResetValueTaskSource.Delay(1, 0, new FormatException()), 0)); await Task.Yield(); } } @@ -382,20 +382,20 @@ async Task TaskInt32ReturningMethod() await Task.FromResult(42); await new ValueTask(); await Assert.ThrowsAsync(async () => await new ValueTask(Task.FromException(new FormatException()))); - await Assert.ThrowsAsync(async () => await new ValueTask(ManualResetValueTaskSource.Completed(0, new FormatException()))); + await Assert.ThrowsAsync(async () => await new ValueTask(ManualResetValueTaskSource.Completed(0, new FormatException()), 0)); Assert.Equal(42, await new ValueTask(42)); Assert.Equal(42, await new ValueTask(Task.FromResult(42))); - Assert.Equal(42, await new ValueTask(ManualResetValueTaskSource.Completed(42, null))); + Assert.Equal(42, await new ValueTask(ManualResetValueTaskSource.Completed(42, null), 0)); await Assert.ThrowsAsync(async () => await new ValueTask(Task.FromException(new FormatException()))); - await Assert.ThrowsAsync(async () => await new ValueTask(ManualResetValueTaskSource.Completed(0, new FormatException()))); + await Assert.ThrowsAsync(async () => await new ValueTask(ManualResetValueTaskSource.Completed(0, new FormatException()), 0)); // Incomplete await Assert.ThrowsAsync(async () => await new ValueTask(Task.Delay(1).ContinueWith(_ => throw new FormatException()))); - await Assert.ThrowsAsync(async () => await new ValueTask(ManualResetValueTaskSource.Delay(1, 0, new FormatException()))); + await Assert.ThrowsAsync(async () => await new ValueTask(ManualResetValueTaskSource.Delay(1, 0, new FormatException()), 0)); Assert.Equal(42, await new ValueTask(Task.Delay(1).ContinueWith(_ => 42))); - Assert.Equal(42, await new ValueTask(ManualResetValueTaskSource.Delay(1, 42, null))); + Assert.Equal(42, await new ValueTask(ManualResetValueTaskSource.Delay(1, 42, null), 0)); await Assert.ThrowsAsync(async () => await new ValueTask(Task.Delay(1).ContinueWith(_ => throw new FormatException()))); - await Assert.ThrowsAsync(async () => await new ValueTask(ManualResetValueTaskSource.Delay(1, 0, new FormatException()))); + await Assert.ThrowsAsync(async () => await new ValueTask(ManualResetValueTaskSource.Delay(1, 0, new FormatException()), 0)); await Task.Yield(); } return 17; @@ -410,20 +410,20 @@ async ValueTask ValueTaskReturningMethod() await Task.FromResult(42); await new ValueTask(); await Assert.ThrowsAsync(async () => await new ValueTask(Task.FromException(new FormatException()))); - await Assert.ThrowsAsync(async () => await new ValueTask(ManualResetValueTaskSource.Completed(0, new FormatException()))); + await Assert.ThrowsAsync(async () => await new ValueTask(ManualResetValueTaskSource.Completed(0, new FormatException()), 0)); Assert.Equal(42, await new ValueTask(42)); Assert.Equal(42, await new ValueTask(Task.FromResult(42))); - Assert.Equal(42, await new ValueTask(ManualResetValueTaskSource.Completed(42, null))); + Assert.Equal(42, await new ValueTask(ManualResetValueTaskSource.Completed(42, null), 0)); await Assert.ThrowsAsync(async () => await new ValueTask(Task.FromException(new FormatException()))); - await Assert.ThrowsAsync(async () => await new ValueTask(ManualResetValueTaskSource.Completed(0, new FormatException()))); + await Assert.ThrowsAsync(async () => await new ValueTask(ManualResetValueTaskSource.Completed(0, new FormatException()), 0)); // Incomplete await Assert.ThrowsAsync(async () => await new ValueTask(Task.Delay(1).ContinueWith(_ => throw new FormatException()))); - await Assert.ThrowsAsync(async () => await new ValueTask(ManualResetValueTaskSource.Delay(1, 0, new FormatException()))); + await Assert.ThrowsAsync(async () => await new ValueTask(ManualResetValueTaskSource.Delay(1, 0, new FormatException()), 0)); Assert.Equal(42, await new ValueTask(Task.Delay(1).ContinueWith(_ => 42))); - Assert.Equal(42, await new ValueTask(ManualResetValueTaskSource.Delay(1, 42, null))); + Assert.Equal(42, await new ValueTask(ManualResetValueTaskSource.Delay(1, 42, null), 0)); await Assert.ThrowsAsync(async () => await new ValueTask(Task.Delay(1).ContinueWith(_ => throw new FormatException()))); - await Assert.ThrowsAsync(async () => await new ValueTask(ManualResetValueTaskSource.Delay(1, 0, new FormatException()))); + await Assert.ThrowsAsync(async () => await new ValueTask(ManualResetValueTaskSource.Delay(1, 0, new FormatException()), 0)); await Task.Yield(); } } @@ -437,20 +437,20 @@ async ValueTask ValueTaskInt32ReturningMethod() await Task.FromResult(42); await new ValueTask(); await Assert.ThrowsAsync(async () => await new ValueTask(Task.FromException(new FormatException()))); - await Assert.ThrowsAsync(async () => await new ValueTask(ManualResetValueTaskSource.Completed(0, new FormatException()))); + await Assert.ThrowsAsync(async () => await new ValueTask(ManualResetValueTaskSource.Completed(0, new FormatException()), 0)); Assert.Equal(42, await new ValueTask(42)); Assert.Equal(42, await new ValueTask(Task.FromResult(42))); - Assert.Equal(42, await new ValueTask(ManualResetValueTaskSource.Completed(42, null))); + Assert.Equal(42, await new ValueTask(ManualResetValueTaskSource.Completed(42, null), 0)); await Assert.ThrowsAsync(async () => await new ValueTask(Task.FromException(new FormatException()))); - await Assert.ThrowsAsync(async () => await new ValueTask(ManualResetValueTaskSource.Completed(0, new FormatException()))); + await Assert.ThrowsAsync(async () => await new ValueTask(ManualResetValueTaskSource.Completed(0, new FormatException()), 0)); // Incomplete await Assert.ThrowsAsync(async () => await new ValueTask(Task.Delay(1).ContinueWith(_ => throw new FormatException()))); - await Assert.ThrowsAsync(async () => await new ValueTask(ManualResetValueTaskSource.Delay(1, 0, new FormatException()))); + await Assert.ThrowsAsync(async () => await new ValueTask(ManualResetValueTaskSource.Delay(1, 0, new FormatException()), 0)); Assert.Equal(42, await new ValueTask(Task.Delay(1).ContinueWith(_ => 42))); - Assert.Equal(42, await new ValueTask(ManualResetValueTaskSource.Delay(1, 42, null))); + Assert.Equal(42, await new ValueTask(ManualResetValueTaskSource.Delay(1, 42, null), 0)); await Assert.ThrowsAsync(async () => await new ValueTask(Task.Delay(1).ContinueWith(_ => throw new FormatException()))); - await Assert.ThrowsAsync(async () => await new ValueTask(ManualResetValueTaskSource.Delay(1, 0, new FormatException()))); + await Assert.ThrowsAsync(async () => await new ValueTask(ManualResetValueTaskSource.Delay(1, 0, new FormatException()), 0)); await Task.Yield(); } return 18; diff --git a/src/System.Threading.Tasks.Extensions/tests/ManualResetValueTaskSource.cs b/src/System.Threading.Tasks.Extensions/tests/ManualResetValueTaskSource.cs index e49be1918ba5..446a42f1752d 100644 --- a/src/System.Threading.Tasks.Extensions/tests/ManualResetValueTaskSource.cs +++ b/src/System.Threading.Tasks.Extensions/tests/ManualResetValueTaskSource.cs @@ -51,13 +51,13 @@ internal sealed class ManualResetValueTaskSource : IValueTaskSource, IValu private T _result; private ExceptionDispatchInfo _error; - public ValueTaskSourceStatus Status => + public ValueTaskSourceStatus GetStatus(short token) => !_completed ? ValueTaskSourceStatus.Pending : _error == null ? ValueTaskSourceStatus.Succeeded : _error.SourceException is OperationCanceledException ? ValueTaskSourceStatus.Canceled : ValueTaskSourceStatus.Faulted; - public T GetResult() + public T GetResult(short token) { if (!_completed) { @@ -75,9 +75,9 @@ public T GetResult() return _result; } - void IValueTaskSource.GetResult() + void IValueTaskSource.GetResult(short token) { - GetResult(); + GetResult(token); } public void Reset() @@ -89,7 +89,7 @@ public void Reset() _error = null; } - public void OnCompleted(Action continuation, object state, ValueTaskSourceOnCompletedFlags flags) + public void OnCompleted(Action continuation, object state, short token, ValueTaskSourceOnCompletedFlags flags) { if ((flags & ValueTaskSourceOnCompletedFlags.FlowExecutionContext) != 0) { diff --git a/src/System.Threading.Tasks.Extensions/tests/ValueTaskTests.cs b/src/System.Threading.Tasks.Extensions/tests/ValueTaskTests.cs index 82fdfa09be87..0e870d486916 100644 --- a/src/System.Threading.Tasks.Extensions/tests/ValueTaskTests.cs +++ b/src/System.Threading.Tasks.Extensions/tests/ValueTaskTests.cs @@ -52,7 +52,7 @@ public void NonGeneric_CreateFromSuccessfullyCompleted_IsCompletedSuccessfully(C ValueTask t = mode == CtorMode.Result ? default : mode == CtorMode.Task ? new ValueTask(Task.CompletedTask) : - new ValueTask(ManualResetValueTaskSource.Completed(0, null)); + new ValueTask(ManualResetValueTaskSource.Completed(0, null), 0); Assert.True(t.IsCompleted); Assert.True(t.IsCompletedSuccessfully); Assert.False(t.IsFaulted); @@ -68,7 +68,7 @@ public void Generic_CreateFromSuccessfullyCompleted_IsCompletedSuccessfully(Ctor ValueTask t = mode == CtorMode.Result ? new ValueTask(42) : mode == CtorMode.Task ? new ValueTask(Task.FromResult(42)) : - new ValueTask(ManualResetValueTaskSource.Completed(42, null)); + new ValueTask(ManualResetValueTaskSource.Completed(42, null), 0); Assert.True(t.IsCompleted); Assert.True(t.IsCompletedSuccessfully); Assert.False(t.IsFaulted); @@ -93,7 +93,7 @@ public void NonGeneric_CreateFromNotCompleted_ThenCompleteSuccessfully(CtorMode case CtorMode.ValueTaskSource: var mre = new ManualResetValueTaskSource(); - t = new ValueTask(mre); + t = new ValueTask(mre, 0); completer = mre; break; } @@ -137,7 +137,7 @@ public void Generic_CreateFromNotCompleted_ThenCompleteSuccessfully(CtorMode mod case CtorMode.ValueTaskSource: var mre = new ManualResetValueTaskSource(); - t = new ValueTask(mre); + t = new ValueTask(mre, 0); completer = mre; break; } @@ -182,7 +182,7 @@ public void NonGeneric_CreateFromNotCompleted_ThenFault(CtorMode mode) case CtorMode.ValueTaskSource: var mre = new ManualResetValueTaskSource(); - t = new ValueTask(mre); + t = new ValueTask(mre, 0); completer = mre; break; } @@ -230,7 +230,7 @@ public void Generic_CreateFromNotCompleted_ThenFault(CtorMode mode) case CtorMode.ValueTaskSource: var mre = new ManualResetValueTaskSource(); - t = new ValueTask(mre); + t = new ValueTask(mre, 0); completer = mre; break; } @@ -268,7 +268,7 @@ public void Generic_CreateFromNotCompleted_ThenFault(CtorMode mode) public void NonGeneric_CreateFromFaulted_IsFaulted(CtorMode mode) { InvalidOperationException e = new InvalidOperationException(); - ValueTask t = mode == CtorMode.Task ? new ValueTask(Task.FromException(e)) : new ValueTask(ManualResetValueTaskSource.Completed(0, e)); + ValueTask t = mode == CtorMode.Task ? new ValueTask(Task.FromException(e)) : new ValueTask(ManualResetValueTaskSource.Completed(0, e), 0); Assert.True(t.IsCompleted); Assert.False(t.IsCompletedSuccessfully); @@ -284,7 +284,7 @@ public void NonGeneric_CreateFromFaulted_IsFaulted(CtorMode mode) public void Generic_CreateFromFaulted_IsFaulted(CtorMode mode) { InvalidOperationException e = new InvalidOperationException(); - ValueTask t = mode == CtorMode.Task ? new ValueTask(Task.FromException(e)) : new ValueTask(ManualResetValueTaskSource.Completed(0, e)); + ValueTask t = mode == CtorMode.Task ? new ValueTask(Task.FromException(e)) : new ValueTask(ManualResetValueTaskSource.Completed(0, e), 0); Assert.True(t.IsCompleted); Assert.False(t.IsCompletedSuccessfully); @@ -299,7 +299,7 @@ public void Generic_CreateFromFaulted_IsFaulted(CtorMode mode) public void NonGeneric_CreateFromNullTask_Throws() { AssertExtensions.Throws("task", () => new ValueTask((Task)null)); - AssertExtensions.Throws("source", () => new ValueTask((IValueTaskSource)null)); + AssertExtensions.Throws("source", () => new ValueTask((IValueTaskSource)null, 0)); } [Fact] @@ -308,8 +308,8 @@ public void Generic_CreateFromNullTask_Throws() AssertExtensions.Throws("task", () => new ValueTask((Task)null)); AssertExtensions.Throws("task", () => new ValueTask((Task)null)); - AssertExtensions.Throws("source", () => new ValueTask((IValueTaskSource)null)); - AssertExtensions.Throws("source", () => new ValueTask((IValueTaskSource)null)); + AssertExtensions.Throws("source", () => new ValueTask((IValueTaskSource)null, 0)); + AssertExtensions.Throws("source", () => new ValueTask((IValueTaskSource)null, 0)); } [Fact] @@ -348,7 +348,7 @@ public void Generic_CreateFromValue_AsTaskNotIdempotent() [Fact] public void NonGeneric_CreateFromValueTaskSource_AsTaskIdempotent() // validates unsupported behavior specific to the backing IValueTaskSource { - ValueTask vt = new ValueTask(ManualResetValueTaskSource.Completed(42, null)); + ValueTask vt = new ValueTask(ManualResetValueTaskSource.Completed(42, null), 0); Task t = vt.AsTask(); Assert.NotNull(t); Assert.Same(t, vt.AsTask()); @@ -358,7 +358,7 @@ public void NonGeneric_CreateFromValueTaskSource_AsTaskIdempotent() // validates [Fact] public void Generic_CreateFromValueTaskSource_AsTaskNotIdempotent() // validates unsupported behavior specific to the backing IValueTaskSource { - ValueTask t = new ValueTask(ManualResetValueTaskSource.Completed(42, null)); + ValueTask t = new ValueTask(ManualResetValueTaskSource.Completed(42, null), 0); Assert.NotSame(Task.FromResult(42), t.AsTask()); Assert.NotSame(t.AsTask(), t.AsTask()); } @@ -368,7 +368,7 @@ public void Generic_CreateFromValueTaskSource_AsTaskNotIdempotent() // validates [InlineData(true)] public async Task NonGeneric_CreateFromValueTaskSource_Success(bool sync) { - ValueTask vt = new ValueTask(sync ? ManualResetValueTaskSource.Completed(0) : ManualResetValueTaskSource.Delay(1, 0)); + ValueTask vt = new ValueTask(sync ? ManualResetValueTaskSource.Completed(0) : ManualResetValueTaskSource.Delay(1, 0), 0); Task t = vt.AsTask(); if (sync) { @@ -382,7 +382,7 @@ public async Task NonGeneric_CreateFromValueTaskSource_Success(bool sync) [InlineData(true)] public async Task Generic_CreateFromValueTaskSource_Success(bool sync) { - ValueTask vt = new ValueTask(sync ? ManualResetValueTaskSource.Completed(42) : ManualResetValueTaskSource.Delay(1, 42)); + ValueTask vt = new ValueTask(sync ? ManualResetValueTaskSource.Completed(42) : ManualResetValueTaskSource.Delay(1, 42), 0); Task t = vt.AsTask(); if (sync) { @@ -396,7 +396,7 @@ public async Task Generic_CreateFromValueTaskSource_Success(bool sync) [InlineData(true)] public async Task NonGeneric_CreateFromValueTaskSource_Faulted(bool sync) { - ValueTask vt = new ValueTask(sync ? ManualResetValueTaskSource.Completed(0, new FormatException()) : ManualResetValueTaskSource.Delay(1, 0, new FormatException())); + ValueTask vt = new ValueTask(sync ? ManualResetValueTaskSource.Completed(0, new FormatException()) : ManualResetValueTaskSource.Delay(1, 0, new FormatException()), 0); Task t = vt.AsTask(); if (sync) { @@ -414,7 +414,7 @@ public async Task NonGeneric_CreateFromValueTaskSource_Faulted(bool sync) [InlineData(true)] public async Task Generic_CreateFromValueTaskSource_Faulted(bool sync) { - ValueTask vt = new ValueTask(sync ? ManualResetValueTaskSource.Completed(0, new FormatException()) : ManualResetValueTaskSource.Delay(1, 0, new FormatException())); + ValueTask vt = new ValueTask(sync ? ManualResetValueTaskSource.Completed(0, new FormatException()) : ManualResetValueTaskSource.Delay(1, 0, new FormatException()), 0); Task t = vt.AsTask(); if (sync) { @@ -432,7 +432,7 @@ public async Task Generic_CreateFromValueTaskSource_Faulted(bool sync) [InlineData(true)] public async Task NonGeneric_CreateFromValueTaskSource_Canceled(bool sync) { - ValueTask vt = new ValueTask(sync ? ManualResetValueTaskSource.Completed(0, new OperationCanceledException()) : ManualResetValueTaskSource.Delay(1, 0, new OperationCanceledException())); + ValueTask vt = new ValueTask(sync ? ManualResetValueTaskSource.Completed(0, new OperationCanceledException()) : ManualResetValueTaskSource.Delay(1, 0, new OperationCanceledException()), 0); Task t = vt.AsTask(); if (sync) { @@ -450,7 +450,7 @@ public async Task NonGeneric_CreateFromValueTaskSource_Canceled(bool sync) [InlineData(true)] public async Task Generic_CreateFromValueTaskSource_Canceled(bool sync) { - ValueTask vt = new ValueTask(sync ? ManualResetValueTaskSource.Completed(0, new OperationCanceledException()) : ManualResetValueTaskSource.Delay(1, 0, new OperationCanceledException())); + ValueTask vt = new ValueTask(sync ? ManualResetValueTaskSource.Completed(0, new OperationCanceledException()) : ManualResetValueTaskSource.Delay(1, 0, new OperationCanceledException()), 0); Task t = vt.AsTask(); if (sync) { @@ -482,7 +482,7 @@ public void NonGeneric_Preserve_FromTask_EqualityMaintained() [Fact] public void NonGeneric_Preserve_FromValueTaskSource_TransitionedToTask() { - ValueTask vt1 = new ValueTask(ManualResetValueTaskSource.Completed(42)); + ValueTask vt1 = new ValueTask(ManualResetValueTaskSource.Completed(42), 0); ValueTask vt2 = vt1.Preserve(); ValueTask vt3 = vt2.Preserve(); Assert.True(vt1 != vt2); @@ -509,7 +509,7 @@ public void Generic_Preserve_FromTask_EqualityMaintained() [Fact] public void Generic_Preserve_FromValueTaskSource_TransitionedToTask() { - ValueTask vt1 = new ValueTask(ManualResetValueTaskSource.Completed(42)); + ValueTask vt1 = new ValueTask(ManualResetValueTaskSource.Completed(42), 0); ValueTask vt2 = vt1.Preserve(); ValueTask vt3 = vt2.Preserve(); Assert.True(vt1 != vt2); @@ -526,7 +526,7 @@ public async Task NonGeneric_CreateFromCompleted_Await(CtorMode mode) ValueTask Create() => mode == CtorMode.Result ? new ValueTask() : mode == CtorMode.Task ? new ValueTask(Task.FromResult(42)) : - new ValueTask(ManualResetValueTaskSource.Completed(0, null)); + new ValueTask(ManualResetValueTaskSource.Completed(0, null), 0); int thread = Environment.CurrentManagedThreadId; @@ -549,7 +549,7 @@ public async Task Generic_CreateFromCompleted_Await(CtorMode mode) ValueTask Create() => mode == CtorMode.Result ? new ValueTask(42) : mode == CtorMode.Task ? new ValueTask(Task.FromResult(42)) : - new ValueTask(ManualResetValueTaskSource.Completed(42, null)); + new ValueTask(ManualResetValueTaskSource.Completed(42, null), 0); int thread = Environment.CurrentManagedThreadId; @@ -598,7 +598,7 @@ public async Task Generic_CreateFromTask_Await_Normal(bool? continueOnCapturedCo public async Task CreateFromValueTaskSource_Await_Normal(bool? continueOnCapturedContext) { var mre = new ManualResetValueTaskSource(); - ValueTask t = new ValueTask(mre); + ValueTask t = new ValueTask(mre, 0); var ignored = Task.Delay(1).ContinueWith(_ => mre.SetResult(42)); switch (continueOnCapturedContext) { @@ -614,7 +614,7 @@ public async Task CreateFromValueTaskSource_Await_Normal(bool? continueOnCapture public async Task Generic_CreateFromValueTaskSource_Await_Normal(bool? continueOnCapturedContext) { var mre = new ManualResetValueTaskSource(); - ValueTask t = new ValueTask(mre); + ValueTask t = new ValueTask(mre, 0); var ignored = Task.Delay(1).ContinueWith(_ => mre.SetResult(42)); switch (continueOnCapturedContext) { @@ -632,7 +632,7 @@ public async Task NonGeneric_Awaiter_OnCompleted(CtorMode mode) ValueTask t = mode == CtorMode.Result ? new ValueTask() : mode == CtorMode.Task ? new ValueTask(Task.CompletedTask) : - new ValueTask(ManualResetValueTaskSource.Completed(0, null)); + new ValueTask(ManualResetValueTaskSource.Completed(0, null), 0); var tcs = new TaskCompletionSource(); t.GetAwaiter().OnCompleted(() => tcs.SetResult(true)); @@ -648,7 +648,7 @@ public async Task NonGeneric_Awaiter_UnsafeOnCompleted(CtorMode mode) ValueTask t = mode == CtorMode.Result ? new ValueTask() : mode == CtorMode.Task ? new ValueTask(Task.CompletedTask) : - new ValueTask(ManualResetValueTaskSource.Completed(0, null)); + new ValueTask(ManualResetValueTaskSource.Completed(0, null), 0); var tcs = new TaskCompletionSource(); t.GetAwaiter().UnsafeOnCompleted(() => tcs.SetResult(true)); @@ -664,7 +664,7 @@ public async Task Generic_Awaiter_OnCompleted(CtorMode mode) ValueTask t = mode == CtorMode.Result ? new ValueTask(42) : mode == CtorMode.Task ? new ValueTask(Task.FromResult(42)) : - new ValueTask(ManualResetValueTaskSource.Completed(42, null)); + new ValueTask(ManualResetValueTaskSource.Completed(42, null), 0); var tcs = new TaskCompletionSource(); t.GetAwaiter().OnCompleted(() => tcs.SetResult(true)); @@ -680,7 +680,7 @@ public async Task Generic_Awaiter_UnsafeOnCompleted(CtorMode mode) ValueTask t = mode == CtorMode.Result ? new ValueTask(42) : mode == CtorMode.Task ? new ValueTask(Task.FromResult(42)) : - new ValueTask(ManualResetValueTaskSource.Completed(42, null)); + new ValueTask(ManualResetValueTaskSource.Completed(42, null), 0); var tcs = new TaskCompletionSource(); t.GetAwaiter().UnsafeOnCompleted(() => tcs.SetResult(true)); @@ -699,7 +699,7 @@ public async Task NonGeneric_ConfiguredAwaiter_OnCompleted(CtorMode mode, bool c ValueTask t = mode == CtorMode.Result ? new ValueTask() : mode == CtorMode.Task ? new ValueTask(Task.CompletedTask) : - new ValueTask(ManualResetValueTaskSource.Completed(0, null)); + new ValueTask(ManualResetValueTaskSource.Completed(0, null), 0); var tcs = new TaskCompletionSource(); t.ConfigureAwait(continueOnCapturedContext).GetAwaiter().OnCompleted(() => tcs.SetResult(true)); @@ -718,7 +718,7 @@ public async Task NonGeneric_ConfiguredAwaiter_UnsafeOnCompleted(CtorMode mode, ValueTask t = mode == CtorMode.Result ? new ValueTask() : mode == CtorMode.Task ? new ValueTask(Task.CompletedTask) : - new ValueTask(ManualResetValueTaskSource.Completed(0, null)); + new ValueTask(ManualResetValueTaskSource.Completed(0, null), 0); var tcs = new TaskCompletionSource(); t.ConfigureAwait(continueOnCapturedContext).GetAwaiter().UnsafeOnCompleted(() => tcs.SetResult(true)); @@ -737,7 +737,7 @@ public async Task Generic_ConfiguredAwaiter_OnCompleted(CtorMode mode, bool cont ValueTask t = mode == CtorMode.Result ? new ValueTask(42) : mode == CtorMode.Task ? new ValueTask(Task.FromResult(42)) : - new ValueTask(ManualResetValueTaskSource.Completed(42, null)); + new ValueTask(ManualResetValueTaskSource.Completed(42, null), 0); var tcs = new TaskCompletionSource(); t.ConfigureAwait(continueOnCapturedContext).GetAwaiter().OnCompleted(() => tcs.SetResult(true)); @@ -756,7 +756,7 @@ public async Task Generic_ConfiguredAwaiter_UnsafeOnCompleted(CtorMode mode, boo ValueTask t = mode == CtorMode.Result ? new ValueTask(42) : mode == CtorMode.Task ? new ValueTask(Task.FromResult(42)) : - new ValueTask(ManualResetValueTaskSource.Completed(42, null)); + new ValueTask(ManualResetValueTaskSource.Completed(42, null), 0); var tcs = new TaskCompletionSource(); t.ConfigureAwait(continueOnCapturedContext).GetAwaiter().UnsafeOnCompleted(() => tcs.SetResult(true)); @@ -778,7 +778,7 @@ await Task.Run(() => ValueTask t = mode == CtorMode.Result ? new ValueTask() : mode == CtorMode.Task ? new ValueTask(Task.CompletedTask) : - new ValueTask(ManualResetValueTaskSource.Completed(0, null)); + new ValueTask(ManualResetValueTaskSource.Completed(0, null), 0); var mres = new ManualResetEventSlim(); t.GetAwaiter().OnCompleted(() => mres.Set()); @@ -809,7 +809,7 @@ await Task.Run(() => ValueTask t = mode == CtorMode.Result ? new ValueTask(42) : mode == CtorMode.Task ? new ValueTask(sync ? Task.FromResult(42) : Task.Delay(1).ContinueWith(_ => 42)) : - new ValueTask(sync ? ManualResetValueTaskSource.Completed(42, null) : ManualResetValueTaskSource.Delay(1, 42, null)); + new ValueTask(sync ? ManualResetValueTaskSource.Completed(42, null) : ManualResetValueTaskSource.Delay(1, 42, null), 0); var mres = new ManualResetEventSlim(); t.GetAwaiter().OnCompleted(() => mres.Set()); @@ -845,7 +845,7 @@ await Task.Run(() => ValueTask t = mode == CtorMode.Result ? new ValueTask() : mode == CtorMode.Task ? new ValueTask(sync ? Task.CompletedTask : Task.Delay(1)) : - new ValueTask(sync ? ManualResetValueTaskSource.Completed(0, null) : ManualResetValueTaskSource.Delay(42, 0, null)); + new ValueTask(sync ? ManualResetValueTaskSource.Completed(0, null) : ManualResetValueTaskSource.Delay(42, 0, null), 0); var mres = new ManualResetEventSlim(); t.ConfigureAwait(continueOnCapturedContext).GetAwaiter().OnCompleted(() => mres.Set()); @@ -881,7 +881,7 @@ await Task.Run(() => ValueTask t = mode == CtorMode.Result ? new ValueTask(42) : mode == CtorMode.Task ? new ValueTask(sync ? Task.FromResult(42) : Task.Delay(1).ContinueWith(_ => 42)) : - new ValueTask(sync ? ManualResetValueTaskSource.Completed(42, null) : ManualResetValueTaskSource.Delay(1, 42, null)); + new ValueTask(sync ? ManualResetValueTaskSource.Completed(42, null) : ManualResetValueTaskSource.Delay(1, 42, null), 0); var mres = new ManualResetEventSlim(); t.ConfigureAwait(continueOnCapturedContext).GetAwaiter().OnCompleted(() => mres.Set()); @@ -929,7 +929,7 @@ public void NonGeneric_GetHashCode_FromObject_MatchesObjectHashCode(CtorMode mod else { var t = ManualResetValueTaskSource.Completed(42, null); - vt = new ValueTask(t); + vt = new ValueTask(t, 0); obj = t; } @@ -952,7 +952,7 @@ public void Generic_GetHashCode_FromObject_MatchesObjectHashCode(CtorMode mode) else { ManualResetValueTaskSource t = ManualResetValueTaskSource.Completed(42, null); - vt = new ValueTask(t); + vt = new ValueTask(t, 0); obj = t; } @@ -970,11 +970,12 @@ public void NonGeneric_OperatorEquals() Assert.True(new ValueTask() == new ValueTask()); Assert.True(new ValueTask(Task.CompletedTask) == new ValueTask(Task.CompletedTask)); Assert.True(new ValueTask(completedTcs.Task) == new ValueTask(completedTcs.Task)); - Assert.True(new ValueTask(completedVts) == new ValueTask(completedVts)); + Assert.True(new ValueTask(completedVts, 0) == new ValueTask(completedVts, 0)); Assert.False(new ValueTask(Task.CompletedTask) == new ValueTask(completedTcs.Task)); - Assert.False(new ValueTask(Task.CompletedTask) == new ValueTask(completedVts)); - Assert.False(new ValueTask(completedTcs.Task) == new ValueTask(completedVts)); + Assert.False(new ValueTask(Task.CompletedTask) == new ValueTask(completedVts, 0)); + Assert.False(new ValueTask(completedTcs.Task) == new ValueTask(completedVts, 0)); + Assert.False(new ValueTask(completedVts, 17) == new ValueTask(completedVts, 18)); } [Fact] @@ -985,7 +986,7 @@ public void Generic_OperatorEquals() Assert.True(new ValueTask(42) == new ValueTask(42)); Assert.True(new ValueTask(completedTask) == new ValueTask(completedTask)); - Assert.True(new ValueTask(completedVts) == new ValueTask(completedVts)); + Assert.True(new ValueTask(completedVts, 17) == new ValueTask(completedVts, 17)); Assert.True(new ValueTask("42") == new ValueTask("42")); Assert.True(new ValueTask((string)null) == new ValueTask((string)null)); @@ -996,8 +997,9 @@ public void Generic_OperatorEquals() Assert.False(new ValueTask(42) == new ValueTask(Task.FromResult(42))); Assert.False(new ValueTask(Task.FromResult(42)) == new ValueTask(42)); - Assert.False(new ValueTask(ManualResetValueTaskSource.Completed(42, null)) == new ValueTask(42)); - Assert.False(new ValueTask(completedTask) == new ValueTask(completedVts)); + Assert.False(new ValueTask(ManualResetValueTaskSource.Completed(42, null), 0) == new ValueTask(42)); + Assert.False(new ValueTask(completedTask) == new ValueTask(completedVts, 0)); + Assert.False(new ValueTask(completedVts, 17) == new ValueTask(completedVts, 18)); } [Fact] @@ -1011,11 +1013,12 @@ public void NonGeneric_OperatorNotEquals() Assert.False(new ValueTask() != new ValueTask()); Assert.False(new ValueTask(Task.CompletedTask) != new ValueTask(Task.CompletedTask)); Assert.False(new ValueTask(completedTcs.Task) != new ValueTask(completedTcs.Task)); - Assert.False(new ValueTask(completedVts) != new ValueTask(completedVts)); + Assert.False(new ValueTask(completedVts, 0) != new ValueTask(completedVts, 0)); Assert.True(new ValueTask(Task.CompletedTask) != new ValueTask(completedTcs.Task)); - Assert.True(new ValueTask(Task.CompletedTask) != new ValueTask(completedVts)); - Assert.True(new ValueTask(completedTcs.Task) != new ValueTask(completedVts)); + Assert.True(new ValueTask(Task.CompletedTask) != new ValueTask(completedVts, 0)); + Assert.True(new ValueTask(completedTcs.Task) != new ValueTask(completedVts, 0)); + Assert.True(new ValueTask(completedVts, 17) != new ValueTask(completedVts, 18)); } [Fact] @@ -1026,7 +1029,7 @@ public void Generic_OperatorNotEquals() Assert.False(new ValueTask(42) != new ValueTask(42)); Assert.False(new ValueTask(completedTask) != new ValueTask(completedTask)); - Assert.False(new ValueTask(completedVts) != new ValueTask(completedVts)); + Assert.False(new ValueTask(completedVts, 0) != new ValueTask(completedVts, 0)); Assert.False(new ValueTask("42") != new ValueTask("42")); Assert.False(new ValueTask((string)null) != new ValueTask((string)null)); @@ -1037,8 +1040,9 @@ public void Generic_OperatorNotEquals() Assert.True(new ValueTask(42) != new ValueTask(Task.FromResult(42))); Assert.True(new ValueTask(Task.FromResult(42)) != new ValueTask(42)); - Assert.True(new ValueTask(ManualResetValueTaskSource.Completed(42, null)) != new ValueTask(42)); - Assert.True(new ValueTask(completedTask) != new ValueTask(completedVts)); + Assert.True(new ValueTask(ManualResetValueTaskSource.Completed(42, null), 0) != new ValueTask(42)); + Assert.True(new ValueTask(completedTask) != new ValueTask(completedVts, 0)); + Assert.True(new ValueTask(completedVts, 17) != new ValueTask(completedVts, 18)); } [Fact] @@ -1048,10 +1052,10 @@ public void NonGeneric_Equals_ValueTask() Assert.False(new ValueTask().Equals(new ValueTask(Task.CompletedTask))); Assert.False(new ValueTask(Task.CompletedTask).Equals(new ValueTask())); - Assert.False(new ValueTask(ManualResetValueTaskSource.Completed(42, null)).Equals(new ValueTask())); - Assert.False(new ValueTask().Equals(new ValueTask(ManualResetValueTaskSource.Completed(42, null)))); - Assert.False(new ValueTask(Task.CompletedTask).Equals(new ValueTask(ManualResetValueTaskSource.Completed(42, null)))); - Assert.False(new ValueTask(ManualResetValueTaskSource.Completed(42, null)).Equals(new ValueTask(Task.CompletedTask))); + Assert.False(new ValueTask(ManualResetValueTaskSource.Completed(42, null), 0).Equals(new ValueTask())); + Assert.False(new ValueTask().Equals(new ValueTask(ManualResetValueTaskSource.Completed(42, null), 0))); + Assert.False(new ValueTask(Task.CompletedTask).Equals(new ValueTask(ManualResetValueTaskSource.Completed(42, null), 0))); + Assert.False(new ValueTask(ManualResetValueTaskSource.Completed(42, null), 0).Equals(new ValueTask(Task.CompletedTask))); } [Fact] @@ -1068,7 +1072,7 @@ public void Generic_Equals_ValueTask() Assert.False(new ValueTask(42).Equals(new ValueTask(Task.FromResult(42)))); Assert.False(new ValueTask(Task.FromResult(42)).Equals(new ValueTask(42))); - Assert.False(new ValueTask(ManualResetValueTaskSource.Completed(42, null)).Equals(new ValueTask(42))); + Assert.False(new ValueTask(ManualResetValueTaskSource.Completed(42, null), 0).Equals(new ValueTask(42))); } [Fact] @@ -1078,15 +1082,15 @@ public void NonGeneric_Equals_Object() Assert.False(new ValueTask().Equals((object)new ValueTask(Task.CompletedTask))); Assert.False(new ValueTask(Task.CompletedTask).Equals((object)new ValueTask())); - Assert.False(new ValueTask(ManualResetValueTaskSource.Completed(42, null)).Equals((object)new ValueTask())); - Assert.False(new ValueTask().Equals((object)new ValueTask(ManualResetValueTaskSource.Completed(42, null)))); - Assert.False(new ValueTask(Task.CompletedTask).Equals((object)new ValueTask(ManualResetValueTaskSource.Completed(42, null)))); - Assert.False(new ValueTask(ManualResetValueTaskSource.Completed(42, null)).Equals((object)new ValueTask(Task.CompletedTask))); + Assert.False(new ValueTask(ManualResetValueTaskSource.Completed(42, null), 0).Equals((object)new ValueTask())); + Assert.False(new ValueTask().Equals((object)new ValueTask(ManualResetValueTaskSource.Completed(42, null), 0))); + Assert.False(new ValueTask(Task.CompletedTask).Equals((object)new ValueTask(ManualResetValueTaskSource.Completed(42, null), 0))); + Assert.False(new ValueTask(ManualResetValueTaskSource.Completed(42, null), 0).Equals((object)new ValueTask(Task.CompletedTask))); Assert.False(new ValueTask().Equals(null)); Assert.False(new ValueTask().Equals("12345")); Assert.False(new ValueTask(Task.CompletedTask).Equals("12345")); - Assert.False(new ValueTask(ManualResetValueTaskSource.Completed(42, null)).Equals("12345")); + Assert.False(new ValueTask(ManualResetValueTaskSource.Completed(42, null), 0).Equals("12345")); } [Fact] @@ -1103,7 +1107,7 @@ public void Generic_Equals_Object() Assert.False(new ValueTask(42).Equals((object)new ValueTask(Task.FromResult(42)))); Assert.False(new ValueTask(Task.FromResult(42)).Equals((object)new ValueTask(42))); - Assert.False(new ValueTask(ManualResetValueTaskSource.Completed(42, null)).Equals((object)new ValueTask(42))); + Assert.False(new ValueTask(ManualResetValueTaskSource.Completed(42, null), 0).Equals((object)new ValueTask(42))); Assert.False(new ValueTask(42).Equals((object)null)); Assert.False(new ValueTask(42).Equals(new object())); @@ -1115,7 +1119,7 @@ public void NonGeneric_ToString_Success() { Assert.Equal("System.Threading.Tasks.ValueTask", new ValueTask().ToString()); Assert.Equal("System.Threading.Tasks.ValueTask", new ValueTask(Task.CompletedTask).ToString()); - Assert.Equal("System.Threading.Tasks.ValueTask", new ValueTask(ManualResetValueTaskSource.Completed(42, null)).ToString()); + Assert.Equal("System.Threading.Tasks.ValueTask", new ValueTask(ManualResetValueTaskSource.Completed(42, null), 0).ToString()); } [Fact] @@ -1123,19 +1127,19 @@ public void Generic_ToString_Success() { Assert.Equal("Hello", new ValueTask("Hello").ToString()); Assert.Equal("Hello", new ValueTask(Task.FromResult("Hello")).ToString()); - Assert.Equal("Hello", new ValueTask(ManualResetValueTaskSource.Completed("Hello", null)).ToString()); + Assert.Equal("Hello", new ValueTask(ManualResetValueTaskSource.Completed("Hello", null), 0).ToString()); Assert.Equal("42", new ValueTask(42).ToString()); Assert.Equal("42", new ValueTask(Task.FromResult(42)).ToString()); - Assert.Equal("42", new ValueTask(ManualResetValueTaskSource.Completed(42, null)).ToString()); + Assert.Equal("42", new ValueTask(ManualResetValueTaskSource.Completed(42, null), 0).ToString()); Assert.Same(string.Empty, new ValueTask(string.Empty).ToString()); Assert.Same(string.Empty, new ValueTask(Task.FromResult(string.Empty)).ToString()); - Assert.Same(string.Empty, new ValueTask(ManualResetValueTaskSource.Completed(string.Empty, null)).ToString()); + Assert.Same(string.Empty, new ValueTask(ManualResetValueTaskSource.Completed(string.Empty, null), 0).ToString()); Assert.Same(string.Empty, new ValueTask(Task.FromException(new InvalidOperationException())).ToString()); Assert.Same(string.Empty, new ValueTask(Task.FromException(new OperationCanceledException())).ToString()); - Assert.Same(string.Empty, new ValueTask(ManualResetValueTaskSource.Completed(null, new InvalidOperationException())).ToString()); + Assert.Same(string.Empty, new ValueTask(ManualResetValueTaskSource.Completed(null, new InvalidOperationException()), 0).ToString()); Assert.Same(string.Empty, new ValueTask(Task.FromCanceled(new CancellationToken(true))).ToString()); @@ -1143,7 +1147,7 @@ public void Generic_ToString_Success() Assert.Same(string.Empty, default(ValueTask).ToString()); Assert.Same(string.Empty, new ValueTask((string)null).ToString()); Assert.Same(string.Empty, new ValueTask(Task.FromResult(null)).ToString()); - Assert.Same(string.Empty, new ValueTask(ManualResetValueTaskSource.Completed(null, null)).ToString()); + Assert.Same(string.Empty, new ValueTask(ManualResetValueTaskSource.Completed(null, null), 0).ToString()); Assert.Same(string.Empty, new ValueTask(new TaskCompletionSource().Task).ToString()); } @@ -1179,13 +1183,13 @@ public void NonGeneric_AsTask_ValueTaskSourcePassesInvalidStateToOnCompleted_Thr { void Validate(IValueTaskSource vts) { - var vt = new ValueTask(vts); + var vt = new ValueTask(vts, 0); Assert.Throws(() => { vt.AsTask(); }); } - Validate(new DelegateValueTaskSource { OnCompletedFunc = (continuation, state, flags) => continuation(null) }); - Validate(new DelegateValueTaskSource { OnCompletedFunc = (continuation, state, flags) => continuation(new object()) }); - Validate(new DelegateValueTaskSource { OnCompletedFunc = (continuation, state, flags) => { continuation(state); continuation(state); } }); + Validate(new DelegateValueTaskSource { OnCompletedFunc = (continuation, state, token, flags) => continuation(null) }); + Validate(new DelegateValueTaskSource { OnCompletedFunc = (continuation, state, token, flags) => continuation(new object()) }); + Validate(new DelegateValueTaskSource { OnCompletedFunc = (continuation, state, token, flags) => { continuation(state); continuation(state); } }); } [Fact] @@ -1193,13 +1197,13 @@ public void Generic_AsTask_ValueTaskSourcePassesInvalidStateToOnCompleted_Throws { void Validate(IValueTaskSource vts) { - var vt = new ValueTask(vts); + var vt = new ValueTask(vts, 0); Assert.Throws(() => { vt.AsTask(); }); } - Validate(new DelegateValueTaskSource { OnCompletedFunc = (continuation, state, flags) => continuation(null) }); - Validate(new DelegateValueTaskSource { OnCompletedFunc = (continuation, state, flags) => continuation(new object()) }); - Validate(new DelegateValueTaskSource { OnCompletedFunc = (continuation, state, flags) => { continuation(state); continuation(state); } }); + Validate(new DelegateValueTaskSource { OnCompletedFunc = (continuation, state, token, flags) => continuation(null) }); + Validate(new DelegateValueTaskSource { OnCompletedFunc = (continuation, state, token, flags) => continuation(new object()) }); + Validate(new DelegateValueTaskSource { OnCompletedFunc = (continuation, state, token, flags) => { continuation(state); continuation(state); } }); } [Fact] @@ -1207,7 +1211,7 @@ public void NonGeneric_OnCompleted_ValueTaskSourcePassesInvalidStateToOnComplete { void Validate(IValueTaskSource vts) { - var vt = new ValueTask(vts); + var vt = new ValueTask(vts, 0); Assert.Throws(() => vt.GetAwaiter().OnCompleted(() => { })); Assert.Throws(() => vt.GetAwaiter().UnsafeOnCompleted(() => { })); foreach (bool continueOnCapturedContext in new[] { true, false }) @@ -1217,8 +1221,8 @@ void Validate(IValueTaskSource vts) } } - Validate(new DelegateValueTaskSource { OnCompletedFunc = (continuation, state, flags) => continuation(null) }); - Validate(new DelegateValueTaskSource { OnCompletedFunc = (continuation, state, flags) => continuation(new object()) }); + Validate(new DelegateValueTaskSource { OnCompletedFunc = (continuation, state, token, flags) => continuation(null) }); + Validate(new DelegateValueTaskSource { OnCompletedFunc = (continuation, state, token, flags) => continuation(new object()) }); } [Fact] @@ -1226,7 +1230,7 @@ public void Generic_OnCompleted_ValueTaskSourcePassesInvalidStateToOnCompleted_T { void Validate(IValueTaskSource vts) { - var vt = new ValueTask(vts); + var vt = new ValueTask(vts, 0); Assert.Throws(() => vt.GetAwaiter().OnCompleted(() => { })); Assert.Throws(() => vt.GetAwaiter().UnsafeOnCompleted(() => { })); foreach (bool continueOnCapturedContext in new[] { true, false }) @@ -1236,24 +1240,24 @@ void Validate(IValueTaskSource vts) } } - Validate(new DelegateValueTaskSource { OnCompletedFunc = (continuation, state, flags) => continuation(null) }); - Validate(new DelegateValueTaskSource { OnCompletedFunc = (continuation, state, flags) => continuation(new object()) }); + Validate(new DelegateValueTaskSource { OnCompletedFunc = (continuation, state, token, flags) => continuation(null) }); + Validate(new DelegateValueTaskSource { OnCompletedFunc = (continuation, state, token, flags) => continuation(new object()) }); } private sealed class DelegateValueTaskSource : IValueTaskSource, IValueTaskSource { - public Func StatusFunc = null; - public Action GetResultAction = null; - public Func GetResultFunc = null; - public Action, object, ValueTaskSourceOnCompletedFlags> OnCompletedFunc; + public Func GetStatusFunc = null; + public Action GetResultAction = null; + public Func GetResultFunc = null; + public Action, object, short, ValueTaskSourceOnCompletedFlags> OnCompletedFunc; - public ValueTaskSourceStatus Status => StatusFunc?.Invoke() ?? ValueTaskSourceStatus.Pending; + public ValueTaskSourceStatus GetStatus(short token) => GetStatusFunc?.Invoke(token) ?? ValueTaskSourceStatus.Pending; - public void GetResult() => GetResultAction?.Invoke(); - T IValueTaskSource.GetResult() => GetResultFunc != null ? GetResultFunc() : default; + public void GetResult(short token) => GetResultAction?.Invoke(token); + T IValueTaskSource.GetResult(short token) => GetResultFunc != null ? GetResultFunc(token) : default; - public void OnCompleted(Action continuation, object state, ValueTaskSourceOnCompletedFlags flags) => - OnCompletedFunc?.Invoke(continuation, state, flags); + public void OnCompleted(Action continuation, object state, short token, ValueTaskSourceOnCompletedFlags flags) => + OnCompletedFunc?.Invoke(continuation, state, token, flags); } private sealed class TrackingSynchronizationContext : SynchronizationContext