From b52bb273af944af68f41d052a764fa939c1115fe Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Wed, 28 Feb 2018 01:21:02 -0500 Subject: [PATCH] Expose and roll out ValueTask extensibility This commit does several things: - Exposes the new `ValueTask` extensibility model being added in coreclr. The ValueTask-related files will separately be mirrored over to corefx to enable the netstandard build of System.Threading.Tasks.Extensions. - Adapts all `Stream`-derived types to return `ValueTask` instead of `Task` from `WriteAsync`. - Changes the new `WebSocket` `SendAsync` method to return `ValueTask` instead of `Task`, and updates the `ManagedWebSocket` implementation accordingly. Most `SendAsync`s on `ManagedWebSocket` should now return a `ValueTask` that's either completed synchronously (no allocation) or using a pooled object. It now uses the underlying transport's new `WriteAsync` overload that returns `ValueTask`. - Switches more uses of `ReadAsync` and `WriteAsync` over to the new overloads, including in Process, DeflateStream, BrotliStream, File, HttpClient, SslStream, WebClient, BufferedStream, CryptoStream, - Removed some unnecessary array clearing from various routines using ArrayPool (after the clearing was added we changed our minds and decided clearing was only necessary in very specific circumstances) - Implements a custom `IValueTaskSource` in Socket, such that async receives and sends become allocation-free (ammortized). `NetworkStream` then inherits this functionality, such that its new `ReadAsync` and `WriteAsync` are also allocation-free (in the unbounded channel implementations; we can subsequently add it in for bounded). - Implements a custom `IValueTaskSource` in System.Threading.Channels, such that reading and writing are ammortized allocation-free up to one concurrent reader and writer. - A few random things I noticed as I was going through, e.g. missing ConfigureAwait, some incorrect synchronization in tests, etc. - Adds a ton of new tests, mainly in System.Threading.Tasks.Extensions, System.Threading.Channels, and System.Net.Sockets. --- src/Common/src/System/IO/DelegatingStream.cs | 2 +- .../src/System/IO/ReadOnlyMemoryStream.cs | 2 +- .../System/Net/WebSockets/ManagedWebSocket.cs | 99 +- .../CompressionStreamUnitTestBase.cs | 2 +- .../System/Net/Configuration.Certificates.cs | 5 +- .../System/Diagnostics/AsyncStreamReader.cs | 2 +- .../dec/BrotliStream.Decompress.cs | 2 +- .../Compression/enc/BrotliStream.Compress.cs | 10 +- .../Compression/DeflateZLib/DeflateStream.cs | 24 +- .../src/System/IO/Compression/GZipStream.cs | 2 +- ...ervingWriteOnlyStreamWrapper.netcoreapp.cs | 2 +- .../CompressionStreamUnitTests.Deflate.cs | 2 +- .../tests/CompressionStreamUnitTests.Gzip.cs | 2 +- .../src/System/IO/File.cs | 10 +- .../IsolatedStorageFileStream.cs | 10 + .../src/System/IO/Pipes/PipeStream.cs | 8 +- .../tests/PipeTest.Read.netcoreapp.cs | 4 +- .../BufferedStreamTests.netcoreapp.cs | 2 +- .../src/System/Net/Http/HttpContent.cs | 20 +- .../src/System/Net/Http/MultipartContent.cs | 6 +- .../System/Net/Http/ReadOnlyMemoryContent.cs | 4 +- .../ChunkedEncodingReadStream.cs | 2 +- .../ChunkedEncodingWriteStream.cs | 11 +- .../ContentLengthWriteStream.cs | 4 +- .../Http/SocketsHttpHandler/HttpConnection.cs | 24 +- .../HttpContentDuplexStream.cs | 2 +- .../HttpContentReadStream.cs | 2 +- .../HttpContentWriteStream.cs | 4 +- .../SocketsHttpHandler/RawConnectionStream.cs | 18 +- .../src/System/Net/Http/StreamContent.cs | 2 +- .../FunctionalTests/HttpClientHandlerTest.cs | 4 +- .../ReadOnlyMemoryContentTest.cs | 4 +- .../src/System/Net/FixedSizeReader.cs | 2 +- .../src/System/Net/Security/SslStream.cs | 2 +- .../Security/SslStreamInternal.Adapters.cs | 8 +- .../System/Net/Security/SslStreamInternal.cs | 34 +- .../Fakes/FakeAuthenticatedStream.cs | 2 +- .../tests/UnitTests/Fakes/FakeSslState.cs | 2 +- .../src/System.Net.Sockets.csproj | 1 + .../src/System/Net/Sockets/NetworkStream.cs | 230 +--- .../src/System/Net/Sockets/Socket.Tasks.cs | 434 +++++-- .../Net/Sockets/SocketTaskExtensions.cs | 4 +- .../FunctionalTests/NetworkStreamTest.cs | 15 + .../NetworkStreamTest.netcoreapp.cs | 297 ++++- .../FunctionalTests/SendReceive.netcoreapp.cs | 50 +- .../System.Net.Sockets.Tests.csproj | 5 +- .../FunctionalTests/UnixDomainSocketTest.cs | 1 + .../src/System/Net/WebClient.cs | 14 +- .../System/Net/WebSockets/ClientWebSocket.cs | 2 +- .../Net/WebSockets/WebSocketHandle.Managed.cs | 2 +- .../Net/WebSockets/WebSocketHandle.Windows.cs | 4 +- .../tests/SendReceiveTest.netcoreapp.cs | 7 +- .../WebSockets/ManagedWebSocketExtensions.cs | 32 +- .../ref/System.Net.WebSockets.cs | 2 +- .../WebSockets/ManagedWebSocket.netcoreapp.cs | 4 +- .../src/System/Net/WebSockets/WebSocket.cs | 6 +- .../ref/System.Runtime.Extensions.cs | 2 +- .../src/System/IO/BufferedStream.cs | 22 +- src/System.Runtime/ref/System.Runtime.cs | 83 +- .../Security/Cryptography/CryptoStream.cs | 14 +- .../ref/System.Threading.Channels.cs | 6 +- .../src/Configurations.props | 1 + .../src/Resources/Strings.resx | 6 + .../src/System.Threading.Channels.csproj | 3 +- .../Threading/Channels/AsyncOperation.cs | 342 ++++++ .../Threading/Channels/BoundedChannel.cs | 115 +- .../Threading/Channels/ChannelReader.cs | 2 +- .../Threading/Channels/ChannelUtilities.cs | 85 +- .../Threading/Channels/ChannelWriter.cs | 28 +- .../System/Threading/Channels/Interactor.cs | 149 --- .../SingleConsumerUnboundedChannel.cs | 174 +-- .../Threading/Channels/UnboundedChannel.cs | 130 +- .../tests/BoundedChannelTests.cs | 28 +- .../tests/ChannelTestBase.cs | 376 +++++- .../tests/ChannelTests.cs | 26 +- .../tests/Performance/Perf.Channel.cs | 28 +- .../System.Threading.Channels.Tests.csproj | 3 + .../tests/TestBase.cs | 8 + .../tests/UnboundedChannelTests.cs | 26 +- .../ref/System.Threading.Tasks.Extensions.cs | 81 ++ .../System.Threading.Tasks.Extensions.csproj | 4 + .../src/System/ThrowHelper.cs | 10 +- .../tests/AsyncMethodBuilderAttributeTests.cs | 1 + .../tests/AsyncValueTaskMethodBuilderTests.cs | 310 ++++- .../tests/Configurations.props | 1 + .../tests/ManualResetValueTaskSource.cs | 170 +++ ...em.Threading.Tasks.Extensions.Tests.csproj | 3 +- .../tests/ValueTaskTests.cs | 1068 +++++++++++++++-- 88 files changed, 3633 insertions(+), 1129 deletions(-) create mode 100644 src/System.Threading.Channels/src/System/Threading/Channels/AsyncOperation.cs delete mode 100644 src/System.Threading.Channels/src/System/Threading/Channels/Interactor.cs create mode 100644 src/System.Threading.Tasks.Extensions/tests/ManualResetValueTaskSource.cs diff --git a/src/Common/src/System/IO/DelegatingStream.cs b/src/Common/src/System/IO/DelegatingStream.cs index 3bb864887fa8..39d7aab060c8 100644 --- a/src/Common/src/System/IO/DelegatingStream.cs +++ b/src/Common/src/System/IO/DelegatingStream.cs @@ -157,7 +157,7 @@ public override Task WriteAsync(byte[] buffer, int offset, int count, Cancellati return _innerStream.WriteAsync(buffer, offset, count, cancellationToken); } - public override Task WriteAsync(ReadOnlyMemory source, CancellationToken cancellationToken = default) + public override ValueTask WriteAsync(ReadOnlyMemory source, CancellationToken cancellationToken = default) { return _innerStream.WriteAsync(source, cancellationToken); } diff --git a/src/Common/src/System/IO/ReadOnlyMemoryStream.cs b/src/Common/src/System/IO/ReadOnlyMemoryStream.cs index 9c240b8e5648..026286e87fff 100644 --- a/src/Common/src/System/IO/ReadOnlyMemoryStream.cs +++ b/src/Common/src/System/IO/ReadOnlyMemoryStream.cs @@ -124,7 +124,7 @@ public override Task CopyToAsync(Stream destination, int bufferSize, Cancellatio { StreamHelpers.ValidateCopyToArgs(this, destination, bufferSize); return _content.Length > _position ? - destination.WriteAsync(_content.Slice(_position), cancellationToken) : + destination.WriteAsync(_content.Slice(_position), cancellationToken).AsTask() : Task.CompletedTask; } diff --git a/src/Common/src/System/Net/WebSockets/ManagedWebSocket.cs b/src/Common/src/System/Net/WebSockets/ManagedWebSocket.cs index 402dd7088e74..6603a793083d 100644 --- a/src/Common/src/System/Net/WebSockets/ManagedWebSocket.cs +++ b/src/Common/src/System/Net/WebSockets/ManagedWebSocket.cs @@ -88,10 +88,7 @@ public static ManagedWebSocket CreateFromConnectedStream( /// private readonly Utf8MessageState _utf8TextState = new Utf8MessageState(); /// - /// Semaphore used to ensure that calls to SendFrameAsync don't run concurrently. While - /// is used to fail if a caller tries to issue another SendAsync while a previous one is running, internally - /// we use SendFrameAsync as an implementation detail, and it should not cause user requests to SendAsync to fail, - /// nor should such internal usage be allowed to run concurrently with other internal usage or with SendAsync. + /// Semaphore used to ensure that calls to SendFrameAsync don't run concurrently. /// private readonly SemaphoreSlim _sendFrameAsyncLock = new SemaphoreSlim(1, 1); @@ -145,15 +142,10 @@ public static ManagedWebSocket CreateFromConnectedStream( /// private bool _lastSendWasFragment; /// - /// The task returned from the last SendAsync operation to not complete synchronously. - /// If this is not null and not completed when a subsequent SendAsync is issued, an exception occurs. - /// - private Task _lastSendAsync; - /// - /// The task returned from the last ReceiveAsync operation to not complete synchronously. + /// The task returned from the last ReceiveAsync(ArraySegment, ...) operation to not complete synchronously. /// If this is not null and not completed when a subsequent ReceiveAsync is issued, an exception occurs. /// - private Task _lastReceiveAsync; + private Task _lastReceiveAsync = Task.CompletedTask; /// Lock used to protect update and check-and-update operations on _state. private object StateUpdateLock => _abortSource; @@ -262,10 +254,10 @@ public override Task SendAsync(ArraySegment buffer, WebSocketMessageType m WebSocketValidate.ValidateArraySegment(buffer, nameof(buffer)); - return SendPrivateAsync((ReadOnlyMemory)buffer, messageType, endOfMessage, cancellationToken); + return SendPrivateAsync((ReadOnlyMemory)buffer, messageType, endOfMessage, cancellationToken).AsTask(); } - private Task SendPrivateAsync(ReadOnlyMemory buffer, WebSocketMessageType messageType, bool endOfMessage, CancellationToken cancellationToken) + private ValueTask SendPrivateAsync(ReadOnlyMemory buffer, WebSocketMessageType messageType, bool endOfMessage, CancellationToken cancellationToken) { if (messageType != WebSocketMessageType.Text && messageType != WebSocketMessageType.Binary) { @@ -278,11 +270,10 @@ private Task SendPrivateAsync(ReadOnlyMemory buffer, WebSocketMessageType try { WebSocketValidate.ThrowIfInvalidState(_state, _disposed, s_validSendStates); - ThrowIfOperationInProgress(_lastSendAsync); } catch (Exception exc) { - return Task.FromException(exc); + return new ValueTask(Task.FromException(exc)); } MessageOpcode opcode = @@ -290,9 +281,8 @@ private Task SendPrivateAsync(ReadOnlyMemory buffer, WebSocketMessageType messageType == WebSocketMessageType.Binary ? MessageOpcode.Binary : MessageOpcode.Text; - Task t = SendFrameAsync(opcode, endOfMessage, buffer, cancellationToken); + ValueTask t = SendFrameAsync(opcode, endOfMessage, buffer, cancellationToken); _lastSendWasFragment = !endOfMessage; - _lastSendAsync = t; return t; } @@ -307,7 +297,7 @@ public override Task ReceiveAsync(ArraySegment buf Debug.Assert(!Monitor.IsEntered(StateUpdateLock), $"{nameof(StateUpdateLock)} must never be held when acquiring {nameof(ReceiveAsyncLock)}"); lock (ReceiveAsyncLock) // synchronize with receives in CloseAsync { - ThrowIfOperationInProgress(_lastReceiveAsync); + ThrowIfOperationInProgress(_lastReceiveAsync.IsCompleted); Task t = ReceiveAsyncPrivate(buffer, cancellationToken).AsTask(); _lastReceiveAsync = t; return t; @@ -362,23 +352,14 @@ public override void Abort() /// The value of the FIN bit for the message. /// The buffer containing the payload data fro the message. /// The CancellationToken to use to cancel the websocket. - private Task SendFrameAsync(MessageOpcode opcode, bool endOfMessage, ReadOnlyMemory payloadBuffer, CancellationToken cancellationToken) + private ValueTask SendFrameAsync(MessageOpcode opcode, bool endOfMessage, ReadOnlyMemory payloadBuffer, CancellationToken cancellationToken) { - // TODO: #4900 SendFrameAsync should in theory typically complete synchronously, making it fast and allocation free. - // However, due to #4900, it almost always yields, resulting in all of the allocations involved in an async method - // yielding, e.g. the boxed state machine, the Action delegate, the MoveNextRunner, and the resulting Task, plus it's - // common that the awaited operation completes so fast after the await that we may end up allocating an AwaitTaskContinuation - // inside of the TaskAwaiter. Since SendFrameAsync is such a core code path, until that can be fixed, we put some - // optimizations in place to avoid a few of those expenses, at the expense of more complicated code; for the common case, - // this code has fewer than half the number and size of allocations. If/when that issue is fixed, this method should be deleted - // and replaced by SendFrameFallbackAsync, which is the same logic but in a much more easily understand flow. - // If a cancelable cancellation token was provided, that would require registering with it, which means more state we have to // pass around (the CancellationTokenRegistration), so if it is cancelable, just immediately go to the fallback path. // Similarly, it should be rare that there are multiple outstanding calls to SendFrameAsync, but if there are, again // fall back to the fallback path. return cancellationToken.CanBeCanceled || !_sendFrameAsyncLock.Wait(0) ? - SendFrameFallbackAsync(opcode, endOfMessage, payloadBuffer, cancellationToken) : + new ValueTask(SendFrameFallbackAsync(opcode, endOfMessage, payloadBuffer, cancellationToken)) : SendFrameLockAcquiredNonCancelableAsync(opcode, endOfMessage, payloadBuffer); } @@ -386,19 +367,19 @@ private Task SendFrameAsync(MessageOpcode opcode, bool endOfMessage, ReadOnlyMem /// The opcode for the message. /// The value of the FIN bit for the message. /// The buffer containing the payload data fro the message. - private Task SendFrameLockAcquiredNonCancelableAsync(MessageOpcode opcode, bool endOfMessage, ReadOnlyMemory payloadBuffer) + private ValueTask SendFrameLockAcquiredNonCancelableAsync(MessageOpcode opcode, bool endOfMessage, ReadOnlyMemory payloadBuffer) { Debug.Assert(_sendFrameAsyncLock.CurrentCount == 0, "Caller should hold the _sendFrameAsyncLock"); // If we get here, the cancellation token is not cancelable so we don't have to worry about it, // and we own the semaphore, so we don't need to asynchronously wait for it. - Task writeTask = null; + ValueTask writeTask = default; bool releaseSemaphoreAndSendBuffer = true; try { // Write the payload synchronously to the buffer, then write that buffer out to the network. int sendBytes = WriteFrameToSendBuffer(opcode, endOfMessage, payloadBuffer.Span); - writeTask = _stream.WriteAsync(_sendBuffer, 0, sendBytes, CancellationToken.None); + writeTask = _stream.WriteAsync(new ReadOnlyMemory(_sendBuffer, 0, sendBytes)); // If the operation happens to complete synchronously (or, more specifically, by // the time we get from the previous line to here), release the semaphore, return @@ -415,10 +396,10 @@ private Task SendFrameLockAcquiredNonCancelableAsync(MessageOpcode opcode, bool } catch (Exception exc) { - return Task.FromException( + return new ValueTask(Task.FromException( exc is OperationCanceledException ? exc : _state == WebSocketState.Aborted ? CreateOperationCanceledException(exc) : - new WebSocketException(WebSocketError.ConnectionClosedPrematurely, exc)); + new WebSocketException(WebSocketError.ConnectionClosedPrematurely, exc))); } finally { @@ -429,22 +410,26 @@ private Task SendFrameLockAcquiredNonCancelableAsync(MessageOpcode opcode, bool } } - // The write was not yet completed. Create and return a continuation that will - // release the semaphore and translate any exception that occurred. - return writeTask.ContinueWith((t, s) => - { - var thisRef = (ManagedWebSocket)s; - thisRef._sendFrameAsyncLock.Release(); - thisRef.ReleaseSendBuffer(); + return new ValueTask(WaitForWriteTaskAsync(writeTask)); + } - try { t.GetAwaiter().GetResult(); } - catch (Exception exc) when (!(exc is OperationCanceledException)) - { - throw thisRef._state == WebSocketState.Aborted ? - CreateOperationCanceledException(exc) : - new WebSocketException(WebSocketError.ConnectionClosedPrematurely, exc); - } - }, this, CancellationToken.None, TaskContinuationOptions.ExecuteSynchronously, TaskScheduler.Default); + private async Task WaitForWriteTaskAsync(ValueTask writeTask) + { + try + { + await writeTask.ConfigureAwait(false); + } + catch (Exception exc) when (!(exc is OperationCanceledException)) + { + throw _state == WebSocketState.Aborted ? + CreateOperationCanceledException(exc) : + new WebSocketException(WebSocketError.ConnectionClosedPrematurely, exc); + } + finally + { + _sendFrameAsyncLock.Release(); + ReleaseSendBuffer(); + } } private async Task SendFrameFallbackAsync(MessageOpcode opcode, bool endOfMessage, ReadOnlyMemory payloadBuffer, CancellationToken cancellationToken) @@ -455,7 +440,7 @@ private async Task SendFrameFallbackAsync(MessageOpcode opcode, bool endOfMessag int sendBytes = WriteFrameToSendBuffer(opcode, endOfMessage, payloadBuffer.Span); using (cancellationToken.Register(s => ((ManagedWebSocket)s).Abort(), this)) { - await _stream.WriteAsync(_sendBuffer, 0, sendBytes, cancellationToken).ConfigureAwait(false); + await _stream.WriteAsync(new ReadOnlyMemory(_sendBuffer, 0, sendBytes), cancellationToken).ConfigureAwait(false); } } catch (Exception exc) when (!(exc is OperationCanceledException)) @@ -518,12 +503,12 @@ private void SendKeepAliveFrameAsync() { // This exists purely to keep the connection alive; don't wait for the result, and ignore any failures. // The call will handle releasing the lock. - Task t = SendFrameLockAcquiredNonCancelableAsync(MessageOpcode.Ping, true, Memory.Empty); + ValueTask t = SendFrameLockAcquiredNonCancelableAsync(MessageOpcode.Ping, true, Memory.Empty); // "Observe" any exception, ignoring it to prevent the unobserved exception event from being raised. - if (t.Status != TaskStatus.RanToCompletion) + if (!t.IsCompletedSuccessfully) { - t.ContinueWith(p => { Exception ignored = p.Exception; }, + t.AsTask().ContinueWith(p => { Exception ignored = p.Exception; }, CancellationToken.None, TaskContinuationOptions.OnlyOnFaulted | TaskContinuationOptions.ExecuteSynchronously, TaskScheduler.Default); @@ -1270,15 +1255,17 @@ private static unsafe int ApplyMask(Span toMask, int mask, int maskIndex) } /// Aborts the websocket and throws an exception if an existing operation is in progress. - private void ThrowIfOperationInProgress(Task operationTask, [CallerMemberName] string methodName = null) + private void ThrowIfOperationInProgress(bool operationCompleted, [CallerMemberName] string methodName = null) { - if (operationTask != null && !operationTask.IsCompleted) + if (!operationCompleted) { Abort(); - throw new InvalidOperationException(SR.Format(SR.net_Websockets_AlreadyOneOutstandingOperation, methodName)); + ThrowOperationInProgress(methodName); } } + private void ThrowOperationInProgress(string methodName) => throw new InvalidOperationException(SR.Format(SR.net_Websockets_AlreadyOneOutstandingOperation, methodName)); + /// Creates an OperationCanceledException instance, using a default message and the specified inner exception and token. private static Exception CreateOperationCanceledException(Exception innerException, CancellationToken cancellationToken = default(CancellationToken)) { diff --git a/src/Common/tests/System/IO/Compression/CompressionStreamUnitTestBase.cs b/src/Common/tests/System/IO/Compression/CompressionStreamUnitTestBase.cs index 9014d94c5c69..0ca2ec0c55e9 100644 --- a/src/Common/tests/System/IO/Compression/CompressionStreamUnitTestBase.cs +++ b/src/Common/tests/System/IO/Compression/CompressionStreamUnitTestBase.cs @@ -1333,7 +1333,7 @@ public override async ValueTask ReadAsync(Memory destination, Cancell return await base.ReadAsync(destination, cancellationToken); } - public override async Task WriteAsync(ReadOnlyMemory source, CancellationToken cancellationToken) + public override async ValueTask WriteAsync(ReadOnlyMemory source, CancellationToken cancellationToken) { WriteHit = true; diff --git a/src/Common/tests/System/Net/Configuration.Certificates.cs b/src/Common/tests/System/Net/Configuration.Certificates.cs index 85cd06e97115..340792634202 100644 --- a/src/Common/tests/System/Net/Configuration.Certificates.cs +++ b/src/Common/tests/System/Net/Configuration.Certificates.cs @@ -19,7 +19,7 @@ public static partial class Certificates private const string CertificatePassword = "testcertificate"; private const string TestDataFolder = "TestData"; - private static Mutex m; + private static readonly Mutex m; private const int MutexTimeout = 120 * 1000; static Certificates() @@ -59,10 +59,9 @@ private static X509Certificate2Collection GetCertificateCollection(string certif { // On Windows, .NET Core applications should not import PFX files in parallel to avoid a known system-level race condition. // This bug results in corrupting the X509Certificate2 certificate state. + Assert.True(m.WaitOne(MutexTimeout), "Cannot acquire the global certificate mutex."); try { - Assert.True(m.WaitOne(MutexTimeout), "Cannot acquire the global certificate mutex."); - var certCollection = new X509Certificate2Collection(); certCollection.Import(Path.Combine(TestDataFolder, certificateFileName), CertificatePassword, X509KeyStorageFlags.DefaultKeySet); diff --git a/src/System.Diagnostics.Process/src/System/Diagnostics/AsyncStreamReader.cs b/src/System.Diagnostics.Process/src/System/Diagnostics/AsyncStreamReader.cs index 3e5ba4604484..a86002e2d1c2 100644 --- a/src/System.Diagnostics.Process/src/System/Diagnostics/AsyncStreamReader.cs +++ b/src/System.Diagnostics.Process/src/System/Diagnostics/AsyncStreamReader.cs @@ -90,7 +90,7 @@ private async Task ReadBufferAsync() { try { - int bytesRead = await _stream.ReadAsync(_byteBuffer, 0, _byteBuffer.Length, _cts.Token).ConfigureAwait(false); + int bytesRead = await _stream.ReadAsync(new Memory(_byteBuffer), _cts.Token).ConfigureAwait(false); if (bytesRead == 0) break; diff --git a/src/System.IO.Compression.Brotli/src/System/IO/Compression/dec/BrotliStream.Decompress.cs b/src/System.IO.Compression.Brotli/src/System/IO/Compression/dec/BrotliStream.Decompress.cs index 4e15f1c62531..3f20e0c30d1f 100644 --- a/src/System.IO.Compression.Brotli/src/System/IO/Compression/dec/BrotliStream.Decompress.cs +++ b/src/System.IO.Compression.Brotli/src/System/IO/Compression/dec/BrotliStream.Decompress.cs @@ -113,7 +113,7 @@ private async ValueTask FinishReadAsyncMemory(Memory destination, Can { int readBytes = 0; int iter = 0; - while (readBytes < _buffer.Length && ((iter = await _stream.ReadAsync(_buffer, readBytes, _buffer.Length - readBytes, cancellationToken).ConfigureAwait(false)) > 0)) + while (readBytes < _buffer.Length && ((iter = await _stream.ReadAsync(new Memory(_buffer, readBytes, _buffer.Length - readBytes), cancellationToken).ConfigureAwait(false)) > 0)) { readBytes += iter; if (readBytes > _buffer.Length) diff --git a/src/System.IO.Compression.Brotli/src/System/IO/Compression/enc/BrotliStream.Compress.cs b/src/System.IO.Compression.Brotli/src/System/IO/Compression/enc/BrotliStream.Compress.cs index 5f2ecd1b6b7f..7fdc5f5b41b3 100644 --- a/src/System.IO.Compression.Brotli/src/System/IO/Compression/enc/BrotliStream.Compress.cs +++ b/src/System.IO.Compression.Brotli/src/System/IO/Compression/enc/BrotliStream.Compress.cs @@ -60,19 +60,19 @@ public override void EndWrite(IAsyncResult asyncResult) => public override Task WriteAsync(byte[] array, int offset, int count, CancellationToken cancellationToken) { ValidateParameters(array, offset, count); - return WriteAsync(new ReadOnlyMemory(array, offset, count), cancellationToken); + return WriteAsync(new ReadOnlyMemory(array, offset, count), cancellationToken).AsTask(); } - public override Task WriteAsync(ReadOnlyMemory source, CancellationToken cancellationToken = default(CancellationToken)) + public override ValueTask WriteAsync(ReadOnlyMemory source, CancellationToken cancellationToken = default(CancellationToken)) { if (_mode != CompressionMode.Compress) throw new InvalidOperationException(SR.BrotliStream_Decompress_UnsupportedOperation); EnsureNoActiveAsyncOperation(); EnsureNotDisposed(); - return cancellationToken.IsCancellationRequested ? + return new ValueTask(cancellationToken.IsCancellationRequested ? Task.FromCanceled(cancellationToken) : - WriteAsyncMemoryCore(source, cancellationToken); + WriteAsyncMemoryCore(source, cancellationToken)); } private async Task WriteAsyncMemoryCore(ReadOnlyMemory source, CancellationToken cancellationToken) @@ -92,7 +92,7 @@ private async Task WriteAsyncMemoryCore(ReadOnlyMemory source, Cancellatio if (bytesConsumed > 0) source = source.Slice(bytesConsumed); if (bytesWritten > 0) - await _stream.WriteAsync(_buffer, 0, bytesWritten, cancellationToken).ConfigureAwait(false); + await _stream.WriteAsync(new ReadOnlyMemory(_buffer, 0, bytesWritten), cancellationToken).ConfigureAwait(false); } } finally diff --git a/src/System.IO.Compression/src/System/IO/Compression/DeflateZLib/DeflateStream.cs b/src/System.IO.Compression/src/System/IO/Compression/DeflateZLib/DeflateStream.cs index b0ec44a643ce..8c75cf0640b1 100644 --- a/src/System.IO.Compression/src/System/IO/Compression/DeflateZLib/DeflateStream.cs +++ b/src/System.IO.Compression/src/System/IO/Compression/DeflateZLib/DeflateStream.cs @@ -199,7 +199,7 @@ private async Task FlushAsyncCore(CancellationToken cancellationToken) flushSuccessful = _deflater.Flush(_buffer, out compressedBytes); if (flushSuccessful) { - await _stream.WriteAsync(_buffer, 0, compressedBytes, cancellationToken).ConfigureAwait(false); + await _stream.WriteAsync(new ReadOnlyMemory(_buffer, 0, compressedBytes), cancellationToken).ConfigureAwait(false); } Debug.Assert(flushSuccessful == (compressedBytes > 0)); } while (flushSuccessful); @@ -643,10 +643,10 @@ public override void EndWrite(IAsyncResult asyncResult) => public override Task WriteAsync(byte[] array, int offset, int count, CancellationToken cancellationToken) { ValidateParameters(array, offset, count); - return WriteAsyncMemory(new ReadOnlyMemory(array, offset, count), cancellationToken); + return WriteAsyncMemory(new ReadOnlyMemory(array, offset, count), cancellationToken).AsTask(); } - public override Task WriteAsync(ReadOnlyMemory source, CancellationToken cancellationToken) + public override ValueTask WriteAsync(ReadOnlyMemory source, CancellationToken cancellationToken) { if (GetType() != typeof(DeflateStream)) { @@ -660,15 +660,15 @@ public override Task WriteAsync(ReadOnlyMemory source, CancellationToken c } } - internal Task WriteAsyncMemory(ReadOnlyMemory source, CancellationToken cancellationToken) + internal ValueTask WriteAsyncMemory(ReadOnlyMemory source, CancellationToken cancellationToken) { EnsureCompressionMode(); EnsureNoActiveAsyncOperation(); EnsureNotDisposed(); - return cancellationToken.IsCancellationRequested ? + return new ValueTask(cancellationToken.IsCancellationRequested ? Task.FromCanceled(cancellationToken) : - WriteAsyncMemoryCore(source, cancellationToken); + WriteAsyncMemoryCore(source, cancellationToken)); } private async Task WriteAsyncMemoryCore(ReadOnlyMemory source, CancellationToken cancellationToken) @@ -701,7 +701,7 @@ private async Task WriteDeflaterOutputAsync(CancellationToken cancellationToken) int compressedBytes = _deflater.GetDeflateOutput(_buffer); if (compressedBytes > 0) { - await _stream.WriteAsync(_buffer, 0, compressedBytes, cancellationToken).ConfigureAwait(false); + await _stream.WriteAsync(new ReadOnlyMemory(_buffer, 0, compressedBytes), cancellationToken).ConfigureAwait(false); } } } @@ -732,7 +732,6 @@ private sealed class CopyToAsyncStream : Stream private readonly Stream _destination; private readonly CancellationToken _cancellationToken; private byte[] _arrayPoolBuffer; - private int _arrayPoolBufferHighWaterMark; public CopyToAsyncStream(DeflateStream deflateStream, Stream destination, int bufferSize, CancellationToken cancellationToken) { @@ -757,8 +756,7 @@ public async Task CopyFromSourceToDestination() int bytesRead = _deflateStream._inflater.Inflate(_arrayPoolBuffer, 0, _arrayPoolBuffer.Length); if (bytesRead > 0) { - if (bytesRead > _arrayPoolBufferHighWaterMark) _arrayPoolBufferHighWaterMark = bytesRead; - await _destination.WriteAsync(_arrayPoolBuffer, 0, bytesRead, _cancellationToken).ConfigureAwait(false); + await _destination.WriteAsync(new ReadOnlyMemory(_arrayPoolBuffer, 0, bytesRead), _cancellationToken).ConfigureAwait(false); } else break; } @@ -770,8 +768,7 @@ public async Task CopyFromSourceToDestination() { _deflateStream.AsyncOperationCompleting(); - Array.Clear(_arrayPoolBuffer, 0, _arrayPoolBufferHighWaterMark); // clear only the most we used - ArrayPool.Shared.Return(_arrayPoolBuffer, clearArray: false); + ArrayPool.Shared.Return(_arrayPoolBuffer); _arrayPoolBuffer = null; } } @@ -801,8 +798,7 @@ public override async Task WriteAsync(byte[] buffer, int offset, int count, Canc int bytesRead = _deflateStream._inflater.Inflate(_arrayPoolBuffer, 0, _arrayPoolBuffer.Length); if (bytesRead > 0) { - if (bytesRead > _arrayPoolBufferHighWaterMark) _arrayPoolBufferHighWaterMark = bytesRead; - await _destination.WriteAsync(_arrayPoolBuffer, 0, bytesRead, cancellationToken).ConfigureAwait(false); + await _destination.WriteAsync(new ReadOnlyMemory(_arrayPoolBuffer, 0, bytesRead), cancellationToken).ConfigureAwait(false); } else break; } diff --git a/src/System.IO.Compression/src/System/IO/Compression/GZipStream.cs b/src/System.IO.Compression/src/System/IO/Compression/GZipStream.cs index 4796b9b39ba2..097777373a67 100644 --- a/src/System.IO.Compression/src/System/IO/Compression/GZipStream.cs +++ b/src/System.IO.Compression/src/System/IO/Compression/GZipStream.cs @@ -180,7 +180,7 @@ public override Task WriteAsync(byte[] array, int offset, int count, Cancellatio return _deflateStream.WriteAsync(array, offset, count, cancellationToken); } - public override Task WriteAsync(ReadOnlyMemory source, CancellationToken cancellationToken = default(CancellationToken)) + public override ValueTask WriteAsync(ReadOnlyMemory source, CancellationToken cancellationToken = default(CancellationToken)) { if (GetType() != typeof(GZipStream)) { diff --git a/src/System.IO.Compression/src/System/IO/Compression/PositionPreservingWriteOnlyStreamWrapper.netcoreapp.cs b/src/System.IO.Compression/src/System/IO/Compression/PositionPreservingWriteOnlyStreamWrapper.netcoreapp.cs index 1c5aa7b6e25f..2303ef65ca23 100644 --- a/src/System.IO.Compression/src/System/IO/Compression/PositionPreservingWriteOnlyStreamWrapper.netcoreapp.cs +++ b/src/System.IO.Compression/src/System/IO/Compression/PositionPreservingWriteOnlyStreamWrapper.netcoreapp.cs @@ -15,7 +15,7 @@ public override void Write(ReadOnlySpan source) _stream.Write(source); } - public override Task WriteAsync(ReadOnlyMemory source, CancellationToken cancellationToken = default(CancellationToken)) + public override ValueTask WriteAsync(ReadOnlyMemory source, CancellationToken cancellationToken = default(CancellationToken)) { _position += source.Length; return _stream.WriteAsync(source, cancellationToken); diff --git a/src/System.IO.Compression/tests/CompressionStreamUnitTests.Deflate.cs b/src/System.IO.Compression/tests/CompressionStreamUnitTests.Deflate.cs index 39e713de65c9..4fa26680b826 100644 --- a/src/System.IO.Compression/tests/CompressionStreamUnitTests.Deflate.cs +++ b/src/System.IO.Compression/tests/CompressionStreamUnitTests.Deflate.cs @@ -72,7 +72,7 @@ public void DerivedStream_ReadWriteSpan_UsesReadWriteArray() ms.Position = 0; using (var compressor = new DerivedDeflateStream(ms, CompressionMode.Compress, leaveOpen: true)) { - compressor.WriteAsync(new ReadOnlyMemory(new byte[1])).Wait(); + compressor.WriteAsync(new ReadOnlyMemory(new byte[1])).AsTask().Wait(); Assert.True(compressor.WriteArrayInvoked); } } diff --git a/src/System.IO.Compression/tests/CompressionStreamUnitTests.Gzip.cs b/src/System.IO.Compression/tests/CompressionStreamUnitTests.Gzip.cs index 002c83cc877e..387b1a87fcb9 100644 --- a/src/System.IO.Compression/tests/CompressionStreamUnitTests.Gzip.cs +++ b/src/System.IO.Compression/tests/CompressionStreamUnitTests.Gzip.cs @@ -57,7 +57,7 @@ public void DerivedStream_ReadWriteSpan_UsesReadWriteArray() ms.Position = 0; using (var compressor = new DerivedGZipStream(ms, CompressionMode.Compress, leaveOpen: true)) { - compressor.WriteAsync(new ReadOnlyMemory(new byte[1])).Wait(); + compressor.WriteAsync(new ReadOnlyMemory(new byte[1])).AsTask().Wait(); Assert.True(compressor.WriteArrayInvoked); } } diff --git a/src/System.IO.FileSystem/src/System/IO/File.cs b/src/System.IO.FileSystem/src/System/IO/File.cs index 112117a2435b..ea9dbdb155d4 100644 --- a/src/System.IO.FileSystem/src/System/IO/File.cs +++ b/src/System.IO.FileSystem/src/System/IO/File.cs @@ -722,14 +722,13 @@ private static async Task InternalReadAllTextAsync(string path, Encoding buffer = ArrayPool.Shared.Rent(sr.CurrentEncoding.GetMaxCharCount(DefaultBufferSize)); for (;;) { - int read = await sr.ReadAsync(buffer, 0, buffer.Length).ConfigureAwait(false); + int read = await sr.ReadAsync(new Memory(buffer), cancellationToken).ConfigureAwait(false); if (read == 0) { return sb.ToString(); } sb.Append(buffer, 0, read); - cancellationToken.ThrowIfCancellationRequested(); } } finally @@ -823,7 +822,7 @@ private static async Task InternalReadAllBytesAsync(FileStream fs, int c byte[] bytes = new byte[count]; do { - int n = await fs.ReadAsync(bytes, index, count - index, cancellationToken).ConfigureAwait(false); + int n = await fs.ReadAsync(new Memory(bytes, index, count - index), cancellationToken).ConfigureAwait(false); if (n == 0) { throw Error.GetEndOfFile(); @@ -857,7 +856,7 @@ private static async Task InternalWriteAllBytesAsync(string path, byte[] bytes, using (FileStream fs = new FileStream(path, FileMode.Create, FileAccess.Write, FileShare.Read, DefaultBufferSize, FileOptions.Asynchronous | FileOptions.SequentialScan)) { - await fs.WriteAsync(bytes, 0, bytes.Length, cancellationToken).ConfigureAwait(false); + await fs.WriteAsync(new ReadOnlyMemory(bytes, 0, bytes.Length), cancellationToken).ConfigureAwait(false); await fs.FlushAsync(cancellationToken).ConfigureAwait(false); } } @@ -950,8 +949,7 @@ private static async Task InternalWriteAllTextAsync(StreamWriter sw, string cont { int batchSize = Math.Min(DefaultBufferSize, count - index); contents.CopyTo(index, buffer, 0, batchSize); - cancellationToken.ThrowIfCancellationRequested(); - await sw.WriteAsync(buffer, 0, batchSize).ConfigureAwait(false); + await sw.WriteAsync(new ReadOnlyMemory(buffer, 0, batchSize), cancellationToken).ConfigureAwait(false); index += batchSize; } diff --git a/src/System.IO.IsolatedStorage/src/System/IO/IsolatedStorage/IsolatedStorageFileStream.cs b/src/System.IO.IsolatedStorage/src/System/IO/IsolatedStorage/IsolatedStorageFileStream.cs index f68a2273b1f8..d2dd8c2e7eb0 100644 --- a/src/System.IO.IsolatedStorage/src/System/IO/IsolatedStorage/IsolatedStorageFileStream.cs +++ b/src/System.IO.IsolatedStorage/src/System/IO/IsolatedStorage/IsolatedStorageFileStream.cs @@ -247,6 +247,11 @@ public override Task ReadAsync(byte[] buffer, int offset, int count, Thread return _fs.ReadAsync(buffer, offset, count, cancellationToken); } + public override ValueTask ReadAsync(Memory destination, CancellationToken cancellationToken) + { + return _fs.ReadAsync(destination, cancellationToken); + } + public override int ReadByte() { return _fs.ReadByte(); @@ -269,6 +274,11 @@ public override Task WriteAsync(byte[] buffer, int offset, int count, Cancellati return _fs.WriteAsync(buffer, offset, count, cancellationToken); } + public override ValueTask WriteAsync(ReadOnlyMemory source, CancellationToken cancellationToken) + { + return _fs.WriteAsync(source, cancellationToken); + } + public override void WriteByte(byte value) { _fs.WriteByte(value); diff --git a/src/System.IO.Pipes/src/System/IO/Pipes/PipeStream.cs b/src/System.IO.Pipes/src/System/IO/Pipes/PipeStream.cs index 3788b43abfb5..446d11e9bb5b 100644 --- a/src/System.IO.Pipes/src/System/IO/Pipes/PipeStream.cs +++ b/src/System.IO.Pipes/src/System/IO/Pipes/PipeStream.cs @@ -278,7 +278,7 @@ public override Task WriteAsync(byte[] buffer, int offset, int count, Cancellati return WriteAsyncCore(new ReadOnlyMemory(buffer, offset, count), cancellationToken); } - public override Task WriteAsync(ReadOnlyMemory source, CancellationToken cancellationToken = default(CancellationToken)) + public override ValueTask WriteAsync(ReadOnlyMemory source, CancellationToken cancellationToken = default(CancellationToken)) { if (!_isAsync) { @@ -292,17 +292,17 @@ public override Task WriteAsync(byte[] buffer, int offset, int count, Cancellati if (cancellationToken.IsCancellationRequested) { - return Task.FromCanceled(cancellationToken); + return new ValueTask(Task.FromCanceled(cancellationToken)); } CheckWriteOperations(); if (source.Length == 0) { - return Task.CompletedTask; + return default; } - return WriteAsyncCore(source, cancellationToken); + return new ValueTask(WriteAsyncCore(source, cancellationToken)); } public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback callback, object state) diff --git a/src/System.IO.Pipes/tests/PipeTest.Read.netcoreapp.cs b/src/System.IO.Pipes/tests/PipeTest.Read.netcoreapp.cs index ecaedc9c735f..0b1bf32e1b47 100644 --- a/src/System.IO.Pipes/tests/PipeTest.Read.netcoreapp.cs +++ b/src/System.IO.Pipes/tests/PipeTest.Read.netcoreapp.cs @@ -88,7 +88,7 @@ public async Task ValidWriteAsync_Span_ValidReadAsync() byte[] sent = new byte[] { 123, 0, 5 }; byte[] received = new byte[] { 0, 0, 0 }; - Task write = pair.writeablePipe.WriteAsync(new ReadOnlyMemory(sent)); + ValueTask write = pair.writeablePipe.WriteAsync(new ReadOnlyMemory(sent)); Assert.Equal(sent.Length, await pair.readablePipe.ReadAsync(new Memory(received, 0, sent.Length))); Assert.Equal(sent, received); await write; @@ -112,7 +112,7 @@ public async Task AsyncReadWriteChain_Span_ReadWrite(int iterations, int writeBu for (int iter = 0; iter < iterations; iter++) { rand.NextBytes(writeBuffer); - Task writerTask = pair.writeablePipe.WriteAsync(new ReadOnlyMemory(writeBuffer), cancellationToken); + ValueTask writerTask = pair.writeablePipe.WriteAsync(new ReadOnlyMemory(writeBuffer), cancellationToken); int totalRead = 0; while (totalRead < writeBuffer.Length) diff --git a/src/System.IO/tests/BufferedStream/BufferedStreamTests.netcoreapp.cs b/src/System.IO/tests/BufferedStream/BufferedStreamTests.netcoreapp.cs index 7ead9aa31e12..51a1be459f53 100644 --- a/src/System.IO/tests/BufferedStream/BufferedStreamTests.netcoreapp.cs +++ b/src/System.IO/tests/BufferedStream/BufferedStreamTests.netcoreapp.cs @@ -83,7 +83,7 @@ public void ReadWriteMemory_Precanceled_Throws() using (var bs = new BufferedStream(new MemoryStream())) { Assert.Equal(TaskStatus.Canceled, bs.ReadAsync(new byte[1], new CancellationToken(true)).AsTask().Status); - Assert.Equal(TaskStatus.Canceled, bs.WriteAsync(new byte[1], new CancellationToken(true)).Status); + Assert.Equal(TaskStatus.Canceled, bs.WriteAsync(new byte[1], new CancellationToken(true)).AsTask().Status); } } } diff --git a/src/System.Net.Http/src/System/Net/Http/HttpContent.cs b/src/System.Net.Http/src/System/Net/Http/HttpContent.cs index f1c15e3cb774..c457deb149b8 100644 --- a/src/System.Net.Http/src/System/Net/Http/HttpContent.cs +++ b/src/System.Net.Http/src/System/Net/Http/HttpContent.cs @@ -320,19 +320,17 @@ internal Task CopyToAsync(Stream stream, TransportContext context, CancellationT try { - Task task = null; ArraySegment buffer; if (TryGetBuffer(out buffer)) { - task = stream.WriteAsync(buffer.Array, buffer.Offset, buffer.Count, cancellationToken); + return CopyToAsyncCore(stream.WriteAsync(new ReadOnlyMemory(buffer.Array, buffer.Offset, buffer.Count), cancellationToken)); } else { - task = SerializeToStreamAsync(stream, context, cancellationToken); + Task task = SerializeToStreamAsync(stream, context, cancellationToken); CheckTaskNotNull(task); + return CopyToAsyncCore(new ValueTask(task)); } - - return CopyToAsyncCore(task); } catch (Exception e) when (StreamCopyExceptionNeedsWrapping(e)) { @@ -340,7 +338,7 @@ internal Task CopyToAsync(Stream stream, TransportContext context, CancellationT } } - private static async Task CopyToAsyncCore(Task copyTask) + private static async Task CopyToAsyncCore(ValueTask copyTask) { try { @@ -733,6 +731,12 @@ public override Task WriteAsync(byte[] buffer, int offset, int count, Cancellati return base.WriteAsync(buffer, offset, count, cancellationToken); } + public override ValueTask WriteAsync(ReadOnlyMemory source, CancellationToken cancellationToken) + { + CheckSize(source.Length); + return base.WriteAsync(source, cancellationToken); + } + public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback callback, object state) { CheckSize(count); @@ -878,10 +882,10 @@ public override Task WriteAsync(byte[] buffer, int offset, int count, Cancellati return Task.CompletedTask; } - public override Task WriteAsync(ReadOnlyMemory source, CancellationToken cancellationToken = default) + public override ValueTask WriteAsync(ReadOnlyMemory source, CancellationToken cancellationToken = default) { Write(source.Span); - return Task.CompletedTask; + return default; } public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback asyncCallback, object asyncState) => diff --git a/src/System.Net.Http/src/System/Net/Http/MultipartContent.cs b/src/System.Net.Http/src/System/Net/Http/MultipartContent.cs index 40216b5a2761..47bd79ddb084 100644 --- a/src/System.Net.Http/src/System/Net/Http/MultipartContent.cs +++ b/src/System.Net.Http/src/System/Net/Http/MultipartContent.cs @@ -275,10 +275,10 @@ private string SerializeHeadersToString(StringBuilder scratch, int contentIndex, return scratch.ToString(); } - private static Task EncodeStringToStreamAsync(Stream stream, string input) + private static ValueTask EncodeStringToStreamAsync(Stream stream, string input) { byte[] buffer = HttpRuleParser.DefaultHttpEncoding.GetBytes(input); - return stream.WriteAsync(buffer, 0, buffer.Length); + return stream.WriteAsync(new ReadOnlyMemory(buffer)); } private static Stream EncodeStringToNewStream(string input) @@ -583,7 +583,7 @@ public override void Flush() { } public override void Write(byte[] buffer, int offset, int count) { throw new NotSupportedException(); } public override void Write(ReadOnlySpan source) { throw new NotSupportedException(); } public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) { throw new NotSupportedException(); } - public override Task WriteAsync(ReadOnlyMemory source, CancellationToken cancellationToken = default) { throw new NotSupportedException(); } + public override ValueTask WriteAsync(ReadOnlyMemory source, CancellationToken cancellationToken = default) { throw new NotSupportedException(); } } #endregion Serialization } diff --git a/src/System.Net.Http/src/System/Net/Http/ReadOnlyMemoryContent.cs b/src/System.Net.Http/src/System/Net/Http/ReadOnlyMemoryContent.cs index 11b98776e67f..77f713593b79 100644 --- a/src/System.Net.Http/src/System/Net/Http/ReadOnlyMemoryContent.cs +++ b/src/System.Net.Http/src/System/Net/Http/ReadOnlyMemoryContent.cs @@ -25,10 +25,10 @@ public ReadOnlyMemoryContent(ReadOnlyMemory content) } protected override Task SerializeToStreamAsync(Stream stream, TransportContext context) => - stream.WriteAsync(_content); + stream.WriteAsync(_content).AsTask(); internal override Task SerializeToStreamAsync(Stream stream, TransportContext context, CancellationToken cancellationToken) => - stream.WriteAsync(_content, cancellationToken); + stream.WriteAsync(_content, cancellationToken).AsTask(); protected internal override bool TryComputeLength(out long length) { diff --git a/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ChunkedEncodingReadStream.cs b/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ChunkedEncodingReadStream.cs index 481a6782b8b0..4ae0fe87d0d0 100644 --- a/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ChunkedEncodingReadStream.cs +++ b/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ChunkedEncodingReadStream.cs @@ -90,7 +90,7 @@ private async ValueTask ReadAsyncCore(Memory destination, Cancellatio // as the connection buffer. That avoids an unnecessary copy while still reading // the maximum amount we'd otherwise read at a time. Debug.Assert(_connection.RemainingBuffer.Length == 0); - int bytesRead = await _connection.ReadAsync(destination.Slice(0, (int)Math.Min((ulong)destination.Length, _chunkBytesRemaining))); + int bytesRead = await _connection.ReadAsync(destination.Slice(0, (int)Math.Min((ulong)destination.Length, _chunkBytesRemaining))).ConfigureAwait(false); if (bytesRead == 0) { throw new IOException(SR.net_http_invalid_response); diff --git a/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ChunkedEncodingWriteStream.cs b/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ChunkedEncodingWriteStream.cs index 9755fb40b641..18f301c0fef3 100644 --- a/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ChunkedEncodingWriteStream.cs +++ b/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ChunkedEncodingWriteStream.cs @@ -18,7 +18,7 @@ public ChunkedEncodingWriteStream(HttpConnection connection) : base(connection) { } - public override Task WriteAsync(ReadOnlyMemory source, CancellationToken ignored) + public override ValueTask WriteAsync(ReadOnlyMemory source, CancellationToken ignored) { Debug.Assert(_connection._currentRequest != null); @@ -26,14 +26,13 @@ public override Task WriteAsync(ReadOnlyMemory source, CancellationToken i // here are those that are already covered by the token having been registered with // to close the connection. - if (source.Length == 0) - { + ValueTask task = source.Length == 0 ? // Don't write if nothing was given, especially since we don't want to accidentally send a 0 chunk, // which would indicate end of body. Instead, just ensure no content is stuck in the buffer. - return _connection.FlushAsync(); - } + _connection.FlushAsync() : + new ValueTask(WriteChunkAsync(source)); - return WriteChunkAsync(source); + return task; } private async Task WriteChunkAsync(ReadOnlyMemory source) diff --git a/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ContentLengthWriteStream.cs b/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ContentLengthWriteStream.cs index 30da8143b2d2..411472175efd 100644 --- a/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ContentLengthWriteStream.cs +++ b/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ContentLengthWriteStream.cs @@ -16,14 +16,14 @@ public ContentLengthWriteStream(HttpConnection connection) : base(connection) { } - public override Task WriteAsync(ReadOnlyMemory source, CancellationToken ignored) // token ignored as it comes from SendAsync + public override ValueTask WriteAsync(ReadOnlyMemory source, CancellationToken ignored) // token ignored as it comes from SendAsync { Debug.Assert(_connection._currentRequest != null); // Have the connection write the data, skipping the buffer. Importantly, this will // force a flush of anything already in the buffer, i.e. any remaining request headers // that are still buffered. - return _connection.WriteAsync(source); + return new ValueTask(_connection.WriteAsync(source)); } public override Task FinishAsync() diff --git a/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnection.cs b/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnection.cs index f5d191630dbd..a613a365d45e 100644 --- a/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnection.cs +++ b/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnection.cs @@ -54,7 +54,7 @@ internal partial class HttpConnection : IDisposable private int _writeOffset; private int _allowedReadLineBytes; - private Task _readAheadTask; + private ValueTask? _readAheadTask; private byte[] _readBuffer; private int _readOffset; private int _readLength; @@ -125,7 +125,7 @@ public bool ReadAheadCompleted get { Debug.Assert(_readAheadTask != null, $"{nameof(_readAheadTask)} should have been initialized"); - return _readAheadTask.IsCompleted; + return _readAheadTask.GetValueOrDefault().IsCompleted; } } @@ -384,12 +384,12 @@ public async Task SendAsync(HttpRequestMessage request, Can // connection again, as that would mean the connection was either closed or had // erroneous data sent on it by the server in response to no request from us. // We need to consume that read prior to issuing another read request. - Task t = _readAheadTask; + ValueTask? t = _readAheadTask; if (t != null) { _readAheadTask = null; - int bytesRead = await t.ConfigureAwait(false); + int bytesRead = await t.GetValueOrDefault().ConfigureAwait(false); if (NetEventSource.IsEnabled) Trace($"Received {bytesRead} bytes."); if (bytesRead == 0) @@ -838,7 +838,7 @@ private async Task WriteAsync(ReadOnlyMemory source) } } - private Task WriteWithoutBufferingAsync(ReadOnlyMemory source) + private ValueTask WriteWithoutBufferingAsync(ReadOnlyMemory source) { if (_writeOffset == 0) { @@ -860,7 +860,7 @@ private Task WriteWithoutBufferingAsync(ReadOnlyMemory source) // There's data in the write buffer and the data we're writing doesn't fit after it. // Do two writes, one to flush the buffer and then another to write the supplied content. - return FlushThenWriteWithoutBufferingAsync(source); + return new ValueTask(FlushThenWriteWithoutBufferingAsync(source)); } private async Task FlushThenWriteWithoutBufferingAsync(ReadOnlyMemory source) @@ -1000,18 +1000,18 @@ private async Task WriteStringAsyncSlow(string s) } } - private Task FlushAsync() + private ValueTask FlushAsync() { if (_writeOffset > 0) { - Task t = WriteToStreamAsync(new ReadOnlyMemory(_writeBuffer, 0, _writeOffset)); + ValueTask t = WriteToStreamAsync(new ReadOnlyMemory(_writeBuffer, 0, _writeOffset)); _writeOffset = 0; return t; } - return Task.CompletedTask; + return default; } - private Task WriteToStreamAsync(ReadOnlyMemory source) + private ValueTask WriteToStreamAsync(ReadOnlyMemory source) { if (NetEventSource.IsEnabled) Trace($"Writing {source.Length} bytes."); return _stream.WriteAsync(source); @@ -1171,7 +1171,7 @@ private async Task CopyFromBufferAsync(Stream destination, int count, Cancellati Debug.Assert(count <= _readLength - _readOffset); if (NetEventSource.IsEnabled) Trace($"Copying {count} bytes to stream."); - await destination.WriteAsync(_readBuffer, _readOffset, count, cancellationToken).ConfigureAwait(false); + await destination.WriteAsync(new ReadOnlyMemory(_readBuffer, _readOffset, count), cancellationToken).ConfigureAwait(false); _readOffset += count; } @@ -1267,7 +1267,7 @@ private void ReturnConnectionToPool() // at any point to understand if the connection has been closed or if errant data // has been sent on the connection by the server, either of which would mean we // should close the connection and not use it for subsequent requests. - _readAheadTask = _stream.ReadAsync(_readBuffer, 0, _readBuffer.Length); + _readAheadTask = _stream.ReadAsync(new Memory(_readBuffer)); // Put connection back in the pool. _pool.ReturnConnection(this); diff --git a/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpContentDuplexStream.cs b/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpContentDuplexStream.cs index 199383bf73ed..1229b0a626a0 100644 --- a/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpContentDuplexStream.cs +++ b/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpContentDuplexStream.cs @@ -40,7 +40,7 @@ public sealed override void Write(byte[] buffer, int offset, int count) public sealed override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) { ValidateBufferArgs(buffer, offset, count); - return WriteAsync(new ReadOnlyMemory(buffer, offset, count), cancellationToken); + return WriteAsync(new ReadOnlyMemory(buffer, offset, count), cancellationToken).AsTask(); } public sealed override void CopyTo(Stream destination, int bufferSize) => diff --git a/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpContentReadStream.cs b/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpContentReadStream.cs index e1cb3cec6bb0..d1c44340d821 100644 --- a/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpContentReadStream.cs +++ b/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpContentReadStream.cs @@ -32,7 +32,7 @@ public sealed override Task FlushAsync(CancellationToken cancellationToken) => public sealed override void Write(byte[] buffer, int offset, int count) => throw new NotSupportedException(); public sealed override void Write(ReadOnlySpan source) => throw new NotSupportedException(); public sealed override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) => throw new NotSupportedException(); - public sealed override Task WriteAsync(ReadOnlyMemory destination, CancellationToken cancellationToken) => throw new NotSupportedException(); + public sealed override ValueTask WriteAsync(ReadOnlyMemory destination, CancellationToken cancellationToken) => throw new NotSupportedException(); public sealed override int Read(byte[] buffer, int offset, int count) { diff --git a/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpContentWriteStream.cs b/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpContentWriteStream.cs index cf4c439b5fd0..6b7993950a41 100644 --- a/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpContentWriteStream.cs +++ b/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpContentWriteStream.cs @@ -32,11 +32,11 @@ public sealed override void Write(byte[] buffer, int offset, int count) => public sealed override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken ignored) { ValidateBufferArgs(buffer, offset, count); - return WriteAsync(new ReadOnlyMemory(buffer, offset, count), ignored); + return WriteAsync(new ReadOnlyMemory(buffer, offset, count), ignored).AsTask(); } public sealed override Task FlushAsync(CancellationToken ignored) => - _connection.FlushAsync(); + _connection.FlushAsync().AsTask(); public abstract Task FinishAsync(); } diff --git a/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/RawConnectionStream.cs b/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/RawConnectionStream.cs index 354965174cb4..e2a09ee9e89b 100644 --- a/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/RawConnectionStream.cs +++ b/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/RawConnectionStream.cs @@ -120,27 +120,27 @@ private void Finish() _connection = null; } - public override Task WriteAsync(ReadOnlyMemory source, CancellationToken cancellationToken) + public override ValueTask WriteAsync(ReadOnlyMemory source, CancellationToken cancellationToken) { if (cancellationToken.IsCancellationRequested) { - return Task.FromCanceled(cancellationToken); + return new ValueTask(Task.FromCanceled(cancellationToken)); } if (_connection == null) { - return Task.FromException(new IOException(SR.net_http_io_write)); + return new ValueTask(Task.FromException(new IOException(SR.net_http_io_write))); } if (source.Length == 0) { - return Task.CompletedTask; + return default; } - Task writeTask = _connection.WriteWithoutBufferingAsync(source); + ValueTask writeTask = _connection.WriteWithoutBufferingAsync(source); return writeTask.IsCompleted ? writeTask : - WaitWithConnectionCancellationAsync(writeTask, cancellationToken); + new ValueTask(WaitWithConnectionCancellationAsync(writeTask, cancellationToken)); } public override Task FlushAsync(CancellationToken cancellationToken) @@ -155,13 +155,13 @@ public override Task FlushAsync(CancellationToken cancellationToken) return Task.CompletedTask; } - Task flushTask = _connection.FlushAsync(); + ValueTask flushTask = _connection.FlushAsync(); return flushTask.IsCompleted ? - flushTask : + flushTask.AsTask() : WaitWithConnectionCancellationAsync(flushTask, cancellationToken); } - private async Task WaitWithConnectionCancellationAsync(Task task, CancellationToken cancellationToken) + private async Task WaitWithConnectionCancellationAsync(ValueTask task, CancellationToken cancellationToken) { CancellationTokenRegistration ctr = _connection.RegisterCancellation(cancellationToken); try diff --git a/src/System.Net.Http/src/System/Net/Http/StreamContent.cs b/src/System.Net.Http/src/System/Net/Http/StreamContent.cs index 02124745593d..91fd75e2541d 100644 --- a/src/System.Net.Http/src/System/Net/Http/StreamContent.cs +++ b/src/System.Net.Http/src/System/Net/Http/StreamContent.cs @@ -167,7 +167,7 @@ public override Task WriteAsync(byte[] buffer, int offset, int count, Threading. throw new NotSupportedException(SR.net_http_content_readonly_stream); } - public override Task WriteAsync(ReadOnlyMemory source, CancellationToken cancellationToken = default) + public override ValueTask WriteAsync(ReadOnlyMemory source, CancellationToken cancellationToken = default) { throw new NotSupportedException(SR.net_http_content_readonly_stream); } diff --git a/src/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTest.cs b/src/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTest.cs index 27544bab3a43..40bb8aa938b2 100644 --- a/src/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTest.cs +++ b/src/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTest.cs @@ -1891,8 +1891,8 @@ await LoopbackServer.CreateClientAndServerAsync(async uri => Assert.Throws(() => responseStream.SetLength(0)); Assert.Throws(() => responseStream.Write(new byte[1], 0, 1)); Assert.Throws(() => responseStream.Write(new Span(new byte[1]))); - await Assert.ThrowsAsync(() => responseStream.WriteAsync(new Memory(new byte[1]))); - await Assert.ThrowsAsync(() => responseStream.WriteAsync(new byte[1], 0, 1)); + await Assert.ThrowsAsync(async () => await responseStream.WriteAsync(new Memory(new byte[1]))); + await Assert.ThrowsAsync(async () => await responseStream.WriteAsync(new byte[1], 0, 1)); Assert.Throws(() => responseStream.WriteByte(1)); // Invalid arguments diff --git a/src/System.Net.Http/tests/FunctionalTests/ReadOnlyMemoryContentTest.cs b/src/System.Net.Http/tests/FunctionalTests/ReadOnlyMemoryContentTest.cs index 98bb3426b583..f33fc4bb74f9 100644 --- a/src/System.Net.Http/tests/FunctionalTests/ReadOnlyMemoryContentTest.cs +++ b/src/System.Net.Http/tests/FunctionalTests/ReadOnlyMemoryContentTest.cs @@ -66,8 +66,8 @@ public async Task ReadAsStreamAsync_TrivialMembersHaveExpectedValuesAndBehavior( Assert.Throws(() => stream.WriteByte(0)); Assert.Throws(() => stream.Write(new byte[1], 0, 1)); Assert.Throws(() => stream.Write(new ReadOnlySpan(new byte[1]))); - await Assert.ThrowsAsync(() => stream.WriteAsync(new byte[1], 0, 1)); - await Assert.ThrowsAsync(() => stream.WriteAsync(new ReadOnlyMemory(new byte[1]))); + await Assert.ThrowsAsync(async () => await stream.WriteAsync(new byte[1], 0, 1)); + await Assert.ThrowsAsync(async () => await stream.WriteAsync(new ReadOnlyMemory(new byte[1]))); // nops stream.Flush(); diff --git a/src/System.Net.Security/src/System/Net/FixedSizeReader.cs b/src/System.Net.Security/src/System/Net/FixedSizeReader.cs index 94dc81225e9f..14ed8a79de44 100644 --- a/src/System.Net.Security/src/System/Net/FixedSizeReader.cs +++ b/src/System.Net.Security/src/System/Net/FixedSizeReader.cs @@ -53,7 +53,7 @@ public static async void ReadPacketAsync(Stream transport, AsyncProtocolRequest int remainingCount = request.Count, offset = request.Offset; do { - int bytes = await transport.ReadAsync(request.Buffer, offset, remainingCount, CancellationToken.None).ConfigureAwait(false); + int bytes = await transport.ReadAsync(new Memory(request.Buffer, offset, remainingCount), CancellationToken.None).ConfigureAwait(false); if (bytes == 0) { if (remainingCount != request.Count) diff --git a/src/System.Net.Security/src/System/Net/Security/SslStream.cs b/src/System.Net.Security/src/System/Net/Security/SslStream.cs index 62027859b7c2..71f9e0181bce 100644 --- a/src/System.Net.Security/src/System/Net/Security/SslStream.cs +++ b/src/System.Net.Security/src/System/Net/Security/SslStream.cs @@ -711,7 +711,7 @@ public override Task WriteAsync(byte[] buffer, int offset, int count, Cancellati return _sslState.SecureStream.WriteAsync(buffer, offset, count, cancellationToken); } - public override Task WriteAsync(ReadOnlyMemory source, CancellationToken cancellationToken) + public override ValueTask WriteAsync(ReadOnlyMemory source, CancellationToken cancellationToken) { return _sslState.SecureStream.WriteAsync(source, cancellationToken); } diff --git a/src/System.Net.Security/src/System/Net/Security/SslStreamInternal.Adapters.cs b/src/System.Net.Security/src/System/Net/Security/SslStreamInternal.Adapters.cs index a94abda4d5b5..2dcc93fe3171 100644 --- a/src/System.Net.Security/src/System/Net/Security/SslStreamInternal.Adapters.cs +++ b/src/System.Net.Security/src/System/Net/Security/SslStreamInternal.Adapters.cs @@ -12,7 +12,7 @@ internal partial class SslStreamInternal private interface ISslWriteAdapter { Task LockAsync(); - Task WriteAsync(byte[] buffer, int offset, int count); + ValueTask WriteAsync(byte[] buffer, int offset, int count); } private interface ISslReadAdapter @@ -61,7 +61,7 @@ public SslWriteAsync(SslState sslState, CancellationToken cancellationToken) public Task LockAsync() => _sslState.CheckEnqueueWriteAsync(); - public Task WriteAsync(byte[] buffer, int offset, int count) => _sslState.InnerStream.WriteAsync(buffer, offset, count, _cancellationToken); + public ValueTask WriteAsync(byte[] buffer, int offset, int count) => _sslState.InnerStream.WriteAsync(new ReadOnlyMemory(buffer, offset, count), _cancellationToken); } private readonly struct SslWriteSync : ISslWriteAdapter @@ -76,10 +76,10 @@ public Task LockAsync() return Task.CompletedTask; } - public Task WriteAsync(byte[] buffer, int offset, int count) + public ValueTask WriteAsync(byte[] buffer, int offset, int count) { _sslState.InnerStream.Write(buffer, offset, count); - return Task.CompletedTask; + return default; } } } diff --git a/src/System.Net.Security/src/System/Net/Security/SslStreamInternal.cs b/src/System.Net.Security/src/System/Net/Security/SslStreamInternal.cs index f2981ef0d3a5..a2581fa0b6d0 100644 --- a/src/System.Net.Security/src/System/Net/Security/SslStreamInternal.cs +++ b/src/System.Net.Security/src/System/Net/Security/SslStreamInternal.cs @@ -167,7 +167,7 @@ internal IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCall internal void EndWrite(IAsyncResult asyncResult) => TaskToApm.End(asyncResult); - internal Task WriteAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken) + internal ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken) { SslWriteAsync writeAdapter = new SslWriteAsync(_sslState, cancellationToken); return WriteAsyncInternal(writeAdapter, buffer); @@ -176,7 +176,7 @@ internal Task WriteAsync(ReadOnlyMemory buffer, CancellationToken cancella internal Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) { ValidateParameters(buffer, offset, count); - return WriteAsync(new ReadOnlyMemory(buffer, offset, count), cancellationToken); + return WriteAsync(new ReadOnlyMemory(buffer, offset, count), cancellationToken).AsTask(); } private void ResetReadBuffer() @@ -338,7 +338,7 @@ private async ValueTask ReadAsyncInternal(TReadAdapter adapte } } - private Task WriteAsyncInternal(TWriteAdapter writeAdapter, ReadOnlyMemory buffer) + private ValueTask WriteAsyncInternal(TWriteAdapter writeAdapter, ReadOnlyMemory buffer) where TWriteAdapter : struct, ISslWriteAdapter { _sslState.CheckThrow(authSuccessCheck: true, shutdownCheck: true); @@ -346,7 +346,7 @@ private Task WriteAsyncInternal(TWriteAdapter writeAdapter, ReadO if (buffer.Length == 0 && !SslStreamPal.CanEncryptEmptyMessage) { // If it's an empty message and the PAL doesn't support that, we're done. - return Task.CompletedTask; + return default; } if (Interlocked.Exchange(ref _nestedWrite, 1) == 1) @@ -354,18 +354,18 @@ private Task WriteAsyncInternal(TWriteAdapter writeAdapter, ReadO throw new NotSupportedException(SR.Format(SR.net_io_invalidnestedcall, nameof(WriteAsync), "write")); } - Task t = buffer.Length < _sslState.MaxDataSize ? + ValueTask t = buffer.Length < _sslState.MaxDataSize ? WriteSingleChunk(writeAdapter, buffer) : - WriteAsyncChunked(writeAdapter, buffer); + new ValueTask(WriteAsyncChunked(writeAdapter, buffer)); if (t.IsCompletedSuccessfully) { _nestedWrite = 0; return t; } - return ExitWriteAsync(t); + return new ValueTask(ExitWriteAsync(t)); - async Task ExitWriteAsync(Task task) + async Task ExitWriteAsync(ValueTask task) { try { @@ -389,7 +389,7 @@ async Task ExitWriteAsync(Task task) } } - private Task WriteSingleChunk(TWriteAdapter writeAdapter, ReadOnlyMemory buffer) + private ValueTask WriteSingleChunk(TWriteAdapter writeAdapter, ReadOnlyMemory buffer) where TWriteAdapter : struct, ISslWriteAdapter { // Request a write IO slot. @@ -397,7 +397,7 @@ private Task WriteSingleChunk(TWriteAdapter writeAdapter, ReadOnl if (!ioSlot.IsCompletedSuccessfully) { // Operation is async and has been queued, return. - return WaitForWriteIOSlot(writeAdapter, ioSlot, buffer); + return new ValueTask(WaitForWriteIOSlot(writeAdapter, ioSlot, buffer)); } byte[] rentedBuffer = ArrayPool.Shared.Rent(buffer.Length + FrameOverhead); @@ -410,10 +410,10 @@ private Task WriteSingleChunk(TWriteAdapter writeAdapter, ReadOnl // Re-handshake status is not supported. ArrayPool.Shared.Return(rentedBuffer); ProtocolToken message = new ProtocolToken(null, status); - return Task.FromException(new IOException(SR.net_io_encrypt, message.GetException())); + return new ValueTask(Task.FromException(new IOException(SR.net_io_encrypt, message.GetException()))); } - Task t = writeAdapter.WriteAsync(outBuffer, 0, encryptedBytes); + ValueTask t = writeAdapter.WriteAsync(outBuffer, 0, encryptedBytes); if (t.IsCompletedSuccessfully) { ArrayPool.Shared.Return(rentedBuffer); @@ -422,7 +422,7 @@ private Task WriteSingleChunk(TWriteAdapter writeAdapter, ReadOnl } else { - return CompleteAsync(t, rentedBuffer); + return new ValueTask(CompleteAsync(t, rentedBuffer)); } async Task WaitForWriteIOSlot(TWriteAdapter wAdapter, Task lockTask, ReadOnlyMemory buff) @@ -431,7 +431,7 @@ async Task WaitForWriteIOSlot(TWriteAdapter wAdapter, Task lockTask, ReadOnlyMem await WriteSingleChunk(wAdapter, buff).ConfigureAwait(false); } - async Task CompleteAsync(Task writeTask, byte[] bufferToReturn) + async Task CompleteAsync(ValueTask writeTask, byte[] bufferToReturn) { try { @@ -471,7 +471,7 @@ private ValueTask FillBufferAsync(TReadAdapter adapter, int m ValueTask t = adapter.ReadAsync(_internalBuffer, _internalBufferCount, _internalBuffer.Length - _internalBufferCount); if (!t.IsCompletedSuccessfully) { - return new ValueTask(InternalFillBufferAsync(adapter, t.AsTask(), minSize, initialCount)); + return new ValueTask(InternalFillBufferAsync(adapter, t, minSize, initialCount)); } int bytes = t.Result; if (bytes == 0) @@ -490,7 +490,7 @@ private ValueTask FillBufferAsync(TReadAdapter adapter, int m return new ValueTask(minSize); - async Task InternalFillBufferAsync(TReadAdapter adap, Task task, int min, int initial) + async Task InternalFillBufferAsync(TReadAdapter adap, ValueTask task, int min, int initial) { while (true) { @@ -511,7 +511,7 @@ async Task InternalFillBufferAsync(TReadAdapter adap, Task task, int m return min; } - task = adap.ReadAsync(_internalBuffer, _internalBufferCount, _internalBuffer.Length - _internalBufferCount).AsTask(); + task = adap.ReadAsync(_internalBuffer, _internalBufferCount, _internalBuffer.Length - _internalBufferCount); } } } diff --git a/src/System.Net.Security/tests/UnitTests/Fakes/FakeAuthenticatedStream.cs b/src/System.Net.Security/tests/UnitTests/Fakes/FakeAuthenticatedStream.cs index be2d2483d5b3..1a549fea08e9 100644 --- a/src/System.Net.Security/tests/UnitTests/Fakes/FakeAuthenticatedStream.cs +++ b/src/System.Net.Security/tests/UnitTests/Fakes/FakeAuthenticatedStream.cs @@ -33,7 +33,7 @@ protected override void Dispose(bool disposing) public abstract bool IsSigned { get; } public abstract bool IsServer { get; } - public new abstract Task WriteAsync(ReadOnlyMemory buffer, CancellationToken token); + public new abstract ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationToken token); public new abstract ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken); } } diff --git a/src/System.Net.Security/tests/UnitTests/Fakes/FakeSslState.cs b/src/System.Net.Security/tests/UnitTests/Fakes/FakeSslState.cs index 2ee788a9ea2c..c32f1d0f2e64 100644 --- a/src/System.Net.Security/tests/UnitTests/Fakes/FakeSslState.cs +++ b/src/System.Net.Security/tests/UnitTests/Fakes/FakeSslState.cs @@ -279,7 +279,7 @@ public override void Write(byte[] buffer, int offset, int count) throw new NotImplementedException(); } - public new Task WriteAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken) + public new ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken) { throw new NotImplementedException(); } diff --git a/src/System.Net.Sockets/src/System.Net.Sockets.csproj b/src/System.Net.Sockets/src/System.Net.Sockets.csproj index e29864b9aab1..362296fe81b5 100644 --- a/src/System.Net.Sockets/src/System.Net.Sockets.csproj +++ b/src/System.Net.Sockets/src/System.Net.Sockets.csproj @@ -410,6 +410,7 @@ + diff --git a/src/System.Net.Sockets/src/System/Net/Sockets/NetworkStream.cs b/src/System.Net.Sockets/src/System/Net/Sockets/NetworkStream.cs index 4cd1b3ae2184..3eb9f2626c29 100644 --- a/src/System.Net.Sockets/src/System/Net/Sockets/NetworkStream.cs +++ b/src/System.Net.Sockets/src/System/Net/Sockets/NetworkStream.cs @@ -204,7 +204,7 @@ public virtual bool DataAvailable #endif if (_cleanedUp) { - throw new ObjectDisposedException(this.GetType().FullName); + throw new ObjectDisposedException(GetType().FullName); } // Ask the socket how many bytes are available. If it's @@ -269,7 +269,7 @@ public override int Read(byte[] buffer, int offset, int size) bool canRead = CanRead; // Prevent race with Dispose. if (_cleanedUp) { - throw new ObjectDisposedException(this.GetType().FullName); + throw new ObjectDisposedException(GetType().FullName); } if (!canRead) { @@ -281,11 +281,11 @@ public override int Read(byte[] buffer, int offset, int size) { throw new ArgumentNullException(nameof(buffer)); } - if (offset < 0 || offset > buffer.Length) + if ((uint)offset > buffer.Length) { throw new ArgumentOutOfRangeException(nameof(offset)); } - if (size < 0 || size > buffer.Length - offset) + if ((uint)size > buffer.Length - offset) { throw new ArgumentOutOfRangeException(nameof(size)); } @@ -358,7 +358,7 @@ public override void Write(byte[] buffer, int offset, int size) bool canWrite = CanWrite; // Prevent race with Dispose. if (_cleanedUp) { - throw new ObjectDisposedException(this.GetType().FullName); + throw new ObjectDisposedException(GetType().FullName); } if (!canWrite) { @@ -370,11 +370,11 @@ public override void Write(byte[] buffer, int offset, int size) { throw new ArgumentNullException(nameof(buffer)); } - if (offset < 0 || offset > buffer.Length) + if ((uint)offset > buffer.Length) { throw new ArgumentOutOfRangeException(nameof(offset)); } - if (size < 0 || size > buffer.Length - offset) + if ((uint)size > buffer.Length - offset) { throw new ArgumentOutOfRangeException(nameof(size)); } @@ -504,7 +504,7 @@ public override IAsyncResult BeginRead(byte[] buffer, int offset, int size, Asyn bool canRead = CanRead; // Prevent race with Dispose. if (_cleanedUp) { - throw new ObjectDisposedException(this.GetType().FullName); + throw new ObjectDisposedException(GetType().FullName); } if (!canRead) { @@ -516,11 +516,11 @@ public override IAsyncResult BeginRead(byte[] buffer, int offset, int size, Asyn { throw new ArgumentNullException(nameof(buffer)); } - if (offset < 0 || offset > buffer.Length) + if ((uint)offset > buffer.Length) { throw new ArgumentOutOfRangeException(nameof(offset)); } - if (size < 0 || size > buffer.Length - offset) + if ((uint)size > buffer.Length - offset) { throw new ArgumentOutOfRangeException(nameof(size)); } @@ -562,7 +562,7 @@ public override int EndRead(IAsyncResult asyncResult) #endif if (_cleanedUp) { - throw new ObjectDisposedException(this.GetType().FullName); + throw new ObjectDisposedException(GetType().FullName); } // Validate input parameters. @@ -609,7 +609,7 @@ public override IAsyncResult BeginWrite(byte[] buffer, int offset, int size, Asy bool canWrite = CanWrite; // Prevent race with Dispose. if (_cleanedUp) { - throw new ObjectDisposedException(this.GetType().FullName); + throw new ObjectDisposedException(GetType().FullName); } if (!canWrite) { @@ -621,11 +621,11 @@ public override IAsyncResult BeginWrite(byte[] buffer, int offset, int size, Asy { throw new ArgumentNullException(nameof(buffer)); } - if (offset < 0 || offset > buffer.Length) + if ((uint)offset > buffer.Length) { throw new ArgumentOutOfRangeException(nameof(offset)); } - if (size < 0 || size > buffer.Length - offset) + if ((uint)size > buffer.Length - offset) { throw new ArgumentOutOfRangeException(nameof(size)); } @@ -664,7 +664,7 @@ public override void EndWrite(IAsyncResult asyncResult) #endif if (_cleanedUp) { - throw new ObjectDisposedException(this.GetType().FullName); + throw new ObjectDisposedException(GetType().FullName); } // Validate input parameters. @@ -708,7 +708,7 @@ public override Task ReadAsync(byte[] buffer, int offset, int size, Cancell bool canRead = CanRead; // Prevent race with Dispose. if (_cleanedUp) { - throw new ObjectDisposedException(this.GetType().FullName); + throw new ObjectDisposedException(GetType().FullName); } if (!canRead) { @@ -720,26 +720,22 @@ public override Task ReadAsync(byte[] buffer, int offset, int size, Cancell { throw new ArgumentNullException(nameof(buffer)); } - if (offset < 0 || offset > buffer.Length) + if ((uint)offset > buffer.Length) { throw new ArgumentOutOfRangeException(nameof(offset)); } - if (size < 0 || size > buffer.Length - offset) + if ((uint)size > buffer.Length - offset) { throw new ArgumentOutOfRangeException(nameof(size)); } - if (cancellationToken.IsCancellationRequested) - { - return Task.FromCanceled(cancellationToken); - } - try { return _streamSocket.ReceiveAsync( - new ArraySegment(buffer, offset, size), + new Memory(buffer, offset, size), SocketFlags.None, - fromNetworkStream: true); + fromNetworkStream: true, + cancellationToken).AsTask(); } catch (Exception exception) when (!(exception is OutOfMemoryException)) { @@ -754,7 +750,7 @@ public override ValueTask ReadAsync(Memory destination, CancellationT bool canRead = CanRead; // Prevent race with Dispose. if (_cleanedUp) { - throw new ObjectDisposedException(this.GetType().FullName); + throw new ObjectDisposedException(GetType().FullName); } if (!canRead) { @@ -797,7 +793,7 @@ public override Task WriteAsync(byte[] buffer, int offset, int size, Cancellatio bool canWrite = CanWrite; // Prevent race with Dispose. if (_cleanedUp) { - throw new ObjectDisposedException(this.GetType().FullName); + throw new ObjectDisposedException(GetType().FullName); } if (!canWrite) { @@ -809,26 +805,21 @@ public override Task WriteAsync(byte[] buffer, int offset, int size, Cancellatio { throw new ArgumentNullException(nameof(buffer)); } - if (offset < 0 || offset > buffer.Length) + if ((uint)offset > buffer.Length) { throw new ArgumentOutOfRangeException(nameof(offset)); } - if (size < 0 || size > buffer.Length - offset) + if ((uint)size > buffer.Length - offset) { throw new ArgumentOutOfRangeException(nameof(size)); } - if (cancellationToken.IsCancellationRequested) - { - return Task.FromCanceled(cancellationToken); - } - try { - return _streamSocket.SendAsync( - new ArraySegment(buffer, offset, size), + return _streamSocket.SendAsyncForNetworkStream( + new ReadOnlyMemory(buffer, offset, size), SocketFlags.None, - fromNetworkStream: true); + cancellationToken).AsTask(); } catch (Exception exception) when (!(exception is OutOfMemoryException)) { @@ -838,12 +829,12 @@ public override Task WriteAsync(byte[] buffer, int offset, int size, Cancellatio } } - public override Task WriteAsync(ReadOnlyMemory source, CancellationToken cancellationToken) + public override ValueTask WriteAsync(ReadOnlyMemory source, CancellationToken cancellationToken) { bool canWrite = CanWrite; // Prevent race with Dispose. if (_cleanedUp) { - throw new ObjectDisposedException(this.GetType().FullName); + throw new ObjectDisposedException(GetType().FullName); } if (!canWrite) { @@ -852,15 +843,10 @@ public override Task WriteAsync(ReadOnlyMemory source, CancellationToken c try { - ValueTask t = _streamSocket.SendAsync( + return _streamSocket.SendAsyncForNetworkStream( source, SocketFlags.None, - fromNetworkStream: true, cancellationToken: cancellationToken); - - return t.IsCompletedSuccessfully ? - Task.CompletedTask : - t.AsTask(); } catch (Exception exception) when (!(exception is OutOfMemoryException)) { @@ -870,50 +856,6 @@ public override Task WriteAsync(ReadOnlyMemory source, CancellationToken c } } - public override Task CopyToAsync(Stream destination, int bufferSize, CancellationToken cancellationToken) - { - // Validate arguments as would the base CopyToAsync - StreamHelpers.ValidateCopyToArgs(this, destination, bufferSize); - - // And bail early if cancellation has already been requested - if (cancellationToken.IsCancellationRequested) - { - return Task.FromCanceled(cancellationToken); - } - - // Do the copy. We get a copy buffer from the shared pool, and we pass both it and the - // socket into the copy as part of the event args so as to avoid additional fields in - // the async method's state machine. - return CopyToAsyncCore( - destination, - new AwaitableSocketAsyncEventArgs(_streamSocket, ArrayPool.Shared.Rent(bufferSize)), - cancellationToken); - } - - private static async Task CopyToAsyncCore(Stream destination, AwaitableSocketAsyncEventArgs ea, CancellationToken cancellationToken) - { - try - { - while (true) - { - cancellationToken.ThrowIfCancellationRequested(); - - int bytesRead = await ea.ReceiveAsync(); - if (bytesRead == 0) - { - break; - } - - await destination.WriteAsync(ea.Buffer, 0, bytesRead, cancellationToken).ConfigureAwait(false); - } - } - finally - { - ArrayPool.Shared.Return(ea.Buffer, clearArray: true); - ea.Dispose(); - } - } - // Flushes data from the stream. This is meaningless for us, so it does nothing. public override void Flush() { @@ -959,115 +901,5 @@ internal void SetSocketTimeoutOption(SocketShutdown mode, int timeout, bool sile } } } - - /// A SocketAsyncEventArgs that can be awaited to get the result of an operation. - internal sealed class AwaitableSocketAsyncEventArgs : SocketAsyncEventArgs, ICriticalNotifyCompletion - { - /// Sentinel object used to indicate that the operation has completed prior to OnCompleted being called. - private static readonly Action s_completedSentinel = () => { }; - /// - /// null if the operation has not completed, if it has, and another object - /// if OnCompleted was called before the operation could complete, in which case it's the delegate to invoke - /// when the operation does complete. - /// - private Action _continuation; - - /// Initializes the event args. - /// The associated socket. - /// The buffer to use for all operations. - public AwaitableSocketAsyncEventArgs(Socket socket, byte[] buffer) - { - Debug.Assert(socket != null); - Debug.Assert(buffer != null && buffer.Length > 0); - - // Store the socket into the base's UserToken. This avoids the need for an extra field, at the expense - // of an object=>Socket cast when we need to access it, which is only once per operation. - UserToken = socket; - - // Store the buffer for use by all operations with this instance. - SetBuffer(buffer, 0, buffer.Length); - - // Hook up the completed event. - Completed += delegate - { - // When the operation completes, see if OnCompleted was already called to hook up a continuation. - // If it was, invoke the continuation. - Action c = _continuation; - if (c != null) - { - c(); - } - else - { - // We may be racing with OnCompleted, so check with synchronization, trying to swap in our - // completion sentinel. If we lose the race and OnCompleted did hook up a continuation, - // invoke it. Otherwise, there's nothing more to be done. - Interlocked.CompareExchange(ref _continuation, s_completedSentinel, null)?.Invoke(); - } - }; - } - - /// Initiates a receive operation on the associated socket. - /// This instance. - public AwaitableSocketAsyncEventArgs ReceiveAsync() - { - if (!Socket.ReceiveAsync(this)) - { - _continuation = s_completedSentinel; - } - return this; - } - - /// Gets this instance. - public AwaitableSocketAsyncEventArgs GetAwaiter() => this; - - /// Gets whether the operation has already completed. - /// - /// This is not a generically usable IsCompleted operation that suggests the whole operation has completed. - /// Rather, it's specifically used as part of the await pattern, and is only usable to determine whether the - /// operation has completed by the time the instance is awaited. - /// - public bool IsCompleted => _continuation != null; - - /// Same as - public void UnsafeOnCompleted(Action continuation) => OnCompleted(continuation); - - /// Queues the provided continuation to be executed once the operation has completed. - public void OnCompleted(Action continuation) - { - if (ReferenceEquals(_continuation, s_completedSentinel) || - ReferenceEquals(Interlocked.CompareExchange(ref _continuation, continuation, null), s_completedSentinel)) - { - Task.Run(continuation); - } - } - - /// Gets the result of the completion operation. - /// Number of bytes transferred. - /// - /// Unlike Task's awaiter's GetResult, this does not block until the operation completes: it must only - /// be used once the operation has completed. This is handled implicitly by await. - /// - public int GetResult() - { - _continuation = null; - if (SocketError != SocketError.Success) - { - ThrowIOSocketException(); - } - return BytesTransferred; - } - - /// Gets the associated socket. - internal Socket Socket => (Socket)UserToken; // stored in the base's UserToken to avoid an extra field in the object - - /// Throws an IOException wrapping a SocketException using the current . - [MethodImpl(MethodImplOptions.NoInlining)] - private void ThrowIOSocketException() - { - var se = new SocketException((int)SocketError); - throw new IOException(SR.Format(SR.net_io_readfailure, se.Message), se); - } - } } } diff --git a/src/System.Net.Sockets/src/System/Net/Sockets/Socket.Tasks.cs b/src/System.Net.Sockets/src/System/Net/Sockets/Socket.Tasks.cs index fc796379ec02..12c77e21cac4 100644 --- a/src/System.Net.Sockets/src/System/Net/Sockets/Socket.Tasks.cs +++ b/src/System.Net.Sockets/src/System/Net/Sockets/Socket.Tasks.cs @@ -46,12 +46,12 @@ public partial class Socket private static readonly Task s_zeroTask = Task.FromResult(0); /// Cached event args used with Task-based async operations. - private CachedTaskEventArgs _cachedTaskEventArgs; + private CachedEventArgs _cachedTaskEventArgs; internal Task AcceptAsync(Socket acceptSocket) { // Get any cached SocketAsyncEventArg we may have. - TaskSocketAsyncEventArgs saea = Interlocked.Exchange(ref LazyInitializer.EnsureInitialized(ref _cachedTaskEventArgs).Accept, s_rentedSocketSentinel); + TaskSocketAsyncEventArgs saea = Interlocked.Exchange(ref LazyInitializer.EnsureInitialized(ref _cachedTaskEventArgs).TaskAccept, s_rentedSocketSentinel); if (saea == s_rentedSocketSentinel) { // An instance was once created (or is currently being created elsewhere), but some other @@ -179,24 +179,13 @@ internal Task ConnectAsync(string host, int port) internal Task ReceiveAsync(ArraySegment buffer, SocketFlags socketFlags, bool fromNetworkStream) { - // Validate the arguments. ValidateBuffer(buffer); - - Int32TaskSocketAsyncEventArgs saea = RentSocketAsyncEventArgs(isReceive: true); - if (saea != null) - { - // We got a cached instance. Configure the buffer and initate the operation. - ConfigureBuffer(saea, buffer, socketFlags, wrapExceptionsInIOExceptions: fromNetworkStream); - return GetTaskForSendReceive(ReceiveAsync(saea), saea, fromNetworkStream, isReceive: true); - } - else - { - // We couldn't get a cached instance, due to a concurrent receive operation on the socket. - // Fall back to wrapping APM. - return ReceiveAsyncApm(buffer, socketFlags); - } + return ReceiveAsync((Memory)buffer, socketFlags, fromNetworkStream, default).AsTask(); } + // TODO https://github.com/dotnet/corefx/issues/24430: + // Fully plumb cancellation down into socket operations. + internal ValueTask ReceiveAsync(Memory buffer, SocketFlags socketFlags, bool fromNetworkStream, CancellationToken cancellationToken) { if (cancellationToken.IsCancellationRequested) @@ -204,15 +193,14 @@ internal ValueTask ReceiveAsync(Memory buffer, SocketFlags socketFlag return new ValueTask(Task.FromCanceled(cancellationToken)); } - // TODO https://github.com/dotnet/corefx/issues/24430: - // Fully plumb cancellation down into socket operations. - - Int32TaskSocketAsyncEventArgs saea = RentSocketAsyncEventArgs(isReceive: true); - if (saea != null) + AwaitableSocketAsyncEventArgs saea = LazyInitializer.EnsureInitialized(ref LazyInitializer.EnsureInitialized(ref _cachedTaskEventArgs).ValueTaskReceive); + if (saea.Reserve()) { - // We got a cached instance. Configure the buffer and initate the operation. - ConfigureBuffer(saea, buffer, socketFlags, wrapExceptionsInIOExceptions: fromNetworkStream); - return GetValueTaskForSendReceive(ReceiveAsync(saea), saea, fromNetworkStream, isReceive: true); + if (saea.BufferList != null) saea.BufferList = null; + saea.SetBuffer(buffer); + saea.SocketFlags = socketFlags; + saea.WrapExceptionsInIOExceptions = fromNetworkStream; + return saea.ReceiveAsync(this); } else { @@ -341,48 +329,57 @@ internal Task ReceiveMessageFromAsync(ArraySegme return tcs.Task; } - internal Task SendAsync(ArraySegment buffer, SocketFlags socketFlags, bool fromNetworkStream) + internal Task SendAsync(ArraySegment buffer, SocketFlags socketFlags) { - // Validate the arguments. ValidateBuffer(buffer); - - Int32TaskSocketAsyncEventArgs saea = RentSocketAsyncEventArgs(isReceive: false); - if (saea != null) + return SendAsync((ReadOnlyMemory)buffer, socketFlags, default).AsTask(); + } + + internal ValueTask SendAsync(ReadOnlyMemory buffer, SocketFlags socketFlags, CancellationToken cancellationToken) + { + if (cancellationToken.IsCancellationRequested) { - // We got a cached instance. Configure the buffer and initate the operation. - ConfigureBuffer(saea, buffer, socketFlags, wrapExceptionsInIOExceptions: fromNetworkStream); - return GetTaskForSendReceive(SendAsync(saea), saea, fromNetworkStream, isReceive: false); + return new ValueTask(Task.FromCanceled(cancellationToken)); + } + + AwaitableSocketAsyncEventArgs saea = LazyInitializer.EnsureInitialized(ref LazyInitializer.EnsureInitialized(ref _cachedTaskEventArgs).ValueTaskSend); + if (saea.Reserve()) + { + if (saea.BufferList != null) saea.BufferList = null; + saea.SetBuffer(MemoryMarshal.AsMemory(buffer)); + saea.SocketFlags = socketFlags; + saea.WrapExceptionsInIOExceptions = false; + return saea.SendAsync(this); } else { // We couldn't get a cached instance, due to a concurrent send operation on the socket. // Fall back to wrapping APM. - return SendAsyncApm(buffer, socketFlags); + return new ValueTask(SendAsyncApm(buffer, socketFlags)); } } - internal ValueTask SendAsync(ReadOnlyMemory buffer, SocketFlags socketFlags, bool fromNetworkStream, CancellationToken cancellationToken) + internal ValueTask SendAsyncForNetworkStream(ReadOnlyMemory buffer, SocketFlags socketFlags, CancellationToken cancellationToken) { if (cancellationToken.IsCancellationRequested) { - return new ValueTask(Task.FromCanceled(cancellationToken)); + return new ValueTask(Task.FromCanceled(cancellationToken)); } - // TODO https://github.com/dotnet/corefx/issues/24430: - // Fully plumb cancellation down into socket operations. - - Int32TaskSocketAsyncEventArgs saea = RentSocketAsyncEventArgs(isReceive: false); - if (saea != null) + AwaitableSocketAsyncEventArgs saea = LazyInitializer.EnsureInitialized(ref LazyInitializer.EnsureInitialized(ref _cachedTaskEventArgs).ValueTaskSend); + if (saea.Reserve()) { - // We got a cached instance. Configure the buffer and initate the operation. - ConfigureBuffer(saea, MemoryMarshal.AsMemory(buffer), socketFlags, wrapExceptionsInIOExceptions: fromNetworkStream); - return GetValueTaskForSendReceive(SendAsync(saea), saea, fromNetworkStream, isReceive: false); + if (saea.BufferList != null) saea.BufferList = null; + saea.SetBuffer(MemoryMarshal.AsMemory(buffer)); + saea.SocketFlags = socketFlags; + saea.WrapExceptionsInIOExceptions = true; + return saea.SendAsyncForNetworkStream(this); } else { // We couldn't get a cached instance, due to a concurrent send operation on the socket. // Fall back to wrapping APM. - return new ValueTask(SendAsyncApm(buffer, socketFlags)); + return new ValueTask(SendAsyncApm(buffer, socketFlags)); } } @@ -502,19 +499,6 @@ private static void ValidateBuffersList(IList> buffers) } } - private static void ConfigureBuffer( - Int32TaskSocketAsyncEventArgs saea, Memory buffer, SocketFlags socketFlags, bool wrapExceptionsInIOExceptions) - { - // Configure the buffer. We don't clear the buffers when returning the SAEA to the pool, - // so as to minimize overhead if the same buffer is used for subsequent operations (which is likely). - // But SAEA doesn't support having both a buffer and a buffer list configured, so clear out a buffer list - // if there is one before we set the desired buffer. - if (saea.BufferList != null) saea.BufferList = null; - saea.SetBuffer(buffer); - saea.SocketFlags = socketFlags; - saea._wrapExceptionsInIOExceptions = wrapExceptionsInIOExceptions; - } - private static void ConfigureBufferList( Int32TaskSocketAsyncEventArgs saea, IList> buffers, SocketFlags socketFlags) { @@ -572,16 +556,10 @@ private Task GetTaskForSendReceive( } else { - // Get any cached, successfully-completed cached task that may exist on this SAEA. - Task lastTask = saea._successfullyCompletedTask; - Debug.Assert(lastTask == null || lastTask.IsCompletedSuccessfully); - // If there is a task and if it has the desired result, simply reuse it. // Otherwise, create a new one for this result value, and in addition to returning it, // also store it into the SAEA for potential future reuse. - t = lastTask != null && lastTask.Result == bytesTransferred ? - lastTask : - (saea._successfullyCompletedTask = Task.FromResult(bytesTransferred)); + t = Task.FromResult(bytesTransferred); } } else @@ -596,48 +574,6 @@ private Task GetTaskForSendReceive( return t; } - /// Gets a value task to represent the operation. - /// true if the operation completes asynchronously; false if it completed synchronously. - /// The event args instance used with the operation. - /// - /// true if the request is coming from NetworkStream, which has special semantics for - /// exceptions and cached tasks; otherwise, false. - /// - /// true if this is a receive; false if this is a send. - private ValueTask GetValueTaskForSendReceive( - bool pending, Int32TaskSocketAsyncEventArgs saea, - bool fromNetworkStream, bool isReceive) - { - ValueTask t; - - if (pending) - { - // The operation is completing asynchronously (it may have already completed). - // Get the task for the operation, with appropriate synchronization to coordinate - // with the async callback that'll be completing the task. - bool responsibleForReturningToPool; - t = new ValueTask(saea.GetCompletionResponsibility(out responsibleForReturningToPool).Task); - if (responsibleForReturningToPool) - { - // We're responsible for returning it only if the callback has already been invoked - // and gotten what it needs from the SAEA; otherwise, the callback will return it. - ReturnSocketAsyncEventArgs(saea, isReceive); - } - } - else - { - // The operation completed synchronously. Return a ValueTask for it. - t = saea.SocketError == SocketError.Success ? - new ValueTask(saea.BytesTransferred) : - new ValueTask(Task.FromException(GetException(saea.SocketError, wrapExceptionsInIOExceptions: fromNetworkStream))); - - // There won't be a callback, and we're done with the SAEA, so return it to the pool. - ReturnSocketAsyncEventArgs(saea, isReceive); - } - - return t; - } - /// Completes the SocketAsyncEventArg's Task with the result of the send or receive, and returns it to the specified pool. private static void CompleteAccept(Socket s, TaskSocketAsyncEventArgs saea) { @@ -709,10 +645,10 @@ private static Exception GetException(SocketError error, bool wrapExceptionsInIO private Int32TaskSocketAsyncEventArgs RentSocketAsyncEventArgs(bool isReceive) { // Get any cached SocketAsyncEventArg we may have. - CachedTaskEventArgs cea = LazyInitializer.EnsureInitialized(ref _cachedTaskEventArgs); + CachedEventArgs cea = LazyInitializer.EnsureInitialized(ref _cachedTaskEventArgs); Int32TaskSocketAsyncEventArgs saea = isReceive ? - Interlocked.Exchange(ref cea.Receive, s_rentedInt32Sentinel) : - Interlocked.Exchange(ref cea.Send, s_rentedInt32Sentinel); + Interlocked.Exchange(ref cea.TaskReceive, s_rentedInt32Sentinel) : + Interlocked.Exchange(ref cea.TaskSend, s_rentedInt32Sentinel); if (saea == s_rentedInt32Sentinel) { @@ -752,13 +688,13 @@ private void ReturnSocketAsyncEventArgs(Int32TaskSocketAsyncEventArgs saea, bool // never null or another instance. if (isReceive) { - Debug.Assert(_cachedTaskEventArgs.Receive == s_rentedInt32Sentinel); - Volatile.Write(ref _cachedTaskEventArgs.Receive, saea); + Debug.Assert(_cachedTaskEventArgs.TaskReceive == s_rentedInt32Sentinel); + Volatile.Write(ref _cachedTaskEventArgs.TaskReceive, saea); } else { - Debug.Assert(_cachedTaskEventArgs.Send == s_rentedInt32Sentinel); - Volatile.Write(ref _cachedTaskEventArgs.Send, saea); + Debug.Assert(_cachedTaskEventArgs.TaskSend == s_rentedInt32Sentinel); + Volatile.Write(ref _cachedTaskEventArgs.TaskSend, saea); } } @@ -779,19 +715,21 @@ private void ReturnSocketAsyncEventArgs(TaskSocketAsyncEventArgs saea) // Write this instance back as a cached instance. It should only ever be overwriting the sentinel, // never null or another instance. - Debug.Assert(_cachedTaskEventArgs.Accept == s_rentedSocketSentinel); - Volatile.Write(ref _cachedTaskEventArgs.Accept, saea); + Debug.Assert(_cachedTaskEventArgs.TaskAccept == s_rentedSocketSentinel); + Volatile.Write(ref _cachedTaskEventArgs.TaskAccept, saea); } /// Dispose of any cached instances. private void DisposeCachedTaskSocketAsyncEventArgs() { - CachedTaskEventArgs cea = _cachedTaskEventArgs; + CachedEventArgs cea = _cachedTaskEventArgs; if (cea != null) { - Interlocked.Exchange(ref cea.Accept, s_rentedSocketSentinel)?.Dispose(); - Interlocked.Exchange(ref cea.Receive, s_rentedInt32Sentinel)?.Dispose(); - Interlocked.Exchange(ref cea.Send, s_rentedInt32Sentinel)?.Dispose(); + Interlocked.Exchange(ref cea.TaskAccept, s_rentedSocketSentinel)?.Dispose(); + Interlocked.Exchange(ref cea.TaskReceive, s_rentedInt32Sentinel)?.Dispose(); + Interlocked.Exchange(ref cea.TaskSend, s_rentedInt32Sentinel)?.Dispose(); + Interlocked.Exchange(ref cea.ValueTaskReceive, AwaitableSocketAsyncEventArgs.Reserved)?.Dispose(); + Interlocked.Exchange(ref cea.ValueTaskSend, AwaitableSocketAsyncEventArgs.Reserved)?.Dispose(); } } @@ -810,14 +748,18 @@ public StateTaskCompletionSource(object baseState) : base(baseState) { } } /// Cached event args used with Task-based async operations. - private sealed class CachedTaskEventArgs + private sealed class CachedEventArgs { /// Cached instance for accept operations. - public TaskSocketAsyncEventArgs Accept; - /// Cached instance for receive operations. - public Int32TaskSocketAsyncEventArgs Receive; - /// Cached instance for send operations. - public Int32TaskSocketAsyncEventArgs Send; + public TaskSocketAsyncEventArgs TaskAccept; + /// Cached instance for receive operations that return . + public Int32TaskSocketAsyncEventArgs TaskReceive; + /// Cached instance for send operations that return . + public Int32TaskSocketAsyncEventArgs TaskSend; + /// Cached instance for receive operations that return . + public AwaitableSocketAsyncEventArgs ValueTaskReceive; + /// Cached instance for send operations that return . + public AwaitableSocketAsyncEventArgs ValueTaskSend; } /// A SocketAsyncEventArgs with an associated async method builder. @@ -854,10 +796,246 @@ internal AsyncTaskMethodBuilder GetCompletionResponsibility(out bool re /// A SocketAsyncEventArgs with an associated async method builder. private sealed class Int32TaskSocketAsyncEventArgs : TaskSocketAsyncEventArgs { - /// A cached, successfully completed task. - internal Task _successfullyCompletedTask; /// Whether exceptions that emerge should be wrapped in IOExceptions. internal bool _wrapExceptionsInIOExceptions; } + + /// A SocketAsyncEventArgs that can be awaited to get the result of an operation. + internal sealed class AwaitableSocketAsyncEventArgs : SocketAsyncEventArgs, IValueTaskSource, IValueTaskSource + { + internal static readonly AwaitableSocketAsyncEventArgs Reserved = new AwaitableSocketAsyncEventArgs() { _continuation = null }; + /// Sentinel object used to indicate that the operation has completed prior to OnCompleted being called. + private static readonly Action s_completedSentinel = state => throw new Exception(nameof(s_completedSentinel)); + /// Sentinel object used to indicate that the instance is available for use. + private static readonly Action s_availableSentinel = state => throw new Exception(nameof(s_availableSentinel)); + /// Event handler for the Completed event. + private static readonly EventHandler s_completedHandler = (s, e) => + { + // When the operation completes, see if OnCompleted was already called to hook up a continuation. + // If it was, invoke the continuation. + AwaitableSocketAsyncEventArgs ea = (AwaitableSocketAsyncEventArgs)e; + Action c = ea._continuation; + if (c != null || (c = Interlocked.CompareExchange(ref ea._continuation, s_completedSentinel, null)) != null) + { + Debug.Assert(c != s_availableSentinel, "The delegate should not have been the available sentinel."); + Debug.Assert(c != s_completedSentinel, "The delegate should not have been the completed sentinel."); + object continuationState = ea.UserToken; + ea.UserToken = null; + ea._continuation = s_completedSentinel; // in case someone's polling IsCompleted + ea.InvokeContinuation(c, continuationState, forceAsync: false); + } + }; + /// + /// if the object is available for use, after GetResult has been called on a previous use. + /// null if the operation has not completed. + /// if it has completed. + /// Another delegate if OnCompleted was called before the operation could complete, in which case it's the delegate to invoke + /// when the operation does complete. + /// + private Action _continuation = s_availableSentinel; + private ExecutionContext _executionContext; + private object _scheduler; + + /// Initializes the event args. + /// The associated socket. + /// The buffer to use for all operations. + public AwaitableSocketAsyncEventArgs() => Completed += s_completedHandler; + + public bool WrapExceptionsInIOExceptions { get; set; } + + public bool Reserve() => Interlocked.CompareExchange(ref _continuation, null, s_availableSentinel) == s_availableSentinel; + + private void Release() => Volatile.Write(ref _continuation, s_availableSentinel); + + /// Initiates a receive operation on the associated socket. + /// This instance. + public ValueTask ReceiveAsync(Socket socket) + { + Debug.Assert(Volatile.Read(ref _continuation) == null, $"Expected null continuation to indicate reserved for use"); + + if (socket.ReceiveAsync(this)) + { + return new ValueTask(this); + } + + int bytesTransferred = BytesTransferred; + SocketError error = SocketError; + + Release(); + + return error == SocketError.Success ? + new ValueTask(bytesTransferred) : + new ValueTask(Task.FromException(CreateException(error))); + } + + /// Initiates a send operation on the associated socket. + /// This instance. + public ValueTask SendAsync(Socket socket) + { + Debug.Assert(Volatile.Read(ref _continuation) == null, $"Expected null continuation to indicate reserved for use"); + + if (socket.SendAsync(this)) + { + return new ValueTask(this); + } + + int bytesTransferred = BytesTransferred; + SocketError error = SocketError; + + Release(); + + return error == SocketError.Success ? + new ValueTask(bytesTransferred) : + new ValueTask(Task.FromException(CreateException(error))); + } + + public ValueTask SendAsyncForNetworkStream(Socket socket) + { + Debug.Assert(Volatile.Read(ref _continuation) == null, $"Expected null continuation to indicate reserved for use"); + + if (socket.SendAsync(this)) + { + return new ValueTask(this); + } + + SocketError error = SocketError; + + Release(); + + return error == SocketError.Success ? + default : + new ValueTask(Task.FromException(CreateException(error))); + } + + /// Gets the status of the operation. + public ValueTaskSourceStatus Status => + _continuation != s_completedSentinel ? ValueTaskSourceStatus.Pending : + base.SocketError == SocketError.Success ? ValueTaskSourceStatus.Succeeded : + ValueTaskSourceStatus.Faulted; + + /// Queues the provided continuation to be executed once the operation has completed. + public void OnCompleted(Action continuation, object state, ValueTaskSourceOnCompletedFlags flags) + { + if ((flags & ValueTaskSourceOnCompletedFlags.FlowExecutionContext) != 0) + { + _executionContext = ExecutionContext.Capture(); + } + + if ((flags & ValueTaskSourceOnCompletedFlags.UseSchedulingContext) != 0) + { + SynchronizationContext sc = SynchronizationContext.Current; + if (sc != null && sc.GetType() != typeof(SynchronizationContext)) + { + _scheduler = sc; + } + else + { + TaskScheduler ts = TaskScheduler.Current; + if (ts != TaskScheduler.Default) + { + _scheduler = ts; + } + } + } + + UserToken = state; // Use UserToken to carry the continuation state around + if (ReferenceEquals(Interlocked.CompareExchange(ref _continuation, continuation, null), s_completedSentinel)) + { + UserToken = null; + InvokeContinuation(continuation, state, forceAsync: true); + } + } + + private void InvokeContinuation(Action continuation, object state, bool forceAsync) + { + ExecutionContext ec = _executionContext; + if (ec == null) + { + InvokeContinuationCore(continuation, state, forceAsync); + } + else + { + _executionContext = null; + ExecutionContext.Run(ec, s => + { + var t = (Tuple, object, bool>)s; + t.Item1.InvokeContinuationCore(t.Item2, t.Item3, t.Item4); + }, Tuple.Create(this, continuation, state, forceAsync)); + } + } + + private void InvokeContinuationCore(Action continuation, object state, bool forceAsync) + { + object scheduler = _scheduler; + _scheduler = null; + + if (scheduler != null) + { + if (scheduler is SynchronizationContext sc) + { + sc.Post(s => + { + var t = (Tuple, object>)s; + t.Item1(t.Item2); + }, Tuple.Create(continuation, state)); + } + else + { + Task.Factory.StartNew(continuation, state, CancellationToken.None, TaskCreationOptions.DenyChildAttach, (TaskScheduler)scheduler); + } + } + else if (forceAsync) + { + // TODO #27464: Use QueueUserWorkItem when it has a compatible signature. + Task.Factory.StartNew(continuation, state, CancellationToken.None, TaskCreationOptions.DenyChildAttach, TaskScheduler.Default); + } + else + { + continuation(state); + } + } + + /// Gets the result of the completion operation. + /// Number of bytes transferred. + /// + /// Unlike Task's awaiter's GetResult, this does not block until the operation completes: it must only + /// be used once the operation has completed. This is handled implicitly by await. + /// + public int GetResult() + { + SocketError error = SocketError; + int bytes = BytesTransferred; + + Release(); + + if (error != SocketError.Success) + { + ThrowException(error); + } + return bytes; + } + + void IValueTaskSource.GetResult() + { + SocketError error = SocketError; + + Release(); + + if (error != SocketError.Success) + { + ThrowException(error); + } + } + + private void ThrowException(SocketError error) => throw CreateException(error); + + private Exception CreateException(SocketError error) + { + var se = new SocketException((int)error); + return WrapExceptionsInIOExceptions ? (Exception) + new IOException(SR.Format(SR.net_io_readfailure, se.Message), se) : + se; + } + } } } diff --git a/src/System.Net.Sockets/src/System/Net/Sockets/SocketTaskExtensions.cs b/src/System.Net.Sockets/src/System/Net/Sockets/SocketTaskExtensions.cs index 7407206222de..e32adfba412b 100644 --- a/src/System.Net.Sockets/src/System/Net/Sockets/SocketTaskExtensions.cs +++ b/src/System.Net.Sockets/src/System/Net/Sockets/SocketTaskExtensions.cs @@ -36,9 +36,9 @@ public static Task ReceiveMessageFromAsync(this socket.ReceiveMessageFromAsync(buffer, socketFlags, remoteEndPoint); public static Task SendAsync(this Socket socket, ArraySegment buffer, SocketFlags socketFlags) => - socket.SendAsync(buffer, socketFlags, fromNetworkStream: false); + socket.SendAsync(buffer, socketFlags); public static ValueTask SendAsync(this Socket socket, ReadOnlyMemory buffer, SocketFlags socketFlags, CancellationToken cancellationToken = default) => - socket.SendAsync(buffer, socketFlags, fromNetworkStream: false, cancellationToken: cancellationToken); + socket.SendAsync(buffer, socketFlags, cancellationToken); public static Task SendAsync(this Socket socket, IList> buffers, SocketFlags socketFlags) => socket.SendAsync(buffers, socketFlags); public static Task SendToAsync(this Socket socket, ArraySegment buffer, SocketFlags socketFlags, EndPoint remoteEP) => diff --git a/src/System.Net.Sockets/tests/FunctionalTests/NetworkStreamTest.cs b/src/System.Net.Sockets/tests/FunctionalTests/NetworkStreamTest.cs index d4a50d03a6c2..af10c538d730 100644 --- a/src/System.Net.Sockets/tests/FunctionalTests/NetworkStreamTest.cs +++ b/src/System.Net.Sockets/tests/FunctionalTests/NetworkStreamTest.cs @@ -454,6 +454,21 @@ public async Task ReadableWriteableProperties_Roundtrip() } } + [Fact] + public async Task ReadWrite_Byte_Success() + { + await RunWithConnectedNetworkStreamsAsync(async (server, client) => + { + for (byte i = 0; i < 10; i++) + { + Task read = Task.Run(() => client.ReadByte()); + Task write = Task.Run(() => server.WriteByte(i)); + await Task.WhenAll(read, write); + Assert.Equal(i, await read); + } + }); + } + [Fact] public async Task ReadWrite_Array_Success() { diff --git a/src/System.Net.Sockets/tests/FunctionalTests/NetworkStreamTest.netcoreapp.cs b/src/System.Net.Sockets/tests/FunctionalTests/NetworkStreamTest.netcoreapp.cs index bcf29d4bb539..98a06da7e2b1 100644 --- a/src/System.Net.Sockets/tests/FunctionalTests/NetworkStreamTest.netcoreapp.cs +++ b/src/System.Net.Sockets/tests/FunctionalTests/NetworkStreamTest.netcoreapp.cs @@ -2,6 +2,9 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System.Collections.Generic; +using System.Linq; +using System.Security.Cryptography; using System.Threading; using System.Threading.Tasks; using Xunit; @@ -43,17 +46,303 @@ await RunWithConnectedNetworkStreamsAsync(async (server, client) => }); } + [Fact] + public async Task ReadWrite_Memory_LargeWrite_Success() + { + await RunWithConnectedNetworkStreamsAsync(async (server, client) => + { + byte[] writeBuffer = new byte[10 * 1024 * 1024]; + byte[] readBuffer = new byte[writeBuffer.Length]; + RandomNumberGenerator.Fill(writeBuffer); + + ValueTask writeTask = client.WriteAsync((ReadOnlyMemory)writeBuffer); + + int totalRead = 0; + while (totalRead < readBuffer.Length) + { + int bytesRead = await server.ReadAsync(new Memory(readBuffer).Slice(totalRead)); + Assert.InRange(bytesRead, 0, int.MaxValue); + if (bytesRead == 0) + { + break; + } + totalRead += bytesRead; + } + Assert.Equal(readBuffer.Length, totalRead); + Assert.Equal(writeBuffer, readBuffer); + + await writeTask; + }); + } + [Fact] public async Task ReadWrite_Precanceled_Throws() { await RunWithConnectedNetworkStreamsAsync(async (server, client) => { - await Assert.ThrowsAnyAsync(() => server.WriteAsync((ArraySegment)new byte[0], new CancellationToken(true))); - await Assert.ThrowsAnyAsync(() => server.ReadAsync((ArraySegment)new byte[0], new CancellationToken(true)).AsTask()); + await Assert.ThrowsAnyAsync(async () => await server.WriteAsync((ArraySegment)new byte[0], new CancellationToken(true))); + await Assert.ThrowsAnyAsync(async () => await server.ReadAsync((ArraySegment)new byte[0], new CancellationToken(true))); + + await Assert.ThrowsAnyAsync(async () => await server.WriteAsync((ReadOnlyMemory)new byte[0], new CancellationToken(true))); + await Assert.ThrowsAnyAsync(async () => await server.ReadAsync((Memory)new byte[0], new CancellationToken(true))); + }); + } + + [Fact] + public async Task ReadAsync_MultipleConcurrentValueTaskReads_Success() + { + await RunWithConnectedNetworkStreamsAsync(async (server, client) => + { + // Technically this isn't supported behavior, but it happens to work because it's supported on socket. + // So validate it to alert us to any potential future breaks. + + byte[] b1 = new byte[1], b2 = new byte[1], b3 = new byte[1]; + ValueTask r1 = server.ReadAsync(b1); + ValueTask r2 = server.ReadAsync(b2); + ValueTask r3 = server.ReadAsync(b3); + + await client.WriteAsync(new byte[] { 42, 43, 44 }); + + Assert.Equal(3, await r1 + await r2 + await r3); + Assert.Equal(42 + 43 + 44, b1[0] + b2[0] + b3[0]); + }); + } + + [Fact] + public async Task ReadAsync_MultipleConcurrentValueTaskReads_AsTask_Success() + { + await RunWithConnectedNetworkStreamsAsync(async (server, client) => + { + // Technically this isn't supported behavior, but it happens to work because it's supported on socket. + // So validate it to alert us to any potential future breaks. + + byte[] b1 = new byte[1], b2 = new byte[1], b3 = new byte[1]; + Task r1 = server.ReadAsync((Memory)b1).AsTask(); + Task r2 = server.ReadAsync((Memory)b2).AsTask(); + Task r3 = server.ReadAsync((Memory)b3).AsTask(); + + await client.WriteAsync(new byte[] { 42, 43, 44 }); + + Assert.Equal(3, await r1 + await r2 + await r3); + Assert.Equal(42 + 43 + 44, b1[0] + b2[0] + b3[0]); + }); + } + + [Fact] + public async Task WriteAsync_MultipleConcurrentValueTaskWrites_Success() + { + await RunWithConnectedNetworkStreamsAsync(async (server, client) => + { + // Technically this isn't supported behavior, but it happens to work because it's supported on socket. + // So validate it to alert us to any potential future breaks. + + ValueTask s1 = server.WriteAsync(new ReadOnlyMemory(new byte[] { 42 })); + ValueTask s2 = server.WriteAsync(new ReadOnlyMemory(new byte[] { 43 })); + ValueTask s3 = server.WriteAsync(new ReadOnlyMemory(new byte[] { 44 })); + + byte[] b1 = new byte[1], b2 = new byte[1], b3 = new byte[1]; + Assert.Equal(3, + await client.ReadAsync((Memory)b1) + + await client.ReadAsync((Memory)b2) + + await client.ReadAsync((Memory)b3)); + + await s1; + await s2; + await s3; + + Assert.Equal(42 + 43 + 44, b1[0] + b2[0] + b3[0]); + }); + } + + [Fact] + public async Task WriteAsync_MultipleConcurrentValueTaskWrites_AsTask_Success() + { + await RunWithConnectedNetworkStreamsAsync(async (server, client) => + { + // Technically this isn't supported behavior, but it happens to work because it's supported on socket. + // So validate it to alert us to any potential future breaks. + + Task s1 = server.WriteAsync(new ReadOnlyMemory(new byte[] { 42 })).AsTask(); + Task s2 = server.WriteAsync(new ReadOnlyMemory(new byte[] { 43 })).AsTask(); + Task s3 = server.WriteAsync(new ReadOnlyMemory(new byte[] { 44 })).AsTask(); + + byte[] b1 = new byte[1], b2 = new byte[1], b3 = new byte[1]; + Task r1 = client.ReadAsync((Memory)b1).AsTask(); + Task r2 = client.ReadAsync((Memory)b2).AsTask(); + Task r3 = client.ReadAsync((Memory)b3).AsTask(); + + await Task.WhenAll(s1, s2, s3, r1, r2, r3); - await Assert.ThrowsAnyAsync(() => server.WriteAsync((ReadOnlyMemory)new byte[0], new CancellationToken(true))); - await Assert.ThrowsAnyAsync(() => server.ReadAsync((Memory)new byte[0], new CancellationToken(true)).AsTask()); + Assert.Equal(3, await r1 + await r2 + await r3); + Assert.Equal(42 + 43 + 44, b1[0] + b2[0] + b3[0]); }); } + + public static IEnumerable ReadAsync_ContinuesOnCurrentContextIfDesired_MemberData() => + from flowExecutionContext in new[] { true, false } + from continueOnCapturedContext in new bool?[] { null, false, true } + select new object[] { flowExecutionContext, continueOnCapturedContext }; + + [Theory] + [MemberData(nameof(ReadAsync_ContinuesOnCurrentContextIfDesired_MemberData))] + public async Task ReadAsync_ContinuesOnCurrentSynchronizationContextIfDesired( + bool flowExecutionContext, bool? continueOnCapturedContext) + { + await Task.Run(async () => // escape xunit sync ctx + { + await RunWithConnectedNetworkStreamsAsync(async (server, client) => + { + Assert.Null(SynchronizationContext.Current); + + var continuationRan = new TaskCompletionSource(); + var asyncLocal = new AsyncLocal(); + bool schedulerWasFlowed = false; + bool executionContextWasFlowed = false; + Action continuation = () => + { + schedulerWasFlowed = SynchronizationContext.Current is CustomSynchronizationContext; + executionContextWasFlowed = 42 == asyncLocal.Value; + continuationRan.SetResult(true); + }; + + var readBuffer = new byte[1]; + ValueTask readValueTask = client.ReadAsync((Memory)new byte[1]); + + SynchronizationContext.SetSynchronizationContext(new CustomSynchronizationContext()); + asyncLocal.Value = 42; + switch (continueOnCapturedContext) + { + case null: + if (flowExecutionContext) + { + readValueTask.GetAwaiter().OnCompleted(continuation); + } + else + { + readValueTask.GetAwaiter().UnsafeOnCompleted(continuation); + } + break; + default: + if (flowExecutionContext) + { + readValueTask.ConfigureAwait(continueOnCapturedContext.Value).GetAwaiter().OnCompleted(continuation); + } + else + { + readValueTask.ConfigureAwait(continueOnCapturedContext.Value).GetAwaiter().UnsafeOnCompleted(continuation); + } + break; + } + asyncLocal.Value = 0; + SynchronizationContext.SetSynchronizationContext(null); + + Assert.False(readValueTask.IsCompleted); + Assert.False(readValueTask.IsCompletedSuccessfully); + await server.WriteAsync(new byte[] { 42 }); + + await continuationRan.Task; + Assert.True(readValueTask.IsCompleted); + Assert.True(readValueTask.IsCompletedSuccessfully); + + Assert.Equal(continueOnCapturedContext != false, schedulerWasFlowed); + Assert.Equal(flowExecutionContext, executionContextWasFlowed); + }); + }); + } + + [Theory] + [MemberData(nameof(ReadAsync_ContinuesOnCurrentContextIfDesired_MemberData))] + public async Task ReadAsync_ContinuesOnCurrentTaskSchedulerIfDesired( + bool flowExecutionContext, bool? continueOnCapturedContext) + { + await Task.Run(async () => // escape xunit sync ctx + { + await RunWithConnectedNetworkStreamsAsync(async (server, client) => + { + Assert.Null(SynchronizationContext.Current); + + var continuationRan = new TaskCompletionSource(); + var asyncLocal = new AsyncLocal(); + bool schedulerWasFlowed = false; + bool executionContextWasFlowed = false; + Action continuation = () => + { + schedulerWasFlowed = TaskScheduler.Current is CustomTaskScheduler; + executionContextWasFlowed = 42 == asyncLocal.Value; + continuationRan.SetResult(true); + }; + + var readBuffer = new byte[1]; + ValueTask readValueTask = client.ReadAsync((Memory)new byte[1]); + + await Task.Factory.StartNew(() => + { + Assert.IsType(TaskScheduler.Current); + asyncLocal.Value = 42; + switch (continueOnCapturedContext) + { + case null: + if (flowExecutionContext) + { + readValueTask.GetAwaiter().OnCompleted(continuation); + } + else + { + readValueTask.GetAwaiter().UnsafeOnCompleted(continuation); + } + break; + default: + if (flowExecutionContext) + { + readValueTask.ConfigureAwait(continueOnCapturedContext.Value).GetAwaiter().OnCompleted(continuation); + } + else + { + readValueTask.ConfigureAwait(continueOnCapturedContext.Value).GetAwaiter().UnsafeOnCompleted(continuation); + } + break; + } + asyncLocal.Value = 0; + }, CancellationToken.None, TaskCreationOptions.None, new CustomTaskScheduler()); + + Assert.False(readValueTask.IsCompleted); + Assert.False(readValueTask.IsCompletedSuccessfully); + await server.WriteAsync(new byte[] { 42 }); + + await continuationRan.Task; + Assert.True(readValueTask.IsCompleted); + Assert.True(readValueTask.IsCompletedSuccessfully); + + Assert.Equal(continueOnCapturedContext != false, schedulerWasFlowed); + Assert.Equal(flowExecutionContext, executionContextWasFlowed); + }); + }); + } + + private sealed class CustomSynchronizationContext : SynchronizationContext + { + public override void Post(SendOrPostCallback d, object state) + { + ThreadPool.QueueUserWorkItem(delegate + { + SetSynchronizationContext(this); + try + { + d(state); + } + finally + { + SetSynchronizationContext(null); + } + }, null); + } + } + + private sealed class CustomTaskScheduler : TaskScheduler + { + protected override void QueueTask(Task task) => ThreadPool.QueueUserWorkItem(_ => TryExecuteTask(task)); + protected override bool TryExecuteTaskInline(Task task, bool taskWasPreviouslyQueued) => false; + protected override IEnumerable GetScheduledTasks() => null; + } } } diff --git a/src/System.Net.Sockets/tests/FunctionalTests/SendReceive.netcoreapp.cs b/src/System.Net.Sockets/tests/FunctionalTests/SendReceive.netcoreapp.cs index 9b511de045bd..ae34a6e90bcf 100644 --- a/src/System.Net.Sockets/tests/FunctionalTests/SendReceive.netcoreapp.cs +++ b/src/System.Net.Sockets/tests/FunctionalTests/SendReceive.netcoreapp.cs @@ -2,10 +2,58 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System.Threading; +using System.Threading.Tasks; +using Xunit; + namespace System.Net.Sockets.Tests { public sealed class SendReceiveSpanSync : SendReceive { } public sealed class SendReceiveSpanSyncForceNonBlocking : SendReceive { } - public sealed class SendReceiveMemoryArrayTask : SendReceive { } + public sealed class SendReceiveMemoryArrayTask : SendReceive + { + [Fact] + public async Task Precanceled_Throws() + { + using (var listener = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) + using (var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) + { + listener.BindToAnonymousPort(IPAddress.Loopback); + listener.Listen(1); + + await client.ConnectAsync(listener.LocalEndPoint); + using (Socket server = await listener.AcceptAsync()) + { + var cts = new CancellationTokenSource(); + cts.Cancel(); + + await Assert.ThrowsAnyAsync(async () => await server.SendAsync((ReadOnlyMemory)new byte[0], SocketFlags.None, cts.Token)); + await Assert.ThrowsAnyAsync(async () => await server.ReceiveAsync((Memory)new byte[0], SocketFlags.None, cts.Token)); + } + } + } + + [Fact] + public async Task DisposedSocket_ThrowsOperationCanceledException() + { + using (var listener = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) + using (var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) + { + listener.BindToAnonymousPort(IPAddress.Loopback); + listener.Listen(1); + + await client.ConnectAsync(listener.LocalEndPoint); + using (Socket server = await listener.AcceptAsync()) + { + var cts = new CancellationTokenSource(); + cts.Cancel(); + + server.Shutdown(SocketShutdown.Both); + await Assert.ThrowsAnyAsync(async () => await server.SendAsync((ReadOnlyMemory)new byte[0], SocketFlags.None, cts.Token)); + await Assert.ThrowsAnyAsync(async () => await server.ReceiveAsync((Memory)new byte[0], SocketFlags.None, cts.Token)); + } + } + } + } public sealed class SendReceiveMemoryNativeTask : SendReceive { } } diff --git a/src/System.Net.Sockets/tests/FunctionalTests/System.Net.Sockets.Tests.csproj b/src/System.Net.Sockets/tests/FunctionalTests/System.Net.Sockets.Tests.csproj index 47411a841d6a..2ea10374e82c 100644 --- a/src/System.Net.Sockets/tests/FunctionalTests/System.Net.Sockets.Tests.csproj +++ b/src/System.Net.Sockets/tests/FunctionalTests/System.Net.Sockets.Tests.csproj @@ -119,5 +119,8 @@ + + + - + \ No newline at end of file diff --git a/src/System.Net.Sockets/tests/FunctionalTests/UnixDomainSocketTest.cs b/src/System.Net.Sockets/tests/FunctionalTests/UnixDomainSocketTest.cs index af17a979cb9d..4e82b1579037 100644 --- a/src/System.Net.Sockets/tests/FunctionalTests/UnixDomainSocketTest.cs +++ b/src/System.Net.Sockets/tests/FunctionalTests/UnixDomainSocketTest.cs @@ -18,6 +18,7 @@ public UnixDomainSocketTest(ITestOutputHelper output) _log = TestLogging.GetInstance(); } + [ActiveIssue(27542)] [OuterLoop] // TODO: Issue #11345 [Fact] [PlatformSpecific(TestPlatforms.Windows)] // CreateUnixDomainSocket should throw on Windows diff --git a/src/System.Net.WebClient/src/System/Net/WebClient.cs b/src/System.Net.WebClient/src/System/Net/WebClient.cs index 364046b724b6..eaf5e9bffc5e 100644 --- a/src/System.Net.WebClient/src/System/Net/WebClient.cs +++ b/src/System.Net.WebClient/src/System/Net/WebClient.cs @@ -888,7 +888,7 @@ private async void DownloadBitsAsync( { while (true) { - int bytesRead = await readStream.ReadAsync(copyBuffer, 0, copyBuffer.Length).ConfigureAwait(false); + int bytesRead = await readStream.ReadAsync(new Memory(copyBuffer)).ConfigureAwait(false); if (bytesRead == 0) { break; @@ -900,7 +900,7 @@ private async void DownloadBitsAsync( PostProgressChanged(asyncOp, _progress); } - await writeStream.WriteAsync(copyBuffer, 0, bytesRead).ConfigureAwait(false); + await writeStream.WriteAsync(new ReadOnlyMemory(copyBuffer, 0, bytesRead)).ConfigureAwait(false); } } @@ -1010,7 +1010,7 @@ private async void UploadBitsAsync( { if (header != null) { - await writeStream.WriteAsync(header, 0, header.Length).ConfigureAwait(false); + await writeStream.WriteAsync(new ReadOnlyMemory(header)).ConfigureAwait(false); _progress.BytesSent += header.Length; PostProgressChanged(asyncOp, _progress); } @@ -1021,9 +1021,9 @@ private async void UploadBitsAsync( { while (true) { - int bytesRead = await readStream.ReadAsync(buffer, 0, buffer.Length).ConfigureAwait(false); + int bytesRead = await readStream.ReadAsync(new Memory(buffer)).ConfigureAwait(false); if (bytesRead <= 0) break; - await writeStream.WriteAsync(buffer, 0, bytesRead).ConfigureAwait(false); + await writeStream.WriteAsync(new ReadOnlyMemory(buffer, 0, bytesRead)).ConfigureAwait(false); _progress.BytesSent += bytesRead; PostProgressChanged(asyncOp, _progress); @@ -1039,7 +1039,7 @@ private async void UploadBitsAsync( { toWrite = chunkSize; } - await writeStream.WriteAsync(buffer, pos, toWrite).ConfigureAwait(false); + await writeStream.WriteAsync(new ReadOnlyMemory(buffer, pos, toWrite)).ConfigureAwait(false); pos += toWrite; _progress.BytesSent += toWrite; PostProgressChanged(asyncOp, _progress); @@ -1048,7 +1048,7 @@ private async void UploadBitsAsync( if (footer != null) { - await writeStream.WriteAsync(footer, 0, footer.Length).ConfigureAwait(false); + await writeStream.WriteAsync(new ReadOnlyMemory(footer)).ConfigureAwait(false); _progress.BytesSent += footer.Length; PostProgressChanged(asyncOp, _progress); } diff --git a/src/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocket.cs b/src/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocket.cs index b0409ce72833..70e35e96feda 100644 --- a/src/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocket.cs +++ b/src/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocket.cs @@ -165,7 +165,7 @@ public override Task SendAsync(ArraySegment buffer, WebSocketMessageType m return _innerWebSocket.SendAsync(buffer, messageType, endOfMessage, cancellationToken); } - public override Task SendAsync(ReadOnlyMemory buffer, WebSocketMessageType messageType, bool endOfMessage, CancellationToken cancellationToken) + public override ValueTask SendAsync(ReadOnlyMemory buffer, WebSocketMessageType messageType, bool endOfMessage, CancellationToken cancellationToken) { ThrowIfNotConnected(); return _innerWebSocket.SendAsync(buffer, messageType, endOfMessage, cancellationToken); diff --git a/src/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Managed.cs b/src/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Managed.cs index d57381419a12..a84c422e56cc 100644 --- a/src/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Managed.cs +++ b/src/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Managed.cs @@ -52,7 +52,7 @@ public void Abort() public Task SendAsync(ArraySegment buffer, WebSocketMessageType messageType, bool endOfMessage, CancellationToken cancellationToken) => _webSocket.SendAsync(buffer, messageType, endOfMessage, cancellationToken); - public Task SendAsync(ReadOnlyMemory buffer, WebSocketMessageType messageType, bool endOfMessage, CancellationToken cancellationToken) => + public ValueTask SendAsync(ReadOnlyMemory buffer, WebSocketMessageType messageType, bool endOfMessage, CancellationToken cancellationToken) => _webSocket.SendAsync(buffer, messageType, endOfMessage, cancellationToken); public Task ReceiveAsync(ArraySegment buffer, CancellationToken cancellationToken) => diff --git a/src/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Windows.cs b/src/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Windows.cs index a7f9d1083d7f..f496886b70e2 100644 --- a/src/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Windows.cs +++ b/src/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Windows.cs @@ -80,7 +80,7 @@ public Task SendAsync( return _webSocket.SendAsync(buffer, messageType, endOfMessage, cancellationToken); } - public Task SendAsync(ReadOnlyMemory buffer, WebSocketMessageType messageType, bool endOfMessage, CancellationToken cancellationToken) + public ValueTask SendAsync(ReadOnlyMemory buffer, WebSocketMessageType messageType, bool endOfMessage, CancellationToken cancellationToken) { if (messageType != WebSocketMessageType.Text && messageType != WebSocketMessageType.Binary) { @@ -94,7 +94,7 @@ public Task SendAsync(ReadOnlyMemory buffer, WebSocketMessageType messageT throw new ArgumentException(errorMessage, nameof(messageType)); } - return _webSocket.SendAsync(buffer, messageType, endOfMessage, cancellationToken); + return new ValueTask(_webSocket.SendAsync(buffer, messageType, endOfMessage, cancellationToken)); } public Task ReceiveAsync( diff --git a/src/System.Net.WebSockets.Client/tests/SendReceiveTest.netcoreapp.cs b/src/System.Net.WebSockets.Client/tests/SendReceiveTest.netcoreapp.cs index 4c7759cba0ef..778ba03213bf 100644 --- a/src/System.Net.WebSockets.Client/tests/SendReceiveTest.netcoreapp.cs +++ b/src/System.Net.WebSockets.Client/tests/SendReceiveTest.netcoreapp.cs @@ -2,7 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System.Net.Sockets; using System.Threading; using System.Threading.Tasks; using Xunit.Abstractions; @@ -16,16 +15,16 @@ public MemorySendReceiveTest(ITestOutputHelper output) : base(output) { } protected override async Task ReceiveAsync(WebSocket ws, ArraySegment arraySegment, CancellationToken cancellationToken) { ValueWebSocketReceiveResult r = await ws.ReceiveAsync( - arraySegment == default(ArraySegment) ? Memory.Empty : (Memory)arraySegment, + (Memory)arraySegment, cancellationToken); return new WebSocketReceiveResult(r.Count, r.MessageType, r.EndOfMessage, ws.CloseStatus, ws.CloseStatusDescription); } protected override Task SendAsync(WebSocket ws, ArraySegment arraySegment, WebSocketMessageType messageType, bool endOfMessage, CancellationToken cancellationToken) => ws.SendAsync( - arraySegment == default(ArraySegment) ? ReadOnlyMemory.Empty : (ReadOnlyMemory)arraySegment, + (ReadOnlyMemory)arraySegment, messageType, endOfMessage, - cancellationToken); + cancellationToken).AsTask(); } } diff --git a/src/System.Net.WebSockets.WebSocketProtocol/src/System/Net/WebSockets/ManagedWebSocketExtensions.cs b/src/System.Net.WebSockets.WebSocketProtocol/src/System/Net/WebSockets/ManagedWebSocketExtensions.cs index 3f2143990e1b..0df60484eeb1 100644 --- a/src/System.Net.WebSockets.WebSocketProtocol/src/System/Net/WebSockets/ManagedWebSocketExtensions.cs +++ b/src/System.Net.WebSockets.WebSocketProtocol/src/System/Net/WebSockets/ManagedWebSocketExtensions.cs @@ -23,16 +23,16 @@ internal static unsafe string GetString(this UTF8Encoding encoding, Span b } } - internal static Task ReadAsync(this Stream stream, Memory destination, CancellationToken cancellationToken) + internal static ValueTask ReadAsync(this Stream stream, Memory destination, CancellationToken cancellationToken = default) { if (destination.TryGetArray(out ArraySegment array)) { - return stream.ReadAsync(array.Array, array.Offset, array.Count, cancellationToken); + return new ValueTask(stream.ReadAsync(array.Array, array.Offset, array.Count, cancellationToken)); } else { byte[] buffer = ArrayPool.Shared.Rent(destination.Length); - return FinishReadAsync(stream.ReadAsync(buffer, 0, destination.Length, cancellationToken), buffer, destination); + return new ValueTask(FinishReadAsync(stream.ReadAsync(buffer, 0, destination.Length, cancellationToken), buffer, destination)); async Task FinishReadAsync(Task readTask, byte[] localBuffer, Memory localDestination) { @@ -49,6 +49,32 @@ async Task FinishReadAsync(Task readTask, byte[] localBuffer, Memory source, CancellationToken cancellationToken = default) + { + if (MemoryMarshal.TryGetArray(source, out ArraySegment array)) + { + return new ValueTask(stream.WriteAsync(array.Array, array.Offset, array.Count, cancellationToken)); + } + else + { + byte[] buffer = ArrayPool.Shared.Rent(source.Length); + source.Span.CopyTo(buffer); + return new ValueTask(FinishWriteAsync(stream.WriteAsync(buffer, 0, source.Length, cancellationToken), buffer)); + + async Task FinishWriteAsync(Task writeTask, byte[] localBuffer) + { + try + { + await writeTask.ConfigureAwait(false); + } + finally + { + ArrayPool.Shared.Return(localBuffer); + } + } + } + } } internal static class BitConverter diff --git a/src/System.Net.WebSockets/ref/System.Net.WebSockets.cs b/src/System.Net.WebSockets/ref/System.Net.WebSockets.cs index 924dc869fe60..126887bc2ee7 100644 --- a/src/System.Net.WebSockets/ref/System.Net.WebSockets.cs +++ b/src/System.Net.WebSockets/ref/System.Net.WebSockets.cs @@ -41,7 +41,7 @@ protected WebSocket() { } [System.ComponentModel.EditorBrowsableAttribute((System.ComponentModel.EditorBrowsableState)(1))] public static void RegisterPrefixes() { } public abstract System.Threading.Tasks.Task SendAsync(System.ArraySegment buffer, System.Net.WebSockets.WebSocketMessageType messageType, bool endOfMessage, System.Threading.CancellationToken cancellationToken); - public virtual System.Threading.Tasks.Task SendAsync(System.ReadOnlyMemory buffer, System.Net.WebSockets.WebSocketMessageType messageType, bool endOfMessage, System.Threading.CancellationToken cancellationToken) { throw null; } + public virtual System.Threading.Tasks.ValueTask SendAsync(System.ReadOnlyMemory buffer, System.Net.WebSockets.WebSocketMessageType messageType, bool endOfMessage, System.Threading.CancellationToken cancellationToken) { throw null; } protected static void ThrowOnInvalidState(System.Net.WebSockets.WebSocketState state, params System.Net.WebSockets.WebSocketState[] validStates) { } } public enum WebSocketCloseStatus diff --git a/src/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.netcoreapp.cs b/src/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.netcoreapp.cs index 23e5f6949c86..db147654fe0e 100644 --- a/src/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.netcoreapp.cs +++ b/src/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.netcoreapp.cs @@ -10,7 +10,7 @@ namespace System.Net.WebSockets { internal sealed partial class ManagedWebSocket : WebSocket { - public override Task SendAsync(ReadOnlyMemory buffer, WebSocketMessageType messageType, bool endOfMessage, CancellationToken cancellationToken) + public override ValueTask SendAsync(ReadOnlyMemory buffer, WebSocketMessageType messageType, bool endOfMessage, CancellationToken cancellationToken) { return SendPrivateAsync(buffer, messageType, endOfMessage, cancellationToken); } @@ -24,7 +24,7 @@ public override ValueTask ReceiveAsync(Memory Debug.Assert(!Monitor.IsEntered(StateUpdateLock), $"{nameof(StateUpdateLock)} must never be held when acquiring {nameof(ReceiveAsyncLock)}"); lock (ReceiveAsyncLock) // synchronize with receives in CloseAsync { - ThrowIfOperationInProgress(_lastReceiveAsync); + ThrowIfOperationInProgress(_lastReceiveAsync.IsCompleted); ValueTask t = ReceiveAsyncPrivate(buffer, cancellationToken); _lastReceiveAsync = t.IsCompletedSuccessfully ? (t.Result.MessageType == WebSocketMessageType.Close ? s_cachedCloseTask : Task.CompletedTask) : diff --git a/src/System.Net.WebSockets/src/System/Net/WebSockets/WebSocket.cs b/src/System.Net.WebSockets/src/System/Net/WebSockets/WebSocket.cs index 8013388494bc..55a5481ca569 100644 --- a/src/System.Net.WebSockets/src/System/Net/WebSockets/WebSocket.cs +++ b/src/System.Net.WebSockets/src/System/Net/WebSockets/WebSocket.cs @@ -56,10 +56,10 @@ public virtual async ValueTask ReceiveAsync(Memory< } } - public virtual Task SendAsync(ReadOnlyMemory buffer, WebSocketMessageType messageType, bool endOfMessage, CancellationToken cancellationToken) => - MemoryMarshal.TryGetArray(buffer, out ArraySegment arraySegment) ? + public virtual ValueTask SendAsync(ReadOnlyMemory buffer, WebSocketMessageType messageType, bool endOfMessage, CancellationToken cancellationToken) => + new ValueTask(MemoryMarshal.TryGetArray(buffer, out ArraySegment arraySegment) ? SendAsync(arraySegment, messageType, endOfMessage, cancellationToken) : - SendWithArrayPoolAsync(buffer, messageType, endOfMessage, cancellationToken); + SendWithArrayPoolAsync(buffer, messageType, endOfMessage, cancellationToken)); private async Task SendWithArrayPoolAsync( ReadOnlyMemory buffer, diff --git a/src/System.Runtime.Extensions/ref/System.Runtime.Extensions.cs b/src/System.Runtime.Extensions/ref/System.Runtime.Extensions.cs index a9a829da969a..b90ae99f016e 100644 --- a/src/System.Runtime.Extensions/ref/System.Runtime.Extensions.cs +++ b/src/System.Runtime.Extensions/ref/System.Runtime.Extensions.cs @@ -1385,7 +1385,7 @@ public override void SetLength(long value) { } public override void Write(byte[] buffer, int offset, int count) { } public override void Write(System.ReadOnlySpan source) { } public override System.Threading.Tasks.Task WriteAsync(byte[] buffer, int offset, int count, System.Threading.CancellationToken cancellationToken) { throw null; } - public override System.Threading.Tasks.Task WriteAsync(System.ReadOnlyMemory source, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public override System.Threading.Tasks.ValueTask WriteAsync(System.ReadOnlyMemory source, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } public override void WriteByte(byte value) { } public virtual void WriteTo(System.IO.Stream stream) { } } diff --git a/src/System.Runtime.Extensions/src/System/IO/BufferedStream.cs b/src/System.Runtime.Extensions/src/System/IO/BufferedStream.cs index 6e918aab3d14..a79a60bf4777 100644 --- a/src/System.Runtime.Extensions/src/System/IO/BufferedStream.cs +++ b/src/System.Runtime.Extensions/src/System/IO/BufferedStream.cs @@ -413,7 +413,7 @@ private async Task FlushWriteAsync(CancellationToken cancellationToken) Debug.Assert(_buffer != null && _bufferSize >= _writePos, "BufferedStream: Write buffer must be allocated and write position must be in the bounds of the buffer in FlushWrite!"); - await _stream.WriteAsync(_buffer, 0, _writePos, cancellationToken).ConfigureAwait(false); + await _stream.WriteAsync(new ReadOnlyMemory(_buffer, 0, _writePos), cancellationToken).ConfigureAwait(false); _writePos = 0; await _stream.FlushAsync(cancellationToken).ConfigureAwait(false); } @@ -731,7 +731,7 @@ private async ValueTask ReadFromUnderlyingStreamAsync( // Ok. We can fill the buffer: EnsureBufferAllocated(); - _readLen = await _stream.ReadAsync(_buffer, 0, _bufferSize, cancellationToken).ConfigureAwait(false); + _readLen = await _stream.ReadAsync(new Memory(_buffer, 0, _bufferSize), cancellationToken).ConfigureAwait(false); bytesFromBuffer = ReadFromBuffer(buffer.Span); return bytesAlreadySatisfied + bytesFromBuffer; @@ -1040,15 +1040,15 @@ public override Task WriteAsync(byte[] buffer, int offset, int count, Cancellati if (buffer.Length - offset < count) throw new ArgumentException(SR.Argument_InvalidOffLen); - return WriteAsync(new ReadOnlyMemory(buffer, offset, count), cancellationToken); + return WriteAsync(new ReadOnlyMemory(buffer, offset, count), cancellationToken).AsTask(); } - public override Task WriteAsync(ReadOnlyMemory source, CancellationToken cancellationToken = default(CancellationToken)) + public override ValueTask WriteAsync(ReadOnlyMemory source, CancellationToken cancellationToken = default(CancellationToken)) { // Fast path check for cancellation already requested if (cancellationToken.IsCancellationRequested) { - return Task.FromCanceled(cancellationToken); + return new ValueTask(Task.FromCanceled(cancellationToken)); } EnsureNotClosed(); @@ -1075,7 +1075,7 @@ public override Task WriteAsync(byte[] buffer, int offset, int count, Cancellati { int bytesWritten = WriteToBuffer(source.Span); Debug.Assert(bytesWritten == source.Length); - return Task.CompletedTask; + return default; } } finally @@ -1086,7 +1086,7 @@ public override Task WriteAsync(byte[] buffer, int offset, int count, Cancellati } // Delegate to the async implementation. - return WriteToUnderlyingStreamAsync(source, cancellationToken, semaphoreLockTask); + return new ValueTask(WriteToUnderlyingStreamAsync(source, cancellationToken, semaphoreLockTask)); } /// BufferedStream should be as thin a wrapper as possible. We want WriteAsync to delegate to @@ -1136,7 +1136,7 @@ private async Task WriteToUnderlyingStreamAsync( Debug.Assert(_writePos == _bufferSize); Debug.Assert(_buffer != null); - await _stream.WriteAsync(_buffer, 0, _writePos, cancellationToken).ConfigureAwait(false); + await _stream.WriteAsync(new ReadOnlyMemory(_buffer, 0, _writePos), cancellationToken).ConfigureAwait(false); _writePos = 0; int bytesWritten = WriteToBuffer(source.Span); @@ -1159,12 +1159,12 @@ private async Task WriteToUnderlyingStreamAsync( EnsureShadowBufferAllocated(); source.Span.CopyTo(new Span(_buffer, _writePos, source.Length)); - await _stream.WriteAsync(_buffer, 0, totalUserBytes, cancellationToken).ConfigureAwait(false); + await _stream.WriteAsync(new ReadOnlyMemory(_buffer, 0, totalUserBytes), cancellationToken).ConfigureAwait(false); _writePos = 0; return; } - await _stream.WriteAsync(_buffer, 0, _writePos, cancellationToken).ConfigureAwait(false); + await _stream.WriteAsync(new ReadOnlyMemory(_buffer, 0, _writePos), cancellationToken).ConfigureAwait(false); _writePos = 0; } @@ -1312,7 +1312,7 @@ private async Task CopyToAsyncCore(Stream destination, int bufferSize, Cancellat { // If there's any read data in the buffer, write it all to the destination stream. Debug.Assert(_writePos == 0, "Write buffer must be empty if there's data in the read buffer"); - await destination.WriteAsync(_buffer, _readPos, readBytes, cancellationToken).ConfigureAwait(false); + await destination.WriteAsync(new ReadOnlyMemory(_buffer, _readPos, readBytes), cancellationToken).ConfigureAwait(false); _readPos = _readLen = 0; } else if (_writePos > 0) diff --git a/src/System.Runtime/ref/System.Runtime.cs b/src/System.Runtime/ref/System.Runtime.cs index 80f1799cac5b..9b1a68455354 100644 --- a/src/System.Runtime/ref/System.Runtime.cs +++ b/src/System.Runtime/ref/System.Runtime.cs @@ -5137,7 +5137,7 @@ protected virtual void ObjectInvariant() { } public virtual void Write(System.ReadOnlySpan source) { } public System.Threading.Tasks.Task WriteAsync(byte[] buffer, int offset, int count) { throw null; } public virtual System.Threading.Tasks.Task WriteAsync(byte[] buffer, int offset, int count, System.Threading.CancellationToken cancellationToken) { throw null; } - public virtual System.Threading.Tasks.Task WriteAsync(System.ReadOnlyMemory source, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public virtual System.Threading.Tasks.ValueTask WriteAsync(System.ReadOnlyMemory source, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } public virtual void WriteByte(byte value) { } } } @@ -6337,6 +6337,18 @@ public sealed partial class AsyncStateMachineAttribute : System.Runtime.Compiler { public AsyncStateMachineAttribute(System.Type stateMachineType) : base (default(System.Type)) { } } + public partial struct AsyncValueTaskMethodBuilder + { + private object _dummy; + public System.Threading.Tasks.ValueTask Task { get { throw null; } } + public void AwaitOnCompleted(ref TAwaiter awaiter, ref TStateMachine stateMachine) where TAwaiter : System.Runtime.CompilerServices.INotifyCompletion where TStateMachine : System.Runtime.CompilerServices.IAsyncStateMachine { } + public void AwaitUnsafeOnCompleted(ref TAwaiter awaiter, ref TStateMachine stateMachine) where TAwaiter : System.Runtime.CompilerServices.ICriticalNotifyCompletion where TStateMachine : System.Runtime.CompilerServices.IAsyncStateMachine { } + public static System.Runtime.CompilerServices.AsyncValueTaskMethodBuilder Create() { throw null; } + public void SetException(System.Exception exception) { } + public void SetResult() { } + public void SetStateMachine(System.Runtime.CompilerServices.IAsyncStateMachine stateMachine) { } + public void Start(ref TStateMachine stateMachine) where TStateMachine : System.Runtime.CompilerServices.IAsyncStateMachine { } + } public partial struct AsyncValueTaskMethodBuilder { private TResult _result; @@ -6427,6 +6439,19 @@ public void OnCompleted(System.Action continuation) { } public void UnsafeOnCompleted(System.Action continuation) { } } } + public readonly partial struct ConfiguredValueTaskAwaitable + { + private readonly object _dummy; + public System.Runtime.CompilerServices.ConfiguredValueTaskAwaitable.ConfiguredValueTaskAwaiter GetAwaiter() { throw null; } + public partial struct ConfiguredValueTaskAwaiter : System.Runtime.CompilerServices.ICriticalNotifyCompletion, System.Runtime.CompilerServices.INotifyCompletion + { + private readonly object _dummy; + public bool IsCompleted { get { throw null; } } + public void GetResult() { } + public void OnCompleted(System.Action continuation) { } + public void UnsafeOnCompleted(System.Action continuation) { } + } + } public readonly partial struct ConfiguredValueTaskAwaitable { private readonly object _dummy; @@ -6713,6 +6738,14 @@ public sealed partial class UnsafeValueTypeAttribute : System.Attribute { public UnsafeValueTypeAttribute() { } } + public readonly partial struct ValueTaskAwaiter : System.Runtime.CompilerServices.ICriticalNotifyCompletion, System.Runtime.CompilerServices.INotifyCompletion + { + private readonly object _dummy; + public bool IsCompleted { get { throw null; } } + public void GetResult() { } + public void OnCompleted(System.Action continuation) { } + public void UnsafeOnCompleted(System.Action continuation) { } + } public readonly partial struct ValueTaskAwaiter : System.Runtime.CompilerServices.ICriticalNotifyCompletion, System.Runtime.CompilerServices.INotifyCompletion { private readonly object _dummy; @@ -7964,11 +7997,58 @@ public UnobservedTaskExceptionEventArgs(System.AggregateException exception) { } public bool Observed { get { throw null; } } public void SetObserved() { } } + [Flags] + public enum ValueTaskSourceOnCompletedFlags + { + None, + UseSchedulingContext = 0x1, + FlowExecutionContext = 0x2, + } + public enum ValueTaskSourceStatus + { + Pending = 0, + Succeeded = 1, + Faulted = 2, + Canceled = 3 + } + public interface IValueTaskSource + { + System.Threading.Tasks.ValueTaskSourceStatus Status { get; } + void OnCompleted(System.Action continuation, object state, System.Threading.Tasks.ValueTaskSourceOnCompletedFlags flags); + void GetResult(); + } + public interface IValueTaskSource + { + System.Threading.Tasks.ValueTaskSourceStatus Status { get; } + void OnCompleted(System.Action continuation, object state, System.Threading.Tasks.ValueTaskSourceOnCompletedFlags flags); + TResult GetResult(); + } + [System.Runtime.CompilerServices.AsyncMethodBuilderAttribute(typeof(System.Runtime.CompilerServices.AsyncValueTaskMethodBuilder))] + public readonly partial struct ValueTask : System.IEquatable + { + internal readonly object _dummy; + public ValueTask(System.Threading.Tasks.Task task) { throw null; } + public ValueTask(System.Threading.Tasks.IValueTaskSource source) { throw null; } + public bool IsCanceled { get { throw null; } } + public bool IsCompleted { get { throw null; } } + public bool IsCompletedSuccessfully { get { throw null; } } + public bool IsFaulted { get { throw null; } } + public System.Threading.Tasks.Task AsTask() { throw null; } + public System.Runtime.CompilerServices.ConfiguredValueTaskAwaitable ConfigureAwait(bool continueOnCapturedContext) { throw null; } + public override bool Equals(object obj) { throw null; } + public bool Equals(System.Threading.Tasks.ValueTask other) { throw null; } + public System.Runtime.CompilerServices.ValueTaskAwaiter GetAwaiter() { throw null; } + public override int GetHashCode() { throw null; } + public System.Threading.Tasks.ValueTask Preserve() { throw null; } + public static bool operator ==(System.Threading.Tasks.ValueTask left, System.Threading.Tasks.ValueTask right) { throw null; } + public static bool operator !=(System.Threading.Tasks.ValueTask left, System.Threading.Tasks.ValueTask right) { throw null; } + } [System.Runtime.CompilerServices.AsyncMethodBuilderAttribute(typeof(System.Runtime.CompilerServices.AsyncValueTaskMethodBuilder<>))] public readonly partial struct ValueTask : System.IEquatable> { internal readonly TResult _result; public ValueTask(System.Threading.Tasks.Task task) { throw null; } + public ValueTask(System.Threading.Tasks.IValueTaskSource source) { throw null; } public ValueTask(TResult result) { throw null; } public bool IsCanceled { get { throw null; } } public bool IsCompleted { get { throw null; } } @@ -7981,6 +8061,7 @@ public void SetObserved() { } public bool Equals(System.Threading.Tasks.ValueTask other) { throw null; } public System.Runtime.CompilerServices.ValueTaskAwaiter GetAwaiter() { throw null; } public override int GetHashCode() { throw null; } + public System.Threading.Tasks.ValueTask Preserve() { throw null; } public static bool operator ==(System.Threading.Tasks.ValueTask left, System.Threading.Tasks.ValueTask right) { throw null; } public static bool operator !=(System.Threading.Tasks.ValueTask left, System.Threading.Tasks.ValueTask right) { throw null; } public override string ToString() { throw null; } diff --git a/src/System.Security.Cryptography.Primitives/src/System/Security/Cryptography/CryptoStream.cs b/src/System.Security.Cryptography.Primitives/src/System/Security/Cryptography/CryptoStream.cs index c9f420d593a7..8f38d707bff8 100644 --- a/src/System.Security.Cryptography.Primitives/src/System/Security/Cryptography/CryptoStream.cs +++ b/src/System.Security.Cryptography.Primitives/src/System/Security/Cryptography/CryptoStream.cs @@ -238,7 +238,7 @@ public override void WriteByte(byte value) public override int Read(byte[] buffer, int offset, int count) { CheckReadArguments(buffer, offset, count); - return ReadAsyncCore(buffer, offset, count, default(CancellationToken), useAsync: false).ConfigureAwait(false).GetAwaiter().GetResult(); + return ReadAsyncCore(buffer, offset, count, default(CancellationToken), useAsync: false).GetAwaiter().GetResult(); } private void CheckReadArguments(byte[] buffer, int offset, int count) @@ -305,7 +305,7 @@ private async Task ReadAsyncCore(byte[] buffer, int offset, int count, Canc Buffer.BlockCopy(_inputBuffer, 0, tempInputBuffer, 0, _inputBufferIndex); amountRead = _inputBufferIndex; amountRead += useAsync ? - await _stream.ReadAsync(tempInputBuffer, _inputBufferIndex, numWholeBlocksInBytes - _inputBufferIndex, cancellationToken) : + await _stream.ReadAsync(new Memory(tempInputBuffer, _inputBufferIndex, numWholeBlocksInBytes - _inputBufferIndex), cancellationToken) : _stream.Read(tempInputBuffer, _inputBufferIndex, numWholeBlocksInBytes - _inputBufferIndex); _inputBufferIndex = 0; @@ -341,7 +341,7 @@ await _stream.ReadAsync(tempInputBuffer, _inputBufferIndex, numWholeBlocksInByte while (_inputBufferIndex < _inputBlockSize) { amountRead = useAsync ? - await _stream.ReadAsync(_inputBuffer, _inputBufferIndex, _inputBlockSize - _inputBufferIndex, cancellationToken) : + await _stream.ReadAsync(new Memory(_inputBuffer, _inputBufferIndex, _inputBlockSize - _inputBufferIndex), cancellationToken) : _stream.Read(_inputBuffer, _inputBufferIndex, _inputBlockSize - _inputBufferIndex); // first, check to see if we're at the end of the input stream @@ -475,7 +475,7 @@ private async Task WriteAsyncCore(byte[] buffer, int offset, int count, Cancella if (_outputBufferIndex > 0) { if (useAsync) - await _stream.WriteAsync(_outputBuffer, 0, _outputBufferIndex, cancellationToken); + await _stream.WriteAsync(new ReadOnlyMemory(_outputBuffer, 0, _outputBufferIndex), cancellationToken); else _stream.Write(_outputBuffer, 0, _outputBufferIndex); _outputBufferIndex = 0; @@ -488,7 +488,7 @@ private async Task WriteAsyncCore(byte[] buffer, int offset, int count, Cancella numOutputBytes = _transform.TransformBlock(_inputBuffer, 0, _inputBlockSize, _outputBuffer, 0); // write out the bytes we just got if (useAsync) - await _stream.WriteAsync(_outputBuffer, 0, numOutputBytes, cancellationToken); + await _stream.WriteAsync(new ReadOnlyMemory(_outputBuffer, 0, numOutputBytes), cancellationToken); else _stream.Write(_outputBuffer, 0, numOutputBytes); @@ -509,7 +509,7 @@ private async Task WriteAsyncCore(byte[] buffer, int offset, int count, Cancella numOutputBytes = _transform.TransformBlock(buffer, currentInputIndex, numWholeBlocksInBytes, _tempOutputBuffer, 0); if (useAsync) - await _stream.WriteAsync(_tempOutputBuffer, 0, numOutputBytes, cancellationToken); + await _stream.WriteAsync(new ReadOnlyMemory(_tempOutputBuffer, 0, numOutputBytes), cancellationToken); else _stream.Write(_tempOutputBuffer, 0, numOutputBytes); @@ -522,7 +522,7 @@ private async Task WriteAsyncCore(byte[] buffer, int offset, int count, Cancella numOutputBytes = _transform.TransformBlock(buffer, currentInputIndex, _inputBlockSize, _outputBuffer, 0); if (useAsync) - await _stream.WriteAsync(_outputBuffer, 0, numOutputBytes, cancellationToken); + await _stream.WriteAsync(new ReadOnlyMemory(_outputBuffer, 0, numOutputBytes), cancellationToken); else _stream.Write(_outputBuffer, 0, numOutputBytes); diff --git a/src/System.Threading.Channels/ref/System.Threading.Channels.cs b/src/System.Threading.Channels/ref/System.Threading.Channels.cs index b387b5da2c2b..d484aab48772 100644 --- a/src/System.Threading.Channels/ref/System.Threading.Channels.cs +++ b/src/System.Threading.Channels/ref/System.Threading.Channels.cs @@ -47,7 +47,7 @@ protected ChannelReader() { } public virtual System.Threading.Tasks.Task Completion { get { throw null; } } public virtual System.Threading.Tasks.ValueTask ReadAsync(CancellationToken cancellationToken = default) { throw null; } public abstract bool TryRead(out T item); - public abstract System.Threading.Tasks.Task WaitToReadAsync(System.Threading.CancellationToken cancellationToken=default); + public abstract System.Threading.Tasks.ValueTask WaitToReadAsync(System.Threading.CancellationToken cancellationToken=default); } public abstract partial class ChannelWriter { @@ -55,8 +55,8 @@ protected ChannelWriter() { } public void Complete(System.Exception error=null) { } public virtual bool TryComplete(System.Exception error=null) { throw null; } public abstract bool TryWrite(T item); - public abstract System.Threading.Tasks.Task WaitToWriteAsync(System.Threading.CancellationToken cancellationToken=default); - public virtual System.Threading.Tasks.Task WriteAsync(T item, System.Threading.CancellationToken cancellationToken=default) { throw null; } + public abstract System.Threading.Tasks.ValueTask WaitToWriteAsync(System.Threading.CancellationToken cancellationToken=default); + public virtual System.Threading.Tasks.ValueTask WriteAsync(T item, System.Threading.CancellationToken cancellationToken=default) { throw null; } } public abstract partial class Channel : System.Threading.Channels.Channel { diff --git a/src/System.Threading.Channels/src/Configurations.props b/src/System.Threading.Channels/src/Configurations.props index 5f3b2623edf6..7eb3ac6025bf 100644 --- a/src/System.Threading.Channels/src/Configurations.props +++ b/src/System.Threading.Channels/src/Configurations.props @@ -4,6 +4,7 @@ netstandard1.3; netstandard; + netcoreapp; diff --git a/src/System.Threading.Channels/src/Resources/Strings.resx b/src/System.Threading.Channels/src/Resources/Strings.resx index 2beea8a35762..83acd241208f 100644 --- a/src/System.Threading.Channels/src/Resources/Strings.resx +++ b/src/System.Threading.Channels/src/Resources/Strings.resx @@ -120,4 +120,10 @@ The channel has been closed. + + The asynchronous operation has not completed. + + + OnCompleted has already been used to register another continuation. + \ No newline at end of file diff --git a/src/System.Threading.Channels/src/System.Threading.Channels.csproj b/src/System.Threading.Channels/src/System.Threading.Channels.csproj index 5db0c65e1293..2e0659f4983f 100644 --- a/src/System.Threading.Channels/src/System.Threading.Channels.csproj +++ b/src/System.Threading.Channels/src/System.Threading.Channels.csproj @@ -22,7 +22,7 @@ - + @@ -37,6 +37,7 @@ + diff --git a/src/System.Threading.Channels/src/System/Threading/Channels/AsyncOperation.cs b/src/System.Threading.Channels/src/System/Threading/Channels/AsyncOperation.cs new file mode 100644 index 000000000000..acf48c02245d --- /dev/null +++ b/src/System.Threading.Channels/src/System/Threading/Channels/AsyncOperation.cs @@ -0,0 +1,342 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Diagnostics; +using System.Runtime.ExceptionServices; +using System.Threading.Tasks; + +namespace System.Threading.Channels +{ + internal abstract class ResettableValueTaskSource + { + protected static readonly Action s_completedSentinel = s => Debug.Fail($"{nameof(ResettableValueTaskSource)}.{nameof(s_completedSentinel)} invoked."); + + protected static void ThrowIncompleteOperationException() => + throw new InvalidOperationException(SR.InvalidOperation_IncompleteAsyncOperation); + + protected static void ThrowMultipleContinuations() => + throw new InvalidOperationException(SR.InvalidOperation_MultipleContinuations); + + public enum States + { + Owned = 0, + CompletionReserved = 1, + CompletionSet = 2, + Released = 3 + } + } + + internal abstract class ResettableValueTaskSource : ResettableValueTaskSource, IValueTaskSource, IValueTaskSource + { + private volatile int _state = (int)States.Owned; + private T _result; + private ExceptionDispatchInfo _error; + private Action _continuation; + private object _continuationState; + private object _schedulingContext; + private ExecutionContext _executionContext; + + public bool RunContinutationsAsynchronously { get; protected set; } + public ValueTaskSourceStatus Status + { + get + { + switch ((States)_state) + { + case States.Owned: + case States.CompletionReserved: + return ValueTaskSourceStatus.Pending; + + case States.CompletionSet: + case States.Released: + return + _error == null ? ValueTaskSourceStatus.Succeeded : + _error.SourceException is OperationCanceledException ? ValueTaskSourceStatus.Canceled : + ValueTaskSourceStatus.Faulted; + + default: + Debug.Fail($"Shouldn't be accessed in the '{(States)_state}' state."); + goto case States.CompletionSet; + } + } + } + + public bool IsCompleted => _state >= (int)States.CompletionSet; + public States UnsafeState { get => (States)_state; set => _state = (int)value; } + + public T GetResult() + { + if (!IsCompleted) + { + ThrowIncompleteOperationException(); + } + + ExceptionDispatchInfo error = _error; + T result = _result; + + _state = (int)States.Released; // only after fetching all needed data + + error?.Throw(); + return result; + } + + void IValueTaskSource.GetResult() + { + if (!IsCompleted) + { + ThrowIncompleteOperationException(); + } + + ExceptionDispatchInfo error = _error; + + _state = (int)States.Released; // only after fetching all needed data + + error?.Throw(); + } + + public bool TryOwnAndReset() + { + if (Interlocked.CompareExchange(ref _state, (int)States.Owned, (int)States.Released) == (int)States.Released) + { + _continuation = null; + _continuationState = null; + _result = default; + _error = null; + _schedulingContext = null; + _executionContext = null; + return true; + } + + return false; + } + + public void OnCompleted(Action continuation, object state, ValueTaskSourceOnCompletedFlags flags) + { + // We need to store the state before the CompareExchange, so that if it completes immediately + // after the CompareExchange, it'll find the state already stored. If someone misuses this + // and schedules multiple continuations erroneously, we could end up using the wrong state. + // Make a best-effort attempt to catch such misuse. + if (_continuationState != null) + { + ThrowMultipleContinuations(); + } + _continuationState = state; + + Debug.Assert(_executionContext == null); + if ((flags & ValueTaskSourceOnCompletedFlags.FlowExecutionContext) != 0) + { + _executionContext = ExecutionContext.Capture(); + } + + Debug.Assert(_schedulingContext == null); + SynchronizationContext sc = null; + TaskScheduler ts = null; + if ((flags & ValueTaskSourceOnCompletedFlags.UseSchedulingContext) != 0) + { + sc = SynchronizationContext.Current; + if (sc != null && sc.GetType() != typeof(SynchronizationContext)) + { + _schedulingContext = sc; + } + else + { + ts = TaskScheduler.Current; + if (ts != TaskScheduler.Default) + { + _schedulingContext = ts; + } + } + } + + Action prevContinuation = Interlocked.CompareExchange(ref _continuation, continuation, null); + if (prevContinuation != null) + { + if (prevContinuation != s_completedSentinel) + { + ThrowMultipleContinuations(); + } + + Debug.Assert(IsCompleted, $"Expected IsCompleted, got {(States)_state}"); + if (sc != null) + { + sc.Post(s => + { + var t = (Tuple, object>)s; + t.Item1(t.Item2); + }, Tuple.Create(continuation, state)); + } + else if (ts != null) + { + Task.Factory.StartNew(continuation, state, CancellationToken.None, TaskCreationOptions.DenyChildAttach, ts); + } + else + { + // TODO #27464: Change this to use the new QueueUserWorkItem signature when it's available. + Debug.Assert(_schedulingContext == null, $"Expected null context, got {_schedulingContext}"); + Task.Factory.StartNew(continuation, state, CancellationToken.None, TaskCreationOptions.DenyChildAttach, TaskScheduler.Default); + } + } + } + + public bool TrySetResult(T result) + { + if (Interlocked.CompareExchange(ref _state, (int)States.CompletionReserved, (int)States.Owned) == (int)States.Owned) + { + _result = result; + SignalCompletion(); + return true; + } + + return false; + } + + public bool TrySetException(Exception error) + { + if (Interlocked.CompareExchange(ref _state, (int)States.CompletionReserved, (int)States.Owned) == (int)States.Owned) + { + _error = ExceptionDispatchInfo.Capture(error); + SignalCompletion(); + return true; + } + + return false; + } + + public bool TrySetCanceled(CancellationToken cancellationToken = default) + { + if (Interlocked.CompareExchange(ref _state, (int)States.CompletionReserved, (int)States.Owned) == (int)States.Owned) + { + _error = ExceptionDispatchInfo.Capture(new OperationCanceledException(cancellationToken)); + SignalCompletion(); + return true; + } + + return false; + } + + private void SignalCompletion() + { + _state = (int)States.CompletionSet; + if (_continuation != null || Interlocked.CompareExchange(ref _continuation, s_completedSentinel, null) != null) + { + ExecutionContext ec = _executionContext; + if (ec != null) + { + ExecutionContext.Run(ec, s => ((ResettableValueTaskSource)s).InvokeContinuation(), this); + } + else + { + InvokeContinuation(); + } + } + } + + private void InvokeContinuation() + { + Debug.Assert(_continuation != s_completedSentinel, $"The continuation was the completion sentinel. State={(States)_state}."); + + if (_schedulingContext == null) + { + if (RunContinutationsAsynchronously) + { + ThreadPool.QueueUserWorkItem(s => + { + var vts = (ResettableValueTaskSource)s; + vts._continuation(vts._continuationState); + }, this); + return; + } + } + else if (_schedulingContext is SynchronizationContext sc) + { + if (RunContinutationsAsynchronously || sc != SynchronizationContext.Current) + { + sc.Post(s => + { + var vts = (ResettableValueTaskSource)s; + vts._continuation(vts._continuationState); + }, this); + return; + } + } + else + { + TaskScheduler ts = (TaskScheduler)_schedulingContext; + if (RunContinutationsAsynchronously || ts != TaskScheduler.Current) + { + Task.Factory.StartNew(s => + { + var vts = (ResettableValueTaskSource)s; + vts._continuation(vts._continuationState); + }, this, CancellationToken.None, TaskCreationOptions.DenyChildAttach, ts); + return; + } + } + + _continuation(_continuationState); + } + } + + /// The representation of an asynchronous operation that has a result value. + /// Specifies the type of the result. May be . + internal class AsyncOperation : ResettableValueTaskSource + { + /// Registration in that should be disposed of when the operation has completed. + private CancellationTokenRegistration _registration; + + /// Initializes the interactor. + /// true if continuations should be forced to run asynchronously; otherwise, false. + /// The cancellation token used to cancel the operation. + public AsyncOperation(bool runContinuationsAsynchronously, CancellationToken cancellationToken = default) + { + RunContinutationsAsynchronously = runContinuationsAsynchronously; + CancellationToken = cancellationToken; + _registration = cancellationToken.Register(s => + { + var thisRef = (AsyncOperation)s; + thisRef.TrySetCanceled(thisRef.CancellationToken); + }, this); + } + + /// Next operation in the linked list of operations. + public AsyncOperation Next { get; set; } + public CancellationToken CancellationToken { get; } + + /// Completes the interactor with a success state and the specified result. + /// The result value. + /// true if the interactor could be successfully transitioned to a completed state; false if it was already completed. + public bool Success(TResult item) + { + UnregisterCancellation(); + return TrySetResult(item); + } + + /// Completes the interactor with a failed state and the specified error. + /// The error. + /// true if the interactor could be successfully transitioned to a completed state; false if it was already completed. + public bool Fail(Exception exception) + { + UnregisterCancellation(); + return TrySetException(exception); + } + + public void UnregisterCancellation() => _registration.Dispose(); + } + + /// The representation of an asynchronous operation that has a result value and carries additional data with it. + /// Specifies the type of data being written. + internal sealed class VoidAsyncOperationWithData : AsyncOperation + { + /// Initializes the interactor. + /// true if continuations should be forced to run asynchronously; otherwise, false. + /// The cancellation token used to cancel the operation. + public VoidAsyncOperationWithData(bool runContinuationsAsynchronously, CancellationToken cancellationToken = default) : + base(runContinuationsAsynchronously, cancellationToken) + { + } + + /// The item being written. + public TData Item { get; set; } + } +} diff --git a/src/System.Threading.Channels/src/System/Threading/Channels/BoundedChannel.cs b/src/System.Threading.Channels/src/System/Threading/Channels/BoundedChannel.cs index c80d92d2cab4..df81bb778835 100644 --- a/src/System.Threading.Channels/src/System/Threading/Channels/BoundedChannel.cs +++ b/src/System.Threading.Channels/src/System/Threading/Channels/BoundedChannel.cs @@ -22,13 +22,13 @@ internal sealed class BoundedChannel : Channel, IDebugEnumerable /// Items currently stored in the channel waiting to be read. private readonly Dequeue _items = new Dequeue(); /// Readers waiting to read from the channel. - private readonly Dequeue> _blockedReaders = new Dequeue>(); + private readonly Dequeue> _blockedReaders = new Dequeue>(); /// Writers waiting to write to the channel. - private readonly Dequeue> _blockedWriters = new Dequeue>(); - /// Task signaled when any WaitToReadAsync waiters should be woken up. - private ReaderInteractor _waitingReaders; - /// Task signaled when any WaitToWriteAsync waiters should be woken up. - private ReaderInteractor _waitingWriters; + private readonly Dequeue> _blockedWriters = new Dequeue>(); + /// Linked list of WaitToReadAsync waiters. + private AsyncOperation _waitingReadersTail; + /// Linked list of WaitToWriteAsync waiters. + private AsyncOperation _waitingWritersTail; /// Whether to force continuations to be executed asynchronously from producer writes. private readonly bool _runContinuationsAsynchronously; /// Set to non-null once Complete has been called. @@ -105,17 +105,17 @@ public override ValueTask ReadAsync(CancellationToken cancellationToken) } // Otherwise, queue the reader. - var reader = ReaderInteractor.Create(parent._runContinuationsAsynchronously, cancellationToken); + var reader = new AsyncOperation(parent._runContinuationsAsynchronously, cancellationToken); parent._blockedReaders.EnqueueTail(reader); - return new ValueTask(reader.Task); + return new ValueTask(reader); } } - public override Task WaitToReadAsync(CancellationToken cancellationToken) + public override ValueTask WaitToReadAsync(CancellationToken cancellationToken) { if (cancellationToken.IsCancellationRequested) { - return Task.FromCanceled(cancellationToken); + return new ValueTask(Task.FromCanceled(cancellationToken)); } BoundedChannel parent = _parent; @@ -126,20 +126,22 @@ public override Task WaitToReadAsync(CancellationToken cancellationToken) // If there are any items available, a read is possible. if (!parent._items.IsEmpty) { - return ChannelUtilities.s_trueTask; + return new ValueTask(true); } // There were no items available, so if we're done writing, a read will never be possible. if (parent._doneWriting != null) { return parent._doneWriting != ChannelUtilities.s_doneWritingSentinel ? - Task.FromException(parent._doneWriting) : - ChannelUtilities.s_falseTask; + new ValueTask(Task.FromException(parent._doneWriting)) : + new ValueTask(false); } // There were no items available, but there could be in the future, so ensure // there's a blocked reader task and return it. - return ChannelUtilities.GetOrCreateWaiter(ref parent._waitingReaders, parent._runContinuationsAsynchronously, cancellationToken); + var waiter = new AsyncOperation(parent._runContinuationsAsynchronously, cancellationToken); + ChannelUtilities.QueueWaiter(ref _parent._waitingReadersTail, waiter); + return new ValueTask(waiter); } } @@ -177,7 +179,7 @@ private T DequeueItemAndPostProcess() while (!parent._blockedWriters.IsEmpty) { - WriterInteractor w = parent._blockedWriters.DequeueHead(); + VoidAsyncOperationWithData w = parent._blockedWriters.DequeueHead(); if (w.Success(default)) { parent._items.EnqueueTail(w.Item); @@ -187,7 +189,7 @@ private T DequeueItemAndPostProcess() // There was no blocked writer, so see if there's a WaitToWriteAsync // we should wake up. - ChannelUtilities.WakeUpWaiters(ref parent._waitingWriters, result: true); + ChannelUtilities.WakeUpWaiters(ref parent._waitingWritersTail, result: true); } // Return the item @@ -242,10 +244,10 @@ public override bool TryComplete(Exception error) // We also know that only one thread (this one) will ever get here, as only that thread // will be the one to transition from _doneWriting false to true. As such, we can // freely manipulate them without any concurrency concerns. - ChannelUtilities.FailInteractors, T>(parent._blockedReaders, ChannelUtilities.CreateInvalidCompletionException(error)); - ChannelUtilities.FailInteractors, VoidResult>(parent._blockedWriters, ChannelUtilities.CreateInvalidCompletionException(error)); - ChannelUtilities.WakeUpWaiters(ref parent._waitingReaders, result: false, error: error); - ChannelUtilities.WakeUpWaiters(ref parent._waitingWriters, result: false, error: error); + ChannelUtilities.FailOperations, T>(parent._blockedReaders, ChannelUtilities.CreateInvalidCompletionException(error)); + ChannelUtilities.FailOperations, VoidResult>(parent._blockedWriters, ChannelUtilities.CreateInvalidCompletionException(error)); + ChannelUtilities.WakeUpWaiters(ref parent._waitingReadersTail, result: false, error: error); + ChannelUtilities.WakeUpWaiters(ref parent._waitingWritersTail, result: false, error: error); // Successfully transitioned to completed. return true; @@ -253,8 +255,8 @@ public override bool TryComplete(Exception error) public override bool TryWrite(T item) { - ReaderInteractor blockedReader = null; - ReaderInteractor waitingReaders = null; + AsyncOperation blockedReader = null; + AsyncOperation waitingReadersTail = null; BoundedChannel parent = _parent; lock (parent.SyncObj) @@ -279,9 +281,9 @@ public override bool TryWrite(T item) // continuations that'll run synchronously while (!parent._blockedReaders.IsEmpty) { - ReaderInteractor r = parent._blockedReaders.DequeueHead(); + AsyncOperation r = parent._blockedReaders.DequeueHead(); r.UnregisterCancellation(); // ensure that once we grab it, we own its completion - if (!r.Task.IsCompleted) + if (!r.IsCompleted) { blockedReader = r; break; @@ -293,12 +295,12 @@ public override bool TryWrite(T item) // If there wasn't a blocked reader, then store the item. If no one's waiting // to be notified about a 0-to-1 transition, we're done. parent._items.EnqueueTail(item); - waitingReaders = parent._waitingReaders; - if (waitingReaders == null) + waitingReadersTail = parent._waitingReadersTail; + if (waitingReadersTail == null) { return true; } - parent._waitingReaders = null; + parent._waitingReadersTail = null; } } else if (count < parent._bufferedCapacity) @@ -333,7 +335,7 @@ public override bool TryWrite(T item) } } - // We either wrote the item already, or we're transfering it to the blocked reader we grabbed. + // We either wrote the item already, or we're transferring it to the blocked reader we grabbed. if (blockedReader != null) { // Transfer the written item to the blocked reader. @@ -346,17 +348,17 @@ public override bool TryWrite(T item) // any waiting readers that there may be something for them to consume. // Since we're no longer holding the lock, it's possible we'll end up // waking readers that have since come in. - waitingReaders.Success(item: true); + ChannelUtilities.WakeUpWaiters(ref waitingReadersTail, result: true); } return true; } - public override Task WaitToWriteAsync(CancellationToken cancellationToken) + public override ValueTask WaitToWriteAsync(CancellationToken cancellationToken) { if (cancellationToken.IsCancellationRequested) { - return Task.FromCanceled(cancellationToken); + return new ValueTask(Task.FromCanceled(cancellationToken)); } BoundedChannel parent = _parent; @@ -368,8 +370,8 @@ public override Task WaitToWriteAsync(CancellationToken cancellationToken) if (parent._doneWriting != null) { return parent._doneWriting != ChannelUtilities.s_doneWritingSentinel ? - Task.FromException(parent._doneWriting) : - ChannelUtilities.s_falseTask; + new ValueTask(Task.FromException(parent._doneWriting)) : + new ValueTask(false); } // If there's space to write, a write is possible. @@ -377,23 +379,25 @@ public override Task WaitToWriteAsync(CancellationToken cancellationToken) // full we'll just drop an element to make room. if (parent._items.Count < parent._bufferedCapacity || parent._mode != BoundedChannelFullMode.Wait) { - return ChannelUtilities.s_trueTask; + return new ValueTask(true); } // We're still allowed to write, but there's no space, so ensure a waiter is queued and return it. - return ChannelUtilities.GetOrCreateWaiter(ref parent._waitingWriters, runContinuationsAsynchronously: true, cancellationToken); + var waiter = new AsyncOperation(runContinuationsAsynchronously: true, cancellationToken); + ChannelUtilities.QueueWaiter(ref parent._waitingWritersTail, waiter); + return new ValueTask(waiter); } } - public override Task WriteAsync(T item, CancellationToken cancellationToken) + public override ValueTask WriteAsync(T item, CancellationToken cancellationToken) { if (cancellationToken.IsCancellationRequested) { - return Task.FromCanceled(cancellationToken); + return new ValueTask(Task.FromCanceled(cancellationToken)); } - ReaderInteractor blockedReader = null; - ReaderInteractor waitingReaders = null; + AsyncOperation blockedReader = null; + AsyncOperation waitingReadersTail = null; BoundedChannel parent = _parent; lock (parent.SyncObj) @@ -403,7 +407,7 @@ public override Task WriteAsync(T item, CancellationToken cancellationToken) // If we're done writing, trying to write is an error. if (parent._doneWriting != null) { - return Task.FromException(ChannelUtilities.CreateInvalidCompletionException(parent._doneWriting)); + return new ValueTask(Task.FromException(ChannelUtilities.CreateInvalidCompletionException(parent._doneWriting))); } // Get the number of items in the channel currently. @@ -418,9 +422,9 @@ public override Task WriteAsync(T item, CancellationToken cancellationToken) // continuations that'll run synchronously while (!parent._blockedReaders.IsEmpty) { - ReaderInteractor r = parent._blockedReaders.DequeueHead(); + AsyncOperation r = parent._blockedReaders.DequeueHead(); r.UnregisterCancellation(); // ensure that once we grab it, we own its completion - if (!r.Task.IsCompleted) + if (!r.IsCompleted) { blockedReader = r; break; @@ -432,12 +436,12 @@ public override Task WriteAsync(T item, CancellationToken cancellationToken) // If there wasn't a blocked reader, then store the item. If no one's waiting // to be notified about a 0-to-1 transition, we're done. parent._items.EnqueueTail(item); - waitingReaders = parent._waitingReaders; - if (waitingReaders == null) + waitingReadersTail = parent._waitingReadersTail; + if (waitingReadersTail == null) { - return ChannelUtilities.s_trueTask; + return default; } - parent._waitingReaders = null; + parent._waitingReadersTail = null; } } else if (count < parent._bufferedCapacity) @@ -446,21 +450,22 @@ public override Task WriteAsync(T item, CancellationToken cancellationToken) // since there's room, we can simply store the item and exit without having to // worry about blocked/waiting readers. parent._items.EnqueueTail(item); - return ChannelUtilities.s_trueTask; + return default; } else if (parent._mode == BoundedChannelFullMode.Wait) { // The channel is full and we're in a wait mode. // Queue the writer. - var writer = WriterInteractor.Create(runContinuationsAsynchronously: true, item, cancellationToken); + var writer = new VoidAsyncOperationWithData(runContinuationsAsynchronously: true, cancellationToken); + writer.Item = item; parent._blockedWriters.EnqueueTail(writer); - return writer.Task; + return new ValueTask(writer); } else if (parent._mode == BoundedChannelFullMode.DropWrite) { // The channel is full and we're in ignore mode. // Ignore the item but say we accepted it. - return ChannelUtilities.s_trueTask; + return default; } else { @@ -470,7 +475,7 @@ public override Task WriteAsync(T item, CancellationToken cancellationToken) parent._items.DequeueTail() : parent._items.DequeueHead(); parent._items.EnqueueTail(item); - return ChannelUtilities.s_trueTask; + return default; } } @@ -487,10 +492,10 @@ public override Task WriteAsync(T item, CancellationToken cancellationToken) // any waiting readers that there may be something for them to consume. // Since we're no longer holding the lock, it's possible we'll end up // waking readers that have since come in. - waitingReaders.Success(item: true); + ChannelUtilities.WakeUpWaiters(ref waitingReadersTail, result: true); } - return ChannelUtilities.s_trueTask; + return default; } /// Gets the number of items in the channel. This should only be used by the debugger. @@ -512,12 +517,12 @@ private void AssertInvariants() if (!_items.IsEmpty) { Debug.Assert(_blockedReaders.IsEmpty, "There are items available, so there shouldn't be any blocked readers."); - Debug.Assert(_waitingReaders == null, "There are items available, so there shouldn't be any waiting readers."); + Debug.Assert(_waitingReadersTail == null, "There are items available, so there shouldn't be any waiting readers."); } if (_items.Count < _bufferedCapacity) { Debug.Assert(_blockedWriters.IsEmpty, "There's space available, so there shouldn't be any blocked writers."); - Debug.Assert(_waitingWriters == null, "There's space available, so there shouldn't be any waiting writers."); + Debug.Assert(_waitingWritersTail == null, "There's space available, so there shouldn't be any waiting writers."); } if (!_blockedReaders.IsEmpty) { diff --git a/src/System.Threading.Channels/src/System/Threading/Channels/ChannelReader.cs b/src/System.Threading.Channels/src/System/Threading/Channels/ChannelReader.cs index 1a4d830c6026..2a838e8231b0 100644 --- a/src/System.Threading.Channels/src/System/Threading/Channels/ChannelReader.cs +++ b/src/System.Threading.Channels/src/System/Threading/Channels/ChannelReader.cs @@ -29,7 +29,7 @@ public abstract class ChannelReader /// A that will complete with a true result when data is available to read /// or with a false result when no further data will ever be available to be read. /// - public abstract Task WaitToReadAsync(CancellationToken cancellationToken = default); + public abstract ValueTask WaitToReadAsync(CancellationToken cancellationToken = default); /// Asynchronously reads an item from the channel. /// A used to cancel the read operation. diff --git a/src/System.Threading.Channels/src/System/Threading/Channels/ChannelUtilities.cs b/src/System.Threading.Channels/src/System/Threading/Channels/ChannelUtilities.cs index 5b4b39b2397c..711e97a1285e 100644 --- a/src/System.Threading.Channels/src/System/Threading/Channels/ChannelUtilities.cs +++ b/src/System.Threading.Channels/src/System/Threading/Channels/ChannelUtilities.cs @@ -60,76 +60,57 @@ internal static ValueTask GetInvalidCompletionValueTask(Exception error) return new ValueTask(t); } - /// Wake up all of the waiters and null out the field. - /// The waiters. - /// The value with which to complete each waiter. - internal static void WakeUpWaiters(ref ReaderInteractor waiters, bool result) + internal static ValueTask QueueWaiter(ref AsyncOperation tail, AsyncOperation waiter) { - ReaderInteractor w = waiters; - if (w != null) + AsyncOperation c = tail; + if (c == null) { - w.Success(result); - waiters = null; + waiter.Next = waiter; } + else + { + waiter.Next = c.Next; + c.Next = waiter; + } + tail = waiter; + return new ValueTask(waiter); } - /// Wake up all of the waiters and null out the field. - /// The waiters. - /// The success value with which to complete each waiter if error is null. - /// The failure with which to cmplete each waiter, if non-null. - internal static void WakeUpWaiters(ref ReaderInteractor waiters, bool result, Exception error = null) + internal static void WakeUpWaiters(ref AsyncOperation listTail, bool result, Exception error = null) { - ReaderInteractor w = waiters; - if (w != null) + AsyncOperation tail = listTail; + if (tail != null) { - if (error != null) - { - w.Fail(error); - } - else + listTail = null; + + AsyncOperation head = tail.Next; + AsyncOperation c = head; + do { - w.Success(result); + AsyncOperation next = c.Next; + c.Next = null; + + bool completed = error != null ? c.Fail(error) : c.Success(result); + Debug.Assert(completed || c.CancellationToken.CanBeCanceled); + + c = next; } - waiters = null; + while (c != head); } } - /// Removes all interactors from the queue, failing each. - /// The queue of interactors to complete. - /// The error with which to complete each interactor. - internal static void FailInteractors(Dequeue interactors, Exception error) where T : Interactor + /// Removes all operations from the queue, failing each. + /// The queue of operations to complete. + /// The error with which to complete each operations. + internal static void FailOperations(Dequeue operations, Exception error) where T : AsyncOperation { Debug.Assert(error != null); - while (!interactors.IsEmpty) + while (!operations.IsEmpty) { - interactors.DequeueHead().Fail(error); + operations.DequeueHead().Fail(error); } } - /// Gets or creates a "waiter" (e.g. WaitForRead/WriteAsync) interactor. - /// The field storing the waiter interactor. - /// true to force continuations to run asynchronously; otherwise, false. - /// The token to use to cancel the wait. - internal static Task GetOrCreateWaiter(ref ReaderInteractor waiter, bool runContinuationsAsynchronously, CancellationToken cancellationToken) - { - // Get the existing waiters interactor. - ReaderInteractor w = waiter; - - // If there isn't one, create one. This explicitly does not include the cancellation token, - // as we reuse it for any number of waiters that overlap. - if (w == null) - { - waiter = w = ReaderInteractor.Create(runContinuationsAsynchronously); - } - - // If the cancellation token can't be canceled, then just return the waiter task. - // If it can, we need to return a task that will complete when the waiter task does but that can also be canceled. - // Easiest way to do that is with a cancelable continuation. - return cancellationToken.CanBeCanceled ? - w.Task.ContinueWith(t => t.Result, cancellationToken, TaskContinuationOptions.ExecuteSynchronously, TaskScheduler.Default) : - w.Task; - } - /// Creates and returns an exception object to indicate that a channel has been closed. internal static Exception CreateInvalidCompletionException(Exception inner = null) => inner is OperationCanceledException ? inner : diff --git a/src/System.Threading.Channels/src/System/Threading/Channels/ChannelWriter.cs b/src/System.Threading.Channels/src/System/Threading/Channels/ChannelWriter.cs index d09fa1b0d0af..2399c4187a92 100644 --- a/src/System.Threading.Channels/src/System/Threading/Channels/ChannelWriter.cs +++ b/src/System.Threading.Channels/src/System/Threading/Channels/ChannelWriter.cs @@ -31,38 +31,38 @@ public abstract class ChannelWriter /// A that will complete with a true result when space is available to write an item /// or with a false result when no further writing will be permitted. /// - public abstract Task WaitToWriteAsync(CancellationToken cancellationToken = default); + public abstract ValueTask WaitToWriteAsync(CancellationToken cancellationToken = default); /// Asynchronously writes an item to the channel. /// The value to write to the channel. /// A used to cancel the write operation. /// A that represents the asynchronous write operation. - public virtual Task WriteAsync(T item, CancellationToken cancellationToken = default) + public virtual ValueTask WriteAsync(T item, CancellationToken cancellationToken = default) { try { return - cancellationToken.IsCancellationRequested ? Task.FromCanceled(cancellationToken) : - TryWrite(item) ? Task.CompletedTask : - WriteAsyncCore(item, cancellationToken); + cancellationToken.IsCancellationRequested ? new ValueTask(Task.FromCanceled(cancellationToken)) : + TryWrite(item) ? default : + new ValueTask(WriteAsyncCore(item, cancellationToken)); } catch (Exception e) { - return Task.FromException(e); + return new ValueTask(Task.FromException(e)); } + } - async Task WriteAsyncCore(T innerItem, CancellationToken ct) + private async Task WriteAsyncCore(T innerItem, CancellationToken ct) + { + while (await WaitToWriteAsync(ct).ConfigureAwait(false)) { - while (await WaitToWriteAsync(ct).ConfigureAwait(false)) + if (TryWrite(innerItem)) { - if (TryWrite(innerItem)) - { - return; - } + return; } - - throw ChannelUtilities.CreateInvalidCompletionException(); } + + throw ChannelUtilities.CreateInvalidCompletionException(); } /// Mark the channel as being complete, meaning no more items will be written to it. diff --git a/src/System.Threading.Channels/src/System/Threading/Channels/Interactor.cs b/src/System.Threading.Channels/src/System/Threading/Channels/Interactor.cs deleted file mode 100644 index f4e0c74767bb..000000000000 --- a/src/System.Threading.Channels/src/System/Threading/Channels/Interactor.cs +++ /dev/null @@ -1,149 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using System.Threading.Tasks; - -namespace System.Threading.Channels -{ - /// A base class for a blocked or waiting reader or writer. - /// Specifies the type of data passed to the reader or writer. - internal abstract class Interactor : TaskCompletionSource - { - /// Initializes the interactor. - /// true if continuations should be forced to run asynchronously; otherwise, false. - protected Interactor(bool runContinuationsAsynchronously) : - base(runContinuationsAsynchronously ? TaskCreationOptions.RunContinuationsAsynchronously : TaskCreationOptions.None) { } - - /// Completes the interactor with a success state and the specified result. - /// The result value. - /// true if the interactor could be successfully transitioned to a completed state; false if it was already completed. - internal bool Success(T item) - { - UnregisterCancellation(); - return TrySetResult(item); - } - - /// Completes the interactor with a failed state and the specified error. - /// The error. - /// true if the interactor could be successfully transitioned to a completed state; false if it was already completed. - internal bool Fail(Exception exception) - { - UnregisterCancellation(); - return TrySetException(exception); - } - - /// Unregister cancellation in case cancellation was registered. - internal virtual void UnregisterCancellation() { } - } - - /// A blocked or waiting reader. - /// Specifies the type of data being read. - internal class ReaderInteractor : Interactor - { - /// Initializes the reader. - /// true if continuations should be forced to run asynchronously; otherwise, false. - protected ReaderInteractor(bool runContinuationsAsynchronously) : base(runContinuationsAsynchronously) { } - - /// Creates a reader. - /// true if continuations should be forced to run asynchronously; otherwise, false. - /// The reader. - public static ReaderInteractor Create(bool runContinuationsAsynchronously) => - new ReaderInteractor(runContinuationsAsynchronously); - - /// Creates a reader. - /// true if continuations should be forced to run asynchronously; otherwise, false. - /// A that can be used to cancel the read operation. - /// The reader. - public static ReaderInteractor Create(bool runContinuationsAsynchronously, CancellationToken cancellationToken) => - cancellationToken.CanBeCanceled ? - new CancelableReaderInteractor(runContinuationsAsynchronously, cancellationToken) : - new ReaderInteractor(runContinuationsAsynchronously); - } - - /// A blocked or waiting writer. - /// Specifies the type of data being written. - internal class WriterInteractor : Interactor - { - /// Initializes the writer. - /// true if continuations should be forced to run asynchronously; otherwise, false. - protected WriterInteractor(bool runContinuationsAsynchronously) : base(runContinuationsAsynchronously) { } - - /// The item being written. - internal T Item { get; private set; } - - /// Creates a writer. - /// true if continuations should be forced to run asynchronously; otherwise, false. - /// The item being written. - /// A that can be used to cancel the read operation. - /// The reader. - public static WriterInteractor Create(bool runContinuationsAsynchronously, T item, CancellationToken cancellationToken) - { - WriterInteractor w = cancellationToken.CanBeCanceled ? - new CancelableWriter(runContinuationsAsynchronously, cancellationToken) : - new WriterInteractor(runContinuationsAsynchronously); - w.Item = item; - return w; - } - } - - /// A blocked or waiting reader where the read can be canceled. - /// Specifies the type of data being read. - internal sealed class CancelableReaderInteractor : ReaderInteractor - { - /// The token used for cancellation. - private readonly CancellationToken _token; - /// Registration in that should be disposed of when the operation has completed. - private CancellationTokenRegistration _registration; - - /// Initializes the cancelable reader. - /// true if continuations should be forced to run asynchronously; otherwise, false. - /// A that can be used to cancel the read operation. - internal CancelableReaderInteractor(bool runContinuationsAsynchronously, CancellationToken cancellationToken) : base(runContinuationsAsynchronously) - { - _token = cancellationToken; - _registration = cancellationToken.Register(s => - { - var thisRef = (CancelableReaderInteractor)s; - thisRef.TrySetCanceled(thisRef._token); - }, this); - } - - /// Unregister cancellation in case cancellation was registered. - internal override void UnregisterCancellation() - { - _registration.Dispose(); - _registration = default; - } - } - - /// A blocked or waiting reader where the read can be canceled. - /// Specifies the type of data being read. - internal sealed class CancelableWriter : WriterInteractor - { - /// The token used for cancellation. - private CancellationToken _token; - /// Registration in that should be disposed of when the operation has completed. - private CancellationTokenRegistration _registration; - - /// Initializes the cancelable writer. - /// true if continuations should be forced to run asynchronously; otherwise, false. - /// A that can be used to cancel the read operation. - internal CancelableWriter(bool runContinuationsAsynchronously, CancellationToken cancellationToken) : base(runContinuationsAsynchronously) - { - _token = cancellationToken; - _registration = cancellationToken.Register(s => - { - var thisRef = (CancelableWriter)s; - thisRef.TrySetCanceled(thisRef._token); - }, this); - } - - /// Unregister cancellation in case cancellation was registered. - internal override void UnregisterCancellation() - { - _registration.Dispose(); - _registration = default; - } - } -} diff --git a/src/System.Threading.Channels/src/System/Threading/Channels/SingleConsumerUnboundedChannel.cs b/src/System.Threading.Channels/src/System/Threading/Channels/SingleConsumerUnboundedChannel.cs index aef99047258b..3a8760966b54 100644 --- a/src/System.Threading.Channels/src/System/Threading/Channels/SingleConsumerUnboundedChannel.cs +++ b/src/System.Threading.Channels/src/System/Threading/Channels/SingleConsumerUnboundedChannel.cs @@ -6,7 +6,6 @@ using System.Collections.Generic; using System.Diagnostics; using System.Threading.Tasks; -using System.Runtime.CompilerServices; namespace System.Threading.Channels { @@ -32,11 +31,11 @@ internal sealed class SingleConsumerUnboundedChannel : Channel, IDebugEnum /// non-null if the channel has been marked as complete for writing. private volatile Exception _doneWriting; - /// A if there's a blocked reader. - private ReaderInteractor _blockedReader; + /// An if there's a blocked reader. + private AsyncOperation _blockedReader; /// A waiting reader (e.g. WaitForReadAsync) if there is one. - private ReaderInteractor _waitingReader; + private AsyncOperation _waitingReader; /// Initialize the channel. /// Whether to force continuations to be executed asynchronously. @@ -54,49 +53,69 @@ internal SingleConsumerUnboundedChannel(bool runContinuationsAsynchronously) private sealed class UnboundedChannelReader : ChannelReader, IDebugEnumerable { internal readonly SingleConsumerUnboundedChannel _parent; - internal UnboundedChannelReader(SingleConsumerUnboundedChannel parent) => _parent = parent; + private readonly AsyncOperation _readerSingleton; + private readonly AsyncOperation _waiterSingleton; + + internal UnboundedChannelReader(SingleConsumerUnboundedChannel parent) + { + _parent = parent; + _readerSingleton = new AsyncOperation(parent._runContinuationsAsynchronously) { UnsafeState = ResettableValueTaskSource.States.Released }; + _waiterSingleton = new AsyncOperation(parent._runContinuationsAsynchronously) { UnsafeState = ResettableValueTaskSource.States.Released }; + } public override Task Completion => _parent._completion.Task; public override ValueTask ReadAsync(CancellationToken cancellationToken) { + if (cancellationToken.IsCancellationRequested) + { + return new ValueTask(Task.FromCanceled(cancellationToken)); + } + + if (TryRead(out T item)) { - return TryRead(out T item) ? - new ValueTask(item) : - ReadAsyncCore(cancellationToken); + return new ValueTask(item); } - ValueTask ReadAsyncCore(CancellationToken ct) + SingleConsumerUnboundedChannel parent = _parent; + + AsyncOperation oldBlockedReader, newBlockedReader; + lock (parent.SyncObj) { - SingleConsumerUnboundedChannel parent = _parent; - if (ct.IsCancellationRequested) + // Now that we hold the lock, try reading again. + if (TryRead(out item)) { - return new ValueTask(Task.FromCanceled(ct)); + return new ValueTask(item); } - lock (parent.SyncObj) + // If no more items will be written, fail the read. + if (parent._doneWriting != null) { - // Now that we hold the lock, try reading again. - if (TryRead(out T item)) - { - return new ValueTask(item); - } + return ChannelUtilities.GetInvalidCompletionValueTask(parent._doneWriting); + } - // If no more items will be written, fail the read. - if (parent._doneWriting != null) + // Try to use the singleton reader. If it's currently being used, then the channel + // is being used erroneously, and we cancel the outstanding operation. + oldBlockedReader = parent._blockedReader; + if (!cancellationToken.CanBeCanceled && _readerSingleton.TryOwnAndReset()) + { + newBlockedReader = _readerSingleton; + if (newBlockedReader == oldBlockedReader) { - return ChannelUtilities.GetInvalidCompletionValueTask(parent._doneWriting); + // The previous operation completed, so null out the "old" reader + // so we don't end up canceling the new operation. + oldBlockedReader = null; } - - Debug.Assert(parent._blockedReader == null || parent._blockedReader.Task.IsCanceled, - "Incorrect usage; multiple outstanding reads were issued against this single-consumer channel"); - - // Store the reader to be completed by a writer. - var reader = ReaderInteractor.Create(parent._runContinuationsAsynchronously, ct); - parent._blockedReader = reader; - return new ValueTask(reader.Task); } + else + { + newBlockedReader = new AsyncOperation(_parent._runContinuationsAsynchronously, cancellationToken); + } + parent._blockedReader = newBlockedReader; } + + oldBlockedReader?.TrySetCanceled(); + return new ValueTask(newBlockedReader); } public override bool TryRead(out T item) @@ -113,44 +132,59 @@ public override bool TryRead(out T item) return false; } - public override Task WaitToReadAsync(CancellationToken cancellationToken) + public override ValueTask WaitToReadAsync(CancellationToken cancellationToken) { // Outside of the lock, check if there are any items waiting to be read. If there are, we're done. - return - cancellationToken.IsCancellationRequested ? Task.FromCanceled(cancellationToken) : - !_parent._items.IsEmpty ? ChannelUtilities.s_trueTask : - WaitToReadAsyncCore(cancellationToken); + if (cancellationToken.IsCancellationRequested) + { + return new ValueTask(Task.FromCanceled(cancellationToken)); + } - Task WaitToReadAsyncCore(CancellationToken ct) + if (!_parent._items.IsEmpty) { - SingleConsumerUnboundedChannel parent = _parent; - ReaderInteractor oldWaiter = null, newWaiter; - lock (parent.SyncObj) + return new ValueTask(true); + } + + SingleConsumerUnboundedChannel parent = _parent; + AsyncOperation oldWaitingReader = null, newWaitingReader; + lock (parent.SyncObj) + { + // Again while holding the lock, check to see if there are any items available. + if (!parent._items.IsEmpty) { - // Again while holding the lock, check to see if there are any items available. - if (!parent._items.IsEmpty) - { - return ChannelUtilities.s_trueTask; - } + return new ValueTask(true); + } - // There aren't any items; if we're done writing, there never will be more items. - if (parent._doneWriting != null) + // There aren't any items; if we're done writing, there never will be more items. + if (parent._doneWriting != null) + { + return parent._doneWriting != ChannelUtilities.s_doneWritingSentinel ? + new ValueTask(Task.FromException(parent._doneWriting)) : + new ValueTask(false); + } + + // Try to use the singleton waiter. If it's currently being used, then the channel + // is being used erroneously, and we cancel the outstanding operation. + oldWaitingReader = parent._waitingReader; + if (!cancellationToken.CanBeCanceled && _waiterSingleton.TryOwnAndReset()) + { + newWaitingReader = _waiterSingleton; + if (newWaitingReader == oldWaitingReader) { - return parent._doneWriting != ChannelUtilities.s_doneWritingSentinel ? - Task.FromException(parent._doneWriting) : - ChannelUtilities.s_falseTask; + // The previous operation completed, so null out the "old" waiter + // so we don't end up canceling the new operation. + oldWaitingReader = null; } - - // Create the new waiter. We're a bit more tolerant of a stray waiting reader - // than we are of a blocked reader, as with usage patterns it's easier to leave one - // behind, so we just cancel any that may have been waiting around. - oldWaiter = parent._waitingReader; - parent._waitingReader = newWaiter = ReaderInteractor.Create(parent._runContinuationsAsynchronously, ct); } - - oldWaiter?.TrySetCanceled(); - return newWaiter.Task; + else + { + newWaitingReader = new AsyncOperation(_parent._runContinuationsAsynchronously, cancellationToken); + } + parent._waitingReader = newWaitingReader; } + + oldWaitingReader?.TrySetCanceled(); + return new ValueTask(newWaitingReader); } /// Gets the number of items in the channel. This should only be used by the debugger. @@ -169,8 +203,8 @@ private sealed class UnboundedChannelWriter : ChannelWriter, IDebugEnumerable public override bool TryComplete(Exception error) { - ReaderInteractor blockedReader = null; - ReaderInteractor waitingReader = null; + AsyncOperation blockedReader = null; + AsyncOperation waitingReader = null; bool completeTask = false; SingleConsumerUnboundedChannel parent = _parent; @@ -244,8 +278,8 @@ public override bool TryWrite(T item) SingleConsumerUnboundedChannel parent = _parent; while (true) // in case a reader was canceled and we need to try again { - ReaderInteractor blockedReader = null; - ReaderInteractor waitingReader = null; + AsyncOperation blockedReader = null; + AsyncOperation waitingReader = null; lock (parent.SyncObj) { @@ -297,22 +331,22 @@ public override bool TryWrite(T item) } } - public override Task WaitToWriteAsync(CancellationToken cancellationToken) + public override ValueTask WaitToWriteAsync(CancellationToken cancellationToken) { Exception doneWriting = _parent._doneWriting; return - cancellationToken.IsCancellationRequested ? Task.FromCanceled(cancellationToken) : - doneWriting == null ? ChannelUtilities.s_trueTask : - doneWriting != ChannelUtilities.s_doneWritingSentinel ? Task.FromException(doneWriting) : - ChannelUtilities.s_falseTask; + cancellationToken.IsCancellationRequested ? new ValueTask(Task.FromCanceled(cancellationToken)) : + doneWriting == null ? new ValueTask(true) : + doneWriting != ChannelUtilities.s_doneWritingSentinel ? new ValueTask(Task.FromException(doneWriting)) : + new ValueTask(false); } - public override Task WriteAsync(T item, CancellationToken cancellationToken) => + public override ValueTask WriteAsync(T item, CancellationToken cancellationToken) => // Writing always succeeds (unless we've already completed writing or cancellation has been requested), // so just TryWrite and return a completed task. - cancellationToken.IsCancellationRequested ? Task.FromCanceled(cancellationToken) : - TryWrite(item) ? Task.CompletedTask : - Task.FromException(ChannelUtilities.CreateInvalidCompletionException(_parent._doneWriting)); + cancellationToken.IsCancellationRequested ? new ValueTask(Task.FromCanceled(cancellationToken)) : + TryWrite(item) ? default : + new ValueTask(Task.FromException(ChannelUtilities.CreateInvalidCompletionException(_parent._doneWriting))); /// Gets the number of items in the channel. This should only be used by the debugger. private int ItemsCountForDebugger => _parent._items.Count; diff --git a/src/System.Threading.Channels/src/System/Threading/Channels/UnboundedChannel.cs b/src/System.Threading.Channels/src/System/Threading/Channels/UnboundedChannel.cs index 4fe6e92f3a5f..e25a829d85c2 100644 --- a/src/System.Threading.Channels/src/System/Threading/Channels/UnboundedChannel.cs +++ b/src/System.Threading.Channels/src/System/Threading/Channels/UnboundedChannel.cs @@ -19,12 +19,12 @@ internal sealed class UnboundedChannel : Channel, IDebugEnumerable /// The items in the channel. private readonly ConcurrentQueue _items = new ConcurrentQueue(); /// Readers blocked reading from the channel. - private readonly Dequeue> _blockedReaders = new Dequeue>(); + private readonly Dequeue> _blockedReaders = new Dequeue>(); /// Whether to force continuations to be executed asynchronously from producer writes. private readonly bool _runContinuationsAsynchronously; /// Readers waiting for a notification that data is available. - private ReaderInteractor _waitingReaders; + private AsyncOperation _waitingReadersTail; /// Set to non-null once Complete has been called. private Exception _doneWriting; @@ -33,7 +33,7 @@ internal UnboundedChannel(bool runContinuationsAsynchronously) { _runContinuationsAsynchronously = runContinuationsAsynchronously; _completion = new TaskCompletionSource(runContinuationsAsynchronously ? TaskCreationOptions.RunContinuationsAsynchronously : TaskCreationOptions.None); - base.Reader = new UnboundedChannelReader(this); + Reader = new UnboundedChannelReader(this); Writer = new UnboundedChannelWriter(this); } @@ -42,7 +42,15 @@ internal UnboundedChannel(bool runContinuationsAsynchronously) private sealed class UnboundedChannelReader : ChannelReader, IDebugEnumerable { internal readonly UnboundedChannel _parent; - internal UnboundedChannelReader(UnboundedChannel parent) => _parent = parent; + private readonly AsyncOperation _readerSingleton; + private readonly AsyncOperation _waiterSingleton; + + internal UnboundedChannelReader(UnboundedChannel parent) + { + _parent = parent; + _readerSingleton = new AsyncOperation(parent._runContinuationsAsynchronously) { UnsafeState = ResettableValueTaskSource.States.Released }; + _waiterSingleton = new AsyncOperation(parent._runContinuationsAsynchronously) { UnsafeState = ResettableValueTaskSource.States.Released }; + } public override Task Completion => _parent._completion.Task; @@ -82,10 +90,21 @@ private ValueTask ReadAsyncCore(CancellationToken cancellationToken) return ChannelUtilities.GetInvalidCompletionValueTask(parent._doneWriting); } - // Otherwise, queue the reader. - var reader = ReaderInteractor.Create(parent._runContinuationsAsynchronously, cancellationToken); + // If we're able to use the singleton reader, do so. + if (!cancellationToken.CanBeCanceled) + { + AsyncOperation singleton = _readerSingleton; + if (singleton.TryOwnAndReset()) + { + parent._blockedReaders.EnqueueTail(singleton); + return new ValueTask(singleton); + } + } + + // Otherwise, create and queue a reader. + var reader = new AsyncOperation(parent._runContinuationsAsynchronously, cancellationToken); parent._blockedReaders.EnqueueTail(reader); - return new ValueTask(reader.Task); + return new ValueTask(reader); } } @@ -108,38 +127,53 @@ public override bool TryRead(out T item) return false; } - public override Task WaitToReadAsync(CancellationToken cancellationToken) + public override ValueTask WaitToReadAsync(CancellationToken cancellationToken) { - return - cancellationToken.IsCancellationRequested ? Task.FromCanceled(cancellationToken) : - !_parent._items.IsEmpty ? ChannelUtilities.s_trueTask : - WaitToReadAsyncCore(cancellationToken); + if (cancellationToken.IsCancellationRequested) + { + return new ValueTask(Task.FromCanceled(cancellationToken)); + } - Task WaitToReadAsyncCore(CancellationToken ct) + if (!_parent._items.IsEmpty) { - UnboundedChannel parent = _parent; + return new ValueTask(true); + } - lock (parent.SyncObj) + UnboundedChannel parent = _parent; + + lock (parent.SyncObj) + { + parent.AssertInvariants(); + + // Try again to read now that we're synchronized with writers. + if (!parent._items.IsEmpty) { - parent.AssertInvariants(); + return new ValueTask(true); + } - // Try again to read now that we're synchronized with writers. - if (!parent._items.IsEmpty) - { - return ChannelUtilities.s_trueTask; - } + // There are no items, so if we're done writing, there's never going to be data available. + if (parent._doneWriting != null) + { + return parent._doneWriting != ChannelUtilities.s_doneWritingSentinel ? + new ValueTask(Task.FromException(parent._doneWriting)) : + new ValueTask(false); + } - // There are no items, so if we're done writing, there's never going to be data available. - if (parent._doneWriting != null) + // If we're able to use the singleton waiter, do so. + if (!cancellationToken.CanBeCanceled) + { + AsyncOperation singleton = _waiterSingleton; + if (singleton.TryOwnAndReset()) { - return parent._doneWriting != ChannelUtilities.s_doneWritingSentinel ? - Task.FromException(parent._doneWriting) : - ChannelUtilities.s_falseTask; + ChannelUtilities.QueueWaiter(ref parent._waitingReadersTail, singleton); + return new ValueTask(singleton); } - - // Queue the waiter - return ChannelUtilities.GetOrCreateWaiter(ref parent._waitingReaders, parent._runContinuationsAsynchronously, ct); } + + // Otherwise, create and queue a waiter. + var waiter = new AsyncOperation(parent._runContinuationsAsynchronously, cancellationToken); + ChannelUtilities.QueueWaiter(ref parent._waitingReadersTail, waiter); + return new ValueTask(waiter); } } @@ -190,8 +224,8 @@ public override bool TryComplete(Exception error) // At this point, _blockedReaders and _waitingReaders will not be mutated: // they're only mutated by readers while holding the lock, and only if _doneWriting is null. // freely manipulate _blockedReaders and _waitingReaders without any concurrency concerns. - ChannelUtilities.FailInteractors, T>(parent._blockedReaders, ChannelUtilities.CreateInvalidCompletionException(error)); - ChannelUtilities.WakeUpWaiters(ref parent._waitingReaders, result: false, error: error); + ChannelUtilities.FailOperations, T>(parent._blockedReaders, ChannelUtilities.CreateInvalidCompletionException(error)); + ChannelUtilities.WakeUpWaiters(ref parent._waitingReadersTail, result: false, error: error); // Successfully transitioned to completed. return true; @@ -202,8 +236,8 @@ public override bool TryWrite(T item) UnboundedChannel parent = _parent; while (true) { - ReaderInteractor blockedReader = null; - ReaderInteractor waitingReaders = null; + AsyncOperation blockedReader = null; + AsyncOperation waitingReadersTail = null; lock (parent.SyncObj) { // If writing has already been marked as done, fail the write. @@ -222,12 +256,12 @@ public override bool TryWrite(T item) if (parent._blockedReaders.IsEmpty) { parent._items.Enqueue(item); - waitingReaders = parent._waitingReaders; - if (waitingReaders == null) + waitingReadersTail = parent._waitingReadersTail; + if (waitingReadersTail == null) { return true; } - parent._waitingReaders = null; + parent._waitingReadersTail = null; } else { @@ -251,26 +285,26 @@ public override bool TryWrite(T item) // we could cause some spurious wake-ups here, if we tell a waiter there's // something available but all data has already been removed. It's a benign // race condition, though, as consumers already need to account for such things. - waitingReaders.Success(item: true); + ChannelUtilities.WakeUpWaiters(ref waitingReadersTail, result: true); return true; } } } - public override Task WaitToWriteAsync(CancellationToken cancellationToken) + public override ValueTask WaitToWriteAsync(CancellationToken cancellationToken) { Exception doneWriting = _parent._doneWriting; return - cancellationToken.IsCancellationRequested ? Task.FromCanceled(cancellationToken) : - doneWriting == null ? ChannelUtilities.s_trueTask : // unbounded writing can always be done if we haven't completed - doneWriting != ChannelUtilities.s_doneWritingSentinel ? Task.FromException(doneWriting) : - ChannelUtilities.s_falseTask; + cancellationToken.IsCancellationRequested ? new ValueTask(Task.FromCanceled(cancellationToken)) : + doneWriting == null ? new ValueTask(true) : // unbounded writing can always be done if we haven't completed + doneWriting != ChannelUtilities.s_doneWritingSentinel ? new ValueTask(Task.FromException(doneWriting)) : + new ValueTask(false); } - public override Task WriteAsync(T item, CancellationToken cancellationToken) => - cancellationToken.IsCancellationRequested ? Task.FromCanceled(cancellationToken) : - TryWrite(item) ? ChannelUtilities.s_trueTask : - Task.FromException(ChannelUtilities.CreateInvalidCompletionException(_parent._doneWriting)); + public override ValueTask WriteAsync(T item, CancellationToken cancellationToken) => + cancellationToken.IsCancellationRequested ? new ValueTask(Task.FromCanceled(cancellationToken)) : + TryWrite(item) ? default : + new ValueTask(Task.FromException(ChannelUtilities.CreateInvalidCompletionException(_parent._doneWriting))); /// Gets the number of items in the channel. This should only be used by the debugger. private int ItemsCountForDebugger => _parent._items.Count; @@ -293,11 +327,11 @@ private void AssertInvariants() if (_runContinuationsAsynchronously) { Debug.Assert(_blockedReaders.IsEmpty, "There's data available, so there shouldn't be any blocked readers."); - Debug.Assert(_waitingReaders == null, "There's data available, so there shouldn't be any waiting readers."); + Debug.Assert(_waitingReadersTail == null, "There's data available, so there shouldn't be any waiting readers."); } Debug.Assert(!_completion.Task.IsCompleted, "We still have data available, so shouldn't be completed."); } - if ((!_blockedReaders.IsEmpty || _waitingReaders != null) && _runContinuationsAsynchronously) + if ((!_blockedReaders.IsEmpty || _waitingReadersTail != null) && _runContinuationsAsynchronously) { Debug.Assert(_items.IsEmpty, "There are blocked/waiting readers, so there shouldn't be any data available."); } diff --git a/src/System.Threading.Channels/tests/BoundedChannelTests.cs b/src/System.Threading.Channels/tests/BoundedChannelTests.cs index fc264e8d8635..903d85c2a085 100644 --- a/src/System.Threading.Channels/tests/BoundedChannelTests.cs +++ b/src/System.Threading.Channels/tests/BoundedChannelTests.cs @@ -9,11 +9,11 @@ namespace System.Threading.Channels.Tests { public class BoundedChannelTests : ChannelTestBase { - protected override Channel CreateChannel() => Channel.CreateBounded(1); - protected override Channel CreateFullChannel() + protected override Channel CreateChannel() => Channel.CreateBounded(new BoundedChannelOptions(1) { AllowSynchronousContinuations = AllowSynchronousContinuations }); + protected override Channel CreateFullChannel() { - var c = Channel.CreateBounded(1); - c.Writer.WriteAsync(42).Wait(); + var c = Channel.CreateBounded(new BoundedChannelOptions(1) { AllowSynchronousContinuations = AllowSynchronousContinuations }); + c.Writer.WriteAsync(default).AsTask().Wait(); return c; } @@ -218,16 +218,16 @@ public void WriteAsync_TryRead_Many_Ignore(int bufferedCapacity) public async Task CancelPendingWrite_Reading_DataTransferredFromCorrectWriter() { var c = Channel.CreateBounded(1); - Assert.Equal(TaskStatus.RanToCompletion, c.Writer.WriteAsync(42).Status); + Assert.True(c.Writer.WriteAsync(42).IsCompletedSuccessfully); var cts = new CancellationTokenSource(); - Task write1 = c.Writer.WriteAsync(43, cts.Token); + Task write1 = c.Writer.WriteAsync(43, cts.Token).AsTask(); Assert.Equal(TaskStatus.WaitingForActivation, write1.Status); cts.Cancel(); - Task write2 = c.Writer.WriteAsync(44); + Task write2 = c.Writer.WriteAsync(44).AsTask(); Assert.Equal(42, await c.Reader.ReadAsync()); Assert.Equal(44, await c.Reader.ReadAsync()); @@ -342,10 +342,10 @@ public async Task WaitToWriteAsync_AfterFullThenRead_ReturnsTrue() var c = Channel.CreateBounded(1); Assert.True(c.Writer.TryWrite(1)); - Task write1 = c.Writer.WaitToWriteAsync(); + Task write1 = c.Writer.WaitToWriteAsync().AsTask(); Assert.False(write1.IsCompleted); - Task write2 = c.Writer.WaitToWriteAsync(); + Task write2 = c.Writer.WaitToWriteAsync().AsTask(); Assert.False(write2.IsCompleted); Assert.Equal(1, await c.Reader.ReadAsync()); @@ -362,12 +362,12 @@ public void AllowSynchronousContinuations_WaitToReadAsync_ContinuationsInvokedAc var c = Channel.CreateBounded(new BoundedChannelOptions(1) { AllowSynchronousContinuations = allowSynchronousContinuations }); int expectedId = Environment.CurrentManagedThreadId; - Task r = c.Reader.WaitToReadAsync().ContinueWith(_ => + Task r = c.Reader.WaitToReadAsync().AsTask().ContinueWith(_ => { Assert.Equal(allowSynchronousContinuations, expectedId == Environment.CurrentManagedThreadId); }, CancellationToken.None, TaskContinuationOptions.ExecuteSynchronously, TaskScheduler.Default); - Assert.Equal(TaskStatus.RanToCompletion, c.Writer.WriteAsync(42).Status); + Assert.True(c.Writer.WriteAsync(42).IsCompletedSuccessfully); ((IAsyncResult)r).AsyncWaitHandle.WaitOne(); // avoid inlining the continuation r.GetAwaiter().GetResult(); } @@ -391,13 +391,13 @@ public void AllowSynchronousContinuations_CompletionTask_ContinuationsInvokedAcc } [Fact] - public void TryWrite_NoBlockedReaders_WaitingReader_WaiterNotified() + public async Task TryWrite_NoBlockedReaders_WaitingReader_WaiterNotified() { Channel c = CreateChannel(); - Task r = c.Reader.WaitToReadAsync(); + Task r = c.Reader.WaitToReadAsync().AsTask(); Assert.True(c.Writer.TryWrite(42)); - AssertSynchronousTrue(r); + Assert.True(await r); } } } diff --git a/src/System.Threading.Channels/tests/ChannelTestBase.cs b/src/System.Threading.Channels/tests/ChannelTestBase.cs index 4719d938b22f..11879918004b 100644 --- a/src/System.Threading.Channels/tests/ChannelTestBase.cs +++ b/src/System.Threading.Channels/tests/ChannelTestBase.cs @@ -5,6 +5,7 @@ using System.Collections.Generic; using System.Diagnostics; using System.Linq; +using System.Runtime.CompilerServices; using System.Threading.Tasks; using Xunit; @@ -12,9 +13,13 @@ namespace System.Threading.Channels.Tests { public abstract class ChannelTestBase : TestBase { - protected abstract Channel CreateChannel(); - protected abstract Channel CreateFullChannel(); + protected Channel CreateChannel() => CreateChannel(); + protected abstract Channel CreateChannel(); + protected Channel CreateFullChannel() => CreateFullChannel(); + protected abstract Channel CreateFullChannel(); + + protected virtual bool AllowSynchronousContinuations => false; protected virtual bool RequiresSingleReader => false; protected virtual bool RequiresSingleWriter => false; protected virtual bool BuffersItems => true; @@ -79,10 +84,10 @@ public async Task Complete_AfterEmpty_WaitingReader_TriggersCompletion() public async Task Complete_BeforeEmpty_WaitingReaders_TriggersCompletion() { Channel c = CreateChannel(); - Task read = c.Reader.ReadAsync().AsTask(); + ValueTask read = c.Reader.ReadAsync(); c.Writer.Complete(); await c.Reader.Completion; - await Assert.ThrowsAnyAsync(() => read); + await Assert.ThrowsAnyAsync(async () => await read); } [Fact] @@ -247,18 +252,18 @@ public void ManyProducerConsumer_ConcurrentReadWrite_Success(int numReaders, int public void WaitToReadAsync_DataAvailableBefore_CompletesSynchronously() { Channel c = CreateChannel(); - Task write = c.Writer.WriteAsync(42); - Task read = c.Reader.WaitToReadAsync(); - Assert.Equal(TaskStatus.RanToCompletion, read.Status); + ValueTask write = c.Writer.WriteAsync(42); + ValueTask read = c.Reader.WaitToReadAsync(); + Assert.True(read.IsCompletedSuccessfully); } [Fact] public void WaitToReadAsync_DataAvailableAfter_CompletesAsynchronously() { Channel c = CreateChannel(); - Task read = c.Reader.WaitToReadAsync(); + ValueTask read = c.Reader.WaitToReadAsync(); Assert.False(read.IsCompleted); - Task write = c.Writer.WriteAsync(42); + ValueTask write = c.Writer.WriteAsync(42); Assert.True(read.Result); } @@ -267,8 +272,8 @@ public void WaitToReadAsync_AfterComplete_SynchronouslyCompletes() { Channel c = CreateChannel(); c.Writer.Complete(); - Task read = c.Reader.WaitToReadAsync(); - Assert.Equal(TaskStatus.RanToCompletion, read.Status); + ValueTask read = c.Reader.WaitToReadAsync(); + Assert.True(read.IsCompletedSuccessfully); Assert.False(read.Result); } @@ -276,7 +281,7 @@ public void WaitToReadAsync_AfterComplete_SynchronouslyCompletes() public void WaitToReadAsync_BeforeComplete_AsynchronouslyCompletes() { Channel c = CreateChannel(); - Task read = c.Reader.WaitToReadAsync(); + ValueTask read = c.Reader.WaitToReadAsync(); Assert.False(read.IsCompleted); c.Writer.Complete(); Assert.False(read.Result); @@ -287,8 +292,8 @@ public void WaitToWriteAsync_AfterComplete_SynchronouslyCompletes() { Channel c = CreateChannel(); c.Writer.Complete(); - Task write = c.Writer.WaitToWriteAsync(); - Assert.Equal(TaskStatus.RanToCompletion, write.Status); + ValueTask write = c.Writer.WaitToWriteAsync(); + Assert.True(write.IsCompletedSuccessfully); Assert.False(write.Result); } @@ -301,8 +306,8 @@ public void WaitToWriteAsync_EmptyChannel_SynchronouslyCompletes() } Channel c = CreateChannel(); - Task write = c.Writer.WaitToWriteAsync(); - Assert.Equal(TaskStatus.RanToCompletion, write.Status); + ValueTask write = c.Writer.WaitToWriteAsync(); + Assert.True(write.IsCompletedSuccessfully); Assert.True(write.Result); } @@ -316,7 +321,7 @@ public async Task WaitToWriteAsync_ManyConcurrent_SatisifedByReaders() Channel c = CreateChannel(); - Task[] writers = Enumerable.Range(0, 100).Select(_ => c.Writer.WaitToWriteAsync()).ToArray(); + Task[] writers = Enumerable.Range(0, 100).Select(_ => c.Writer.WaitToWriteAsync().AsTask()).ToArray(); Task[] readers = Enumerable.Range(0, 100).Select(_ => c.Reader.ReadAsync().AsTask()).ToArray(); await Task.WhenAll(writers); @@ -334,7 +339,7 @@ public void WaitToWriteAsync_BlockedReader_ReturnsTrue() public void TryRead_DataAvailable_Success() { Channel c = CreateChannel(); - Task write = c.Writer.WriteAsync(42); + ValueTask write = c.Writer.WriteAsync(42); Assert.True(c.Reader.TryRead(out int result)); Assert.Equal(42, result); } @@ -360,7 +365,7 @@ public async Task WriteAsync_AfterComplete_ThrowsException() { Channel c = CreateChannel(); c.Writer.Complete(); - await Assert.ThrowsAnyAsync(() => c.Writer.WriteAsync(42)); + await Assert.ThrowsAnyAsync(async () => await c.Writer.WriteAsync(42)); } [Fact] @@ -393,10 +398,10 @@ public async Task Complete_WithException_PropagatesToExistingWriter() Channel c = CreateFullChannel(); if (c != null) { - Task write = c.Writer.WriteAsync(42); + ValueTask write = c.Writer.WriteAsync(42); var exc = new FormatException(); c.Writer.Complete(exc); - Assert.Same(exc, (await Assert.ThrowsAsync(() => write)).InnerException); + Assert.Same(exc, (await Assert.ThrowsAsync(async () => await write)).InnerException); } } @@ -406,18 +411,18 @@ public async Task Complete_WithException_PropagatesToNewWriter() Channel c = CreateChannel(); var exc = new FormatException(); c.Writer.Complete(exc); - Task write = c.Writer.WriteAsync(42); - Assert.Same(exc, (await Assert.ThrowsAsync(() => write)).InnerException); + ValueTask write = c.Writer.WriteAsync(42); + Assert.Same(exc, (await Assert.ThrowsAsync(async () => await write)).InnerException); } [Fact] public async Task Complete_WithException_PropagatesToExistingWaitingReader() { Channel c = CreateChannel(); - Task read = c.Reader.WaitToReadAsync(); + ValueTask read = c.Reader.WaitToReadAsync(); var exc = new FormatException(); c.Writer.Complete(exc); - await Assert.ThrowsAsync(() => read); + await Assert.ThrowsAsync(async () => await read); } [Fact] @@ -426,8 +431,8 @@ public async Task Complete_WithException_PropagatesToNewWaitingReader() Channel c = CreateChannel(); var exc = new FormatException(); c.Writer.Complete(exc); - Task read = c.Reader.WaitToReadAsync(); - await Assert.ThrowsAsync(() => read); + ValueTask read = c.Reader.WaitToReadAsync(); + await Assert.ThrowsAsync(async () => await read); } [Fact] @@ -436,8 +441,8 @@ public async Task Complete_WithException_PropagatesToNewWaitingWriter() Channel c = CreateChannel(); var exc = new FormatException(); c.Writer.Complete(exc); - Task write = c.Writer.WaitToWriteAsync(); - await Assert.ThrowsAsync(() => write); + ValueTask write = c.Writer.WaitToWriteAsync(); + await Assert.ThrowsAsync(async () => await write); } [Theory] @@ -454,7 +459,7 @@ public void ManyWriteAsync_ThenManyTryRead_Success(int readMode) const int NumItems = 2000; - Task[] writers = new Task[NumItems]; + ValueTask[] writers = new ValueTask[NumItems]; for (int i = 0; i < writers.Length; i++) { writers[i] = c.Writer.WriteAsync(i); @@ -476,11 +481,11 @@ public void Precancellation_Writing_ReturnsImmediately() { Channel c = CreateChannel(); - Task writeTask = c.Writer.WriteAsync(42, new CancellationToken(true)); - Assert.Equal(TaskStatus.Canceled, writeTask.Status); + ValueTask writeTask = c.Writer.WriteAsync(42, new CancellationToken(true)); + Assert.True(writeTask.IsCanceled); - Task waitTask = c.Writer.WaitToWriteAsync(new CancellationToken(true)); - Assert.Equal(TaskStatus.Canceled, waitTask.Status); + ValueTask waitTask = c.Writer.WaitToWriteAsync(new CancellationToken(true)); + Assert.True(writeTask.IsCanceled); } [Fact] @@ -502,8 +507,8 @@ public void Precancellation_WaitToReadAsync_ReturnsImmediately(bool dataAvailabl Assert.True(c.Writer.TryWrite(42)); } - Task writeTask = c.Reader.WaitToReadAsync(new CancellationToken(true)); - Assert.Equal(TaskStatus.Canceled, writeTask.Status); + ValueTask writeTask = c.Reader.WaitToReadAsync(new CancellationToken(true)); + Assert.True(writeTask.IsCanceled); } [Theory] @@ -514,10 +519,10 @@ public async Task WaitToReadAsync_DataWritten_CompletesSuccessfully(bool cancela Channel c = CreateChannel(); CancellationToken token = cancelable ? new CancellationTokenSource().Token : default; - Task read = c.Reader.WaitToReadAsync(token); + ValueTask read = c.Reader.WaitToReadAsync(token); Assert.False(read.IsCompleted); - Task write = c.Writer.WriteAsync(42, token); + ValueTask write = c.Writer.WriteAsync(42, token); Assert.True(await read); } @@ -528,10 +533,10 @@ public async Task WaitToReadAsync_NoDataWritten_Canceled_CompletesAsCanceled() Channel c = CreateChannel(); var cts = new CancellationTokenSource(); - Task read = c.Reader.WaitToReadAsync(cts.Token); + ValueTask read = c.Reader.WaitToReadAsync(cts.Token); Assert.False(read.IsCompleted); cts.Cancel(); - await Assert.ThrowsAnyAsync(() => read); + await Assert.ThrowsAnyAsync(async () => await read); } [Fact] @@ -542,7 +547,7 @@ public async Task ReadAsync_ThenWriteAsync_Succeeds() ValueTask r = c.Reader.ReadAsync(); Assert.False(r.IsCompleted); - Task w = c.Writer.WriteAsync(42); + ValueTask w = c.Writer.WriteAsync(42); AssertSynchronousSuccess(w); Assert.Equal(42, await r); @@ -553,10 +558,10 @@ public async Task WriteAsync_ReadAsync_Succeeds() { Channel c = CreateChannel(); - Task w = c.Writer.WriteAsync(42); + ValueTask w = c.Writer.WriteAsync(42); ValueTask r = c.Reader.ReadAsync(); - await Task.WhenAll(w, r.AsTask()); + await Task.WhenAll(w.AsTask(), r.AsTask()); Assert.Equal(42, await r); } @@ -620,8 +625,8 @@ public async Task ReadAsync_TryWrite_ManyConcurrentReaders_SerializedWriters_Suc Channel c = CreateChannel(); const int Items = 100; - ValueTask[] readers = (from i in Enumerable.Range(0, Items) select c.Reader.ReadAsync()).ToArray(); - var remainingReaders = new List>(readers.Select(r => r.AsTask())); + Task[] readers = (from i in Enumerable.Range(0, Items) select c.Reader.ReadAsync().AsTask()).ToArray(); + var remainingReaders = new List>(readers); for (int i = 0; i < Items; i++) { @@ -631,7 +636,7 @@ public async Task ReadAsync_TryWrite_ManyConcurrentReaders_SerializedWriters_Suc remainingReaders.Remove(r); } - Assert.Equal((Items * (Items - 1)) / 2, Enumerable.Sum(await Task.WhenAll(readers.Select(r => r.AsTask())))); + Assert.Equal((Items * (Items - 1)) / 2, Enumerable.Sum(await Task.WhenAll(readers))); } [Fact] @@ -697,5 +702,284 @@ public async Task ReadAsync_Canceled_WriteAsyncCompletesNextReader() Assert.Equal(i, await r); } } + + [Fact] + public async Task ReadAsync_ConsecutiveReadsSucceed() + { + Channel c = CreateChannel(); + for (int i = 0; i < 5; i++) + { + ValueTask r = c.Reader.ReadAsync(); + await c.Writer.WriteAsync(i); + Assert.Equal(i, await r); + } + } + + [Fact] + public async Task WaitToReadAsync_ConsecutiveReadsSucceed() + { + Channel c = CreateChannel(); + for (int i = 0; i < 5; i++) + { + ValueTask r = c.Reader.WaitToReadAsync(); + await c.Writer.WriteAsync(i); + Assert.True(await r); + Assert.True(c.Reader.TryRead(out int item)); + Assert.Equal(i, item); + } + } + + public static IEnumerable Reader_ContinuesOnCurrentContextIfDesired_MemberData() => + from readOrWait in new[] { true, false } + from completeBeforeOnCompleted in new[] { true, false } + from flowExecutionContext in new[] { true, false } + from continueOnCapturedContext in new bool?[] { null, false, true } + select new object[] { readOrWait, completeBeforeOnCompleted, flowExecutionContext, continueOnCapturedContext }; + + [Theory] + [MemberData(nameof(Reader_ContinuesOnCurrentContextIfDesired_MemberData))] + public async Task Reader_ContinuesOnCurrentSynchronizationContextIfDesired( + bool readOrWait, bool completeBeforeOnCompleted, bool flowExecutionContext, bool? continueOnCapturedContext) + { + if (AllowSynchronousContinuations) + { + return; + } + + await Task.Run(async () => + { + Assert.Null(SynchronizationContext.Current); + + Channel c = CreateChannel(); + ValueTask vt = readOrWait ? + c.Reader.ReadAsync() : + c.Reader.WaitToReadAsync(); + + var continuationRan = new TaskCompletionSource(); + var asyncLocal = new AsyncLocal(); + bool schedulerWasFlowed = false; + bool executionContextWasFlowed = false; + Action continuation = () => + { + schedulerWasFlowed = SynchronizationContext.Current is CustomSynchronizationContext; + executionContextWasFlowed = 42 == asyncLocal.Value; + continuationRan.SetResult(true); + }; + + if (completeBeforeOnCompleted) + { + Assert.False(vt.IsCompleted); + Assert.False(vt.IsCompletedSuccessfully); + c.Writer.TryWrite(true); + } + + SynchronizationContext.SetSynchronizationContext(new CustomSynchronizationContext()); + asyncLocal.Value = 42; + switch (continueOnCapturedContext) + { + case null: + if (flowExecutionContext) + { + vt.GetAwaiter().OnCompleted(continuation); + } + else + { + vt.GetAwaiter().UnsafeOnCompleted(continuation); + } + break; + default: + if (flowExecutionContext) + { + vt.ConfigureAwait(continueOnCapturedContext.Value).GetAwaiter().OnCompleted(continuation); + } + else + { + vt.ConfigureAwait(continueOnCapturedContext.Value).GetAwaiter().UnsafeOnCompleted(continuation); + } + break; + } + asyncLocal.Value = 0; + SynchronizationContext.SetSynchronizationContext(null); + + if (!completeBeforeOnCompleted) + { + Assert.False(vt.IsCompleted); + Assert.False(vt.IsCompletedSuccessfully); + c.Writer.TryWrite(true); + } + + await continuationRan.Task; + Assert.True(vt.IsCompleted); + Assert.True(vt.IsCompletedSuccessfully); + + Assert.Equal(continueOnCapturedContext != false, schedulerWasFlowed); + if (completeBeforeOnCompleted) // OnCompleted will simply queue using a mechanism that happens to flow + { + Assert.True(executionContextWasFlowed); + } + else + { + Assert.Equal(flowExecutionContext, executionContextWasFlowed); + } + }); + } + + [Theory] + [MemberData(nameof(Reader_ContinuesOnCurrentContextIfDesired_MemberData))] + public async Task Reader_ContinuesOnCurrentTaskSchedulerIfDesired( + bool readOrWait, bool completeBeforeOnCompleted, bool flowExecutionContext, bool? continueOnCapturedContext) + { + if (AllowSynchronousContinuations) + { + return; + } + + await Task.Run(async () => + { + Assert.Null(SynchronizationContext.Current); + + Channel c = CreateChannel(); + ValueTask vt = readOrWait ? + c.Reader.ReadAsync() : + c.Reader.WaitToReadAsync(); + + var continuationRan = new TaskCompletionSource(); + var asyncLocal = new AsyncLocal(); + bool schedulerWasFlowed = false; + bool executionContextWasFlowed = false; + Action continuation = () => + { + schedulerWasFlowed = TaskScheduler.Current is CustomTaskScheduler; + executionContextWasFlowed = 42 == asyncLocal.Value; + continuationRan.SetResult(true); + }; + + if (completeBeforeOnCompleted) + { + Assert.False(vt.IsCompleted); + Assert.False(vt.IsCompletedSuccessfully); + c.Writer.TryWrite(true); + } + + await Task.Factory.StartNew(() => + { + Assert.IsType(TaskScheduler.Current); + asyncLocal.Value = 42; + switch (continueOnCapturedContext) + { + case null: + if (flowExecutionContext) + { + vt.GetAwaiter().OnCompleted(continuation); + } + else + { + vt.GetAwaiter().UnsafeOnCompleted(continuation); + } + break; + default: + if (flowExecutionContext) + { + vt.ConfigureAwait(continueOnCapturedContext.Value).GetAwaiter().OnCompleted(continuation); + } + else + { + vt.ConfigureAwait(continueOnCapturedContext.Value).GetAwaiter().UnsafeOnCompleted(continuation); + } + break; + } + asyncLocal.Value = 0; + }, CancellationToken.None, TaskCreationOptions.None, new CustomTaskScheduler()); + + if (!completeBeforeOnCompleted) + { + Assert.False(vt.IsCompleted); + Assert.False(vt.IsCompletedSuccessfully); + c.Writer.TryWrite(true); + } + + await continuationRan.Task; + Assert.True(vt.IsCompleted); + Assert.True(vt.IsCompletedSuccessfully); + + Assert.Equal(continueOnCapturedContext != false, schedulerWasFlowed); + if (completeBeforeOnCompleted) // OnCompleted will simply queue using a mechanism that happens to flow + { + Assert.True(executionContextWasFlowed); + } + else + { + Assert.Equal(flowExecutionContext, executionContextWasFlowed); + } + }); + } + + [Fact] + public void ValueTask_GetResultWhenNotCompleted_Throws() + { + ValueTaskAwaiter readVt = CreateChannel().Reader.ReadAsync().GetAwaiter(); + Assert.Throws(() => readVt.GetResult()); + + ValueTaskAwaiter waitReadVt = CreateChannel().Reader.WaitToReadAsync().GetAwaiter(); + Assert.Throws(() => waitReadVt.GetResult()); + + if (CreateFullChannel() != null) + { + ValueTaskAwaiter writeVt = CreateFullChannel().Writer.WriteAsync(42).GetAwaiter(); + Assert.Throws(() => writeVt.GetResult()); + + ValueTaskAwaiter waitWriteVt = CreateFullChannel().Writer.WaitToWriteAsync().GetAwaiter(); + Assert.Throws(() => waitWriteVt.GetResult()); + } + } + + [Fact] + public void ValueTask_MultipleContinuations_Throws() + { + ValueTaskAwaiter readVt = CreateChannel().Reader.ReadAsync().GetAwaiter(); + readVt.OnCompleted(() => { }); + Assert.Throws(() => readVt.OnCompleted(() => { })); + + ValueTaskAwaiter waitReadVt = CreateChannel().Reader.WaitToReadAsync().GetAwaiter(); + waitReadVt.OnCompleted(() => { }); + Assert.Throws(() => waitReadVt.OnCompleted(() => { })); + + if (CreateFullChannel() != null) + { + ValueTaskAwaiter writeVt = CreateFullChannel().Writer.WriteAsync(42).GetAwaiter(); + writeVt.OnCompleted(() => { }); + Assert.Throws(() => writeVt.OnCompleted(() => { })); + + ValueTaskAwaiter waitWriteVt = CreateFullChannel().Writer.WaitToWriteAsync().GetAwaiter(); + waitWriteVt.OnCompleted(() => { }); + Assert.Throws(() => waitWriteVt.OnCompleted(() => { })); + } + } + + private sealed class CustomSynchronizationContext : SynchronizationContext + { + public override void Post(SendOrPostCallback d, object state) + { + ThreadPool.QueueUserWorkItem(delegate + { + SetSynchronizationContext(this); + try + { + d(state); + } + finally + { + SetSynchronizationContext(null); + } + }, null); + } + } + + private sealed class CustomTaskScheduler : TaskScheduler + { + protected override void QueueTask(Task task) => ThreadPool.QueueUserWorkItem(_ => TryExecuteTask(task)); + protected override bool TryExecuteTaskInline(Task task, bool taskWasPreviouslyQueued) => false; + protected override IEnumerable GetScheduledTasks() => null; + } } } diff --git a/src/System.Threading.Channels/tests/ChannelTests.cs b/src/System.Threading.Channels/tests/ChannelTests.cs index be8b7fb5d6e2..c9147750fcb4 100644 --- a/src/System.Threading.Channels/tests/ChannelTests.cs +++ b/src/System.Threading.Channels/tests/ChannelTests.cs @@ -92,7 +92,7 @@ public async Task DefaultWriteAsync_UsesWaitToWriteAsyncAndTryWrite() { var c = new TestChannelWriter(10); Assert.False(c.TryComplete()); - Assert.Equal(TaskStatus.Canceled, c.WriteAsync(42, new CancellationToken(true)).Status); + Assert.Equal(TaskStatus.Canceled, c.WriteAsync(42, new CancellationToken(true)).AsTask().Status); int count = 0; try @@ -117,9 +117,9 @@ public void DefaultCompletion_NeverCompletes() public async Task DefaultWriteAsync_CatchesTryWriteExceptions() { var w = new TryWriteThrowingWriter(); - Task t = w.WriteAsync(42); - Assert.Equal(TaskStatus.Faulted, t.Status); - await Assert.ThrowsAsync(() => t); + ValueTask t = w.WriteAsync(42); + Assert.True(t.IsFaulted); + await Assert.ThrowsAsync(async () => await t); } [Fact] @@ -181,7 +181,7 @@ public override bool TryRead(out T item) return _reader.TryRead(out item); } - public override Task WaitToReadAsync(CancellationToken cancellationToken) + public override ValueTask WaitToReadAsync(CancellationToken cancellationToken) { return _reader.WaitToReadAsync(cancellationToken); } @@ -207,10 +207,10 @@ private sealed class TestChannelWriter : ChannelWriter public override bool TryWrite(T item) => _rand.Next(0, 2) == 0 && _count++ < _max; // succeed if we're under our limit, and add random failures - public override Task WaitToWriteAsync(CancellationToken cancellationToken) => - _count >= _max ? Task.FromResult(false) : - _rand.Next(0, 2) == 0 ? Task.Delay(1).ContinueWith(_ => true) : // randomly introduce delays - Task.FromResult(true); + public override ValueTask WaitToWriteAsync(CancellationToken cancellationToken) => + _count >= _max ? new ValueTask(Task.FromResult(false)) : + _rand.Next(0, 2) == 0 ? new ValueTask(Task.Delay(1).ContinueWith(_ => true)) : // randomly introduce delays + new ValueTask(Task.FromResult(true)); } private sealed class TestChannelReader : ChannelReader @@ -246,22 +246,22 @@ public override bool TryRead(out T item) return true; } - public override Task WaitToReadAsync(CancellationToken cancellationToken) => + public override ValueTask WaitToReadAsync(CancellationToken cancellationToken) => new ValueTask( _closed ? Task.FromResult(false) : _rand.Next(0, 2) == 0 ? Task.Delay(1).ContinueWith(_ => true) : // randomly introduce delays - Task.FromResult(true); + Task.FromResult(true)); } private sealed class TryWriteThrowingWriter : ChannelWriter { public override bool TryWrite(T item) => throw new FormatException(); - public override Task WaitToWriteAsync(CancellationToken cancellationToken = default) => throw new InvalidDataException(); + public override ValueTask WaitToWriteAsync(CancellationToken cancellationToken = default) => throw new InvalidDataException(); } private sealed class TryReadThrowingReader : ChannelReader { public override bool TryRead(out T item) => throw new FieldAccessException(); - public override Task WaitToReadAsync(CancellationToken cancellationToken = default) => throw new DriveNotFoundException(); + public override ValueTask WaitToReadAsync(CancellationToken cancellationToken = default) => throw new DriveNotFoundException(); } private sealed class CanReadFalseStream : MemoryStream diff --git a/src/System.Threading.Channels/tests/Performance/Perf.Channel.cs b/src/System.Threading.Channels/tests/Performance/Perf.Channel.cs index 0c9825f2142a..78b7ab1f6c18 100644 --- a/src/System.Threading.Channels/tests/Performance/Perf.Channel.cs +++ b/src/System.Threading.Channels/tests/Performance/Perf.Channel.cs @@ -7,23 +7,25 @@ namespace System.Threading.Channels.Tests { - public sealed class Perf_UnboundedChannelTests : Perf_BufferingTests + public sealed class UnboundedChannelPerfTests : PerfTests { public override Channel CreateChannel() => Channel.CreateUnbounded(); } - public sealed class Perf_UnboundedSpscChannelTests : Perf_BufferingTests + public sealed class SpscUnboundedChannelPerfTests : PerfTests { public override Channel CreateChannel() => Channel.CreateUnbounded(new UnboundedChannelOptions { SingleReader = true, SingleWriter = true }); } - public sealed class Perf_BoundedChannelTests : Perf_BufferingTests + public sealed class BoundedChannelPerfTests : PerfTests { public override Channel CreateChannel() => Channel.CreateBounded(10); } - public abstract class Perf_BufferingTests : Perf_Tests + public abstract class PerfTests { + public abstract Channel CreateChannel(); + [Benchmark(InnerIterationCount = 1_000_000), MeasureGCAllocations] public void TryWriteThenTryRead() { @@ -87,16 +89,12 @@ public async Task ReadAsyncThenWriteAsync() } } } - } - - public abstract class Perf_Tests - { - public abstract Channel CreateChannel(); [Benchmark(InnerIterationCount = 1_000_000), MeasureGCAllocations] - public async Task ConcurrentReadAsyncWriteAsync() + public async Task PingPong() { - Channel channel = CreateChannel(); + Channel channel1 = CreateChannel(); + Channel channel2 = CreateChannel(); foreach (BenchmarkIteration iteration in Benchmark.Iterations) { @@ -106,17 +104,21 @@ public async Task ConcurrentReadAsyncWriteAsync() await Task.WhenAll( Task.Run(async () => { - ChannelReader reader = channel.Reader; + ChannelReader reader = channel1.Reader; + ChannelWriter writer = channel2.Writer; for (int i = 0; i < iters; i++) { + await writer.WriteAsync(i); await reader.ReadAsync(); } }), Task.Run(async () => { - ChannelWriter writer = channel.Writer; + ChannelWriter writer = channel1.Writer; + ChannelReader reader = channel2.Reader; for (int i = 0; i < iters; i++) { + await reader.ReadAsync(); await writer.WriteAsync(i); } })); diff --git a/src/System.Threading.Channels/tests/System.Threading.Channels.Tests.csproj b/src/System.Threading.Channels/tests/System.Threading.Channels.Tests.csproj index 7f837843745e..8f91bde9230b 100644 --- a/src/System.Threading.Channels/tests/System.Threading.Channels.Tests.csproj +++ b/src/System.Threading.Channels/tests/System.Threading.Channels.Tests.csproj @@ -19,5 +19,8 @@ Common\System\Diagnostics\DebuggerAttributes.cs + + + \ No newline at end of file diff --git a/src/System.Threading.Channels/tests/TestBase.cs b/src/System.Threading.Channels/tests/TestBase.cs index d3af1ba7eb41..9a84a8ecd69b 100644 --- a/src/System.Threading.Channels/tests/TestBase.cs +++ b/src/System.Threading.Channels/tests/TestBase.cs @@ -24,6 +24,8 @@ protected async Task AssertCanceled(Task task, CancellationToken token) AssertSynchronouslyCanceled(task, token); } + protected void AssertSynchronousSuccess(ValueTask task) => Assert.True(task.IsCompletedSuccessfully); + protected void AssertSynchronousSuccess(ValueTask task) => Assert.True(task.IsCompletedSuccessfully); protected void AssertSynchronousSuccess(Task task) => Assert.Equal(TaskStatus.RanToCompletion, task.Status); protected void AssertSynchronousTrue(Task task) @@ -32,6 +34,12 @@ protected void AssertSynchronousTrue(Task task) Assert.True(task.Result); } + protected void AssertSynchronousTrue(ValueTask task) + { + AssertSynchronousSuccess(task); + Assert.True(task.Result); + } + internal sealed class DelegateObserver : IObserver { public Action OnNextDelegate = null; diff --git a/src/System.Threading.Channels/tests/UnboundedChannelTests.cs b/src/System.Threading.Channels/tests/UnboundedChannelTests.cs index 71d8e91f43f4..dec1e50f3ee1 100644 --- a/src/System.Threading.Channels/tests/UnboundedChannelTests.cs +++ b/src/System.Threading.Channels/tests/UnboundedChannelTests.cs @@ -10,14 +10,13 @@ namespace System.Threading.Channels.Tests { public abstract class UnboundedChannelTests : ChannelTestBase { - protected abstract bool AllowSynchronousContinuations { get; } - protected override Channel CreateChannel() => Channel.CreateUnbounded( + protected override Channel CreateChannel() => Channel.CreateUnbounded( new UnboundedChannelOptions { SingleReader = RequiresSingleReader, AllowSynchronousContinuations = AllowSynchronousContinuations }); - protected override Channel CreateFullChannel() => null; + protected override Channel CreateFullChannel() => null; [Fact] public async Task Complete_BeforeEmpty_NoWaiters_TriggersCompletion() @@ -107,12 +106,12 @@ public void AllowSynchronousContinuations_WaitToReadAsync_ContinuationsInvokedAc Channel c = CreateChannel(); int expectedId = Environment.CurrentManagedThreadId; - Task r = c.Reader.WaitToReadAsync().ContinueWith(_ => + Task r = c.Reader.WaitToReadAsync().AsTask().ContinueWith(_ => { Assert.Equal(AllowSynchronousContinuations, expectedId == Environment.CurrentManagedThreadId); }, CancellationToken.None, TaskContinuationOptions.ExecuteSynchronously, TaskScheduler.Default); - Assert.Equal(TaskStatus.RanToCompletion, c.Writer.WriteAsync(42).Status); + Assert.True(c.Writer.WriteAsync(42).IsCompletedSuccessfully); ((IAsyncResult)r).AsyncWaitHandle.WaitOne(); // avoid inlining the continuation r.GetAwaiter().GetResult(); } @@ -155,13 +154,24 @@ public void ValidateInternalDebuggerAttributes() public async Task MultipleWaiters_CancelsPreviousWaiter() { Channel c = CreateChannel(); - Task t1 = c.Reader.WaitToReadAsync(); - Task t2 = c.Reader.WaitToReadAsync(); - await Assert.ThrowsAnyAsync(() => t1); + ValueTask t1 = c.Reader.WaitToReadAsync(); + ValueTask t2 = c.Reader.WaitToReadAsync(); + await Assert.ThrowsAnyAsync(async () => await t1); Assert.True(c.Writer.TryWrite(42)); Assert.True(await t2); } + [Fact] + public async Task MultipleReaders_CancelsPreviousReader() + { + Channel c = CreateChannel(); + ValueTask t1 = c.Reader.ReadAsync(); + ValueTask t2 = c.Reader.ReadAsync(); + await Assert.ThrowsAnyAsync(async () => await t1); + Assert.True(c.Writer.TryWrite(42)); + Assert.Equal(42, await t2); + } + [Fact] public void Stress_TryWrite_TryRead() { diff --git a/src/System.Threading.Tasks.Extensions/ref/System.Threading.Tasks.Extensions.cs b/src/System.Threading.Tasks.Extensions/ref/System.Threading.Tasks.Extensions.cs index 577fb2ed3c85..36f4afcbf63f 100644 --- a/src/System.Threading.Tasks.Extensions/ref/System.Threading.Tasks.Extensions.cs +++ b/src/System.Threading.Tasks.Extensions/ref/System.Threading.Tasks.Extensions.cs @@ -13,6 +13,18 @@ public sealed partial class AsyncMethodBuilderAttribute : System.Attribute public AsyncMethodBuilderAttribute(System.Type builderType) { } public System.Type BuilderType { get { throw null; } } } + public partial struct AsyncValueTaskMethodBuilder + { + private object _dummy; + public System.Threading.Tasks.ValueTask Task { get { throw null; } } + public void AwaitOnCompleted(ref TAwaiter awaiter, ref TStateMachine stateMachine) where TAwaiter : System.Runtime.CompilerServices.INotifyCompletion where TStateMachine : System.Runtime.CompilerServices.IAsyncStateMachine { } + public void AwaitUnsafeOnCompleted(ref TAwaiter awaiter, ref TStateMachine stateMachine) where TAwaiter : System.Runtime.CompilerServices.ICriticalNotifyCompletion where TStateMachine : System.Runtime.CompilerServices.IAsyncStateMachine { } + public static System.Runtime.CompilerServices.AsyncValueTaskMethodBuilder Create() { throw null; } + public void SetException(System.Exception exception) { } + public void SetResult() { } + public void SetStateMachine(System.Runtime.CompilerServices.IAsyncStateMachine stateMachine) { } + public void Start(ref TStateMachine stateMachine) where TStateMachine : System.Runtime.CompilerServices.IAsyncStateMachine { } + } public partial struct AsyncValueTaskMethodBuilder { private TResult _result; @@ -25,6 +37,19 @@ public void SetResult(TResult result) { } public void SetStateMachine(System.Runtime.CompilerServices.IAsyncStateMachine stateMachine) { } public void Start(ref TStateMachine stateMachine) where TStateMachine : System.Runtime.CompilerServices.IAsyncStateMachine { } } + public readonly partial struct ConfiguredValueTaskAwaitable + { + private readonly object _dummy; + public System.Runtime.CompilerServices.ConfiguredValueTaskAwaitable.ConfiguredValueTaskAwaiter GetAwaiter() { throw null; } + public readonly partial struct ConfiguredValueTaskAwaiter : System.Runtime.CompilerServices.ICriticalNotifyCompletion, System.Runtime.CompilerServices.INotifyCompletion + { + private readonly object _dummy; + public bool IsCompleted { get { throw null; } } + public void GetResult() { } + public void OnCompleted(System.Action continuation) { } + public void UnsafeOnCompleted(System.Action continuation) { } + } + } public readonly partial struct ConfiguredValueTaskAwaitable { private readonly object _dummy; @@ -38,6 +63,14 @@ public void OnCompleted(System.Action continuation) { } public void UnsafeOnCompleted(System.Action continuation) { } } } + public readonly partial struct ValueTaskAwaiter : System.Runtime.CompilerServices.ICriticalNotifyCompletion, System.Runtime.CompilerServices.INotifyCompletion + { + private readonly object _dummy; + public bool IsCompleted { get { throw null; } } + public void GetResult() { } + public void OnCompleted(System.Action continuation) { } + public void UnsafeOnCompleted(System.Action continuation) { } + } public readonly partial struct ValueTaskAwaiter : System.Runtime.CompilerServices.ICriticalNotifyCompletion, System.Runtime.CompilerServices.INotifyCompletion { private readonly object _dummy; @@ -49,11 +82,58 @@ public void UnsafeOnCompleted(System.Action continuation) { } } namespace System.Threading.Tasks { + [System.Flags] + public enum ValueTaskSourceOnCompletedFlags + { + None, + UseSchedulingContext = 0x1, + FlowExecutionContext = 0x2, + } + public enum ValueTaskSourceStatus + { + Pending = 0, + Succeeded = 1, + Faulted = 2, + Canceled = 3 + } + public interface IValueTaskSource + { + System.Threading.Tasks.ValueTaskSourceStatus Status { get; } + void OnCompleted(System.Action continuation, object state, System.Threading.Tasks.ValueTaskSourceOnCompletedFlags flags); + void GetResult(); + } + public interface IValueTaskSource + { + System.Threading.Tasks.ValueTaskSourceStatus Status { get; } + void OnCompleted(System.Action continuation, object state, System.Threading.Tasks.ValueTaskSourceOnCompletedFlags flags); + TResult GetResult(); + } + [System.Runtime.CompilerServices.AsyncMethodBuilderAttribute(typeof(System.Runtime.CompilerServices.AsyncValueTaskMethodBuilder))] + public readonly partial struct ValueTask : System.IEquatable + { + internal readonly object _dummy; + public ValueTask(System.Threading.Tasks.Task task) { throw null; } + public ValueTask(System.Threading.Tasks.IValueTaskSource source) { throw null; } + public bool IsCanceled { get { throw null; } } + public bool IsCompleted { get { throw null; } } + public bool IsCompletedSuccessfully { get { throw null; } } + public bool IsFaulted { get { throw null; } } + public System.Threading.Tasks.Task AsTask() { throw null; } + public System.Runtime.CompilerServices.ConfiguredValueTaskAwaitable ConfigureAwait(bool continueOnCapturedContext) { throw null; } + public override bool Equals(object obj) { throw null; } + public bool Equals(System.Threading.Tasks.ValueTask other) { throw null; } + public System.Runtime.CompilerServices.ValueTaskAwaiter GetAwaiter() { throw null; } + public override int GetHashCode() { throw null; } + public System.Threading.Tasks.ValueTask Preserve() { throw null; } + public static bool operator ==(System.Threading.Tasks.ValueTask left, System.Threading.Tasks.ValueTask right) { throw null; } + public static bool operator !=(System.Threading.Tasks.ValueTask left, System.Threading.Tasks.ValueTask right) { throw null; } + } [System.Runtime.CompilerServices.AsyncMethodBuilderAttribute(typeof(System.Runtime.CompilerServices.AsyncValueTaskMethodBuilder<>))] public readonly partial struct ValueTask : System.IEquatable> { internal readonly TResult _result; public ValueTask(System.Threading.Tasks.Task task) { throw null; } + public ValueTask(System.Threading.Tasks.IValueTaskSource source) { throw null; } public ValueTask(TResult result) { throw null; } public bool IsCanceled { get { throw null; } } public bool IsCompleted { get { throw null; } } @@ -66,6 +146,7 @@ namespace System.Threading.Tasks public bool Equals(System.Threading.Tasks.ValueTask other) { throw null; } public System.Runtime.CompilerServices.ValueTaskAwaiter GetAwaiter() { throw null; } public override int GetHashCode() { throw null; } + public System.Threading.Tasks.ValueTask Preserve() { throw null; } public static bool operator ==(System.Threading.Tasks.ValueTask left, System.Threading.Tasks.ValueTask right) { throw null; } public static bool operator !=(System.Threading.Tasks.ValueTask left, System.Threading.Tasks.ValueTask right) { throw null; } public override string ToString() { throw null; } diff --git a/src/System.Threading.Tasks.Extensions/src/System.Threading.Tasks.Extensions.csproj b/src/System.Threading.Tasks.Extensions/src/System.Threading.Tasks.Extensions.csproj index dc43433c90f3..5a04050f280a 100644 --- a/src/System.Threading.Tasks.Extensions/src/System.Threading.Tasks.Extensions.csproj +++ b/src/System.Threading.Tasks.Extensions/src/System.Threading.Tasks.Extensions.csproj @@ -37,6 +37,9 @@ Common\CoreLib\System\Runtime\CompilerServices\ConfiguredValueTaskAwaitable.cs + + Common\CoreLib\System\Threading\Tasks\IValueTaskSource.cs + Common\CoreLib\System\Runtime\CompilerServices\ValueTaskAwaiter.cs @@ -47,6 +50,7 @@ + diff --git a/src/System.Threading.Tasks.Extensions/src/System/ThrowHelper.cs b/src/System.Threading.Tasks.Extensions/src/System/ThrowHelper.cs index 0757c616bb01..5824a165ec1e 100644 --- a/src/System.Threading.Tasks.Extensions/src/System/ThrowHelper.cs +++ b/src/System.Threading.Tasks.Extensions/src/System/ThrowHelper.cs @@ -12,9 +12,15 @@ internal static class ThrowHelper internal static void ThrowArgumentNullException(ExceptionArgument argument) => throw GetArgumentNullException(argument); + internal static void ThrowArgumentOutOfRangeException(ExceptionArgument argument) => + throw GetArgumentOutOfRangeException(argument); + private static ArgumentNullException GetArgumentNullException(ExceptionArgument argument) => new ArgumentNullException(GetArgumentName(argument)); + private static ArgumentOutOfRangeException GetArgumentOutOfRangeException(ExceptionArgument argument) => + new ArgumentOutOfRangeException(GetArgumentName(argument)); + [MethodImpl(MethodImplOptions.NoInlining)] private static string GetArgumentName(ExceptionArgument argument) { @@ -27,6 +33,8 @@ private static string GetArgumentName(ExceptionArgument argument) internal enum ExceptionArgument { - task + task, + source, + state } } diff --git a/src/System.Threading.Tasks.Extensions/tests/AsyncMethodBuilderAttributeTests.cs b/src/System.Threading.Tasks.Extensions/tests/AsyncMethodBuilderAttributeTests.cs index 29dfa236f4fb..02f1d8b4cdcb 100644 --- a/src/System.Threading.Tasks.Extensions/tests/AsyncMethodBuilderAttributeTests.cs +++ b/src/System.Threading.Tasks.Extensions/tests/AsyncMethodBuilderAttributeTests.cs @@ -11,6 +11,7 @@ public class AsyncMethodBuilderAttributeTests [Theory] [InlineData(typeof(string))] [InlineData(typeof(int))] + [InlineData(typeof(AsyncValueTaskMethodBuilder))] [InlineData(typeof(AsyncValueTaskMethodBuilder<>))] [InlineData(typeof(AsyncValueTaskMethodBuilder))] [InlineData(typeof(AsyncValueTaskMethodBuilder))] diff --git a/src/System.Threading.Tasks.Extensions/tests/AsyncValueTaskMethodBuilderTests.cs b/src/System.Threading.Tasks.Extensions/tests/AsyncValueTaskMethodBuilderTests.cs index 1ecefdf6fa4a..096338523109 100644 --- a/src/System.Threading.Tasks.Extensions/tests/AsyncValueTaskMethodBuilderTests.cs +++ b/src/System.Threading.Tasks.Extensions/tests/AsyncValueTaskMethodBuilderTests.cs @@ -2,7 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System; using System.Runtime.CompilerServices; using Xunit; @@ -11,14 +10,30 @@ namespace System.Threading.Tasks.Tests public class AsyncValueTaskMethodBuilderTests { [Fact] - public void Create_ReturnsDefaultInstance() + public void NonGeneric_Create_ReturnsDefaultInstance() + { + AsyncValueTaskMethodBuilder b = default; + Assert.Equal(default, b); // implementation detail being verified + } + + [Fact] + public void Generic_Create_ReturnsDefaultInstance() { AsyncValueTaskMethodBuilder b = default; - Assert.Equal(default(AsyncValueTaskMethodBuilder), b); // implementation detail being verified + Assert.Equal(default, b); // implementation detail being verified + } + + [Fact] + public void NonGeneric_SetResult_BeforeAccessTask_ValueTaskIsDefault() + { + AsyncValueTaskMethodBuilder b = default; + b.SetResult(); + ValueTask vt = b.Task; + Assert.True(vt == default); } [Fact] - public void SetResult_BeforeAccessTask_ValueTaskContainsValue() + public void Generic_SetResult_BeforeAccessTask_ValueTaskContainsValue() { AsyncValueTaskMethodBuilder b = default; b.SetResult(42); @@ -29,7 +44,18 @@ public void SetResult_BeforeAccessTask_ValueTaskContainsValue() } [Fact] - public void SetResult_AfterAccessTask_ValueTaskContainsValue() + public void NonGeneric_SetResult_AfterAccessTask_ValueTaskContainsValue() + { + AsyncValueTaskMethodBuilder b = default; + ValueTask vt = b.Task; + b.SetResult(); + Assert.False(vt == default); + Assert.True(vt.IsCompletedSuccessfully); + Assert.True(WrapsTask(vt)); + } + + [Fact] + public void Generic_SetResult_AfterAccessTask_ValueTaskContainsValue() { AsyncValueTaskMethodBuilder b = default; ValueTask vt = b.Task; @@ -40,7 +66,18 @@ public void SetResult_AfterAccessTask_ValueTaskContainsValue() } [Fact] - public void SetException_BeforeAccessTask_FaultsTask() + public void NonGeneric_SetException_BeforeAccessTask_FaultsTask() + { + AsyncValueTaskMethodBuilder b = default; + var e = new FormatException(); + b.SetException(e); + ValueTask vt = b.Task; + Assert.True(vt.IsFaulted); + Assert.Same(e, Assert.Throws(() => vt.GetAwaiter().GetResult())); + } + + [Fact] + public void Generic_SetException_BeforeAccessTask_FaultsTask() { AsyncValueTaskMethodBuilder b = default; var e = new FormatException(); @@ -51,7 +88,18 @@ public void SetException_BeforeAccessTask_FaultsTask() } [Fact] - public void SetException_AfterAccessTask_FaultsTask() + public void NonGeneric_SetException_AfterAccessTask_FaultsTask() + { + AsyncValueTaskMethodBuilder b = default; + var e = new FormatException(); + ValueTask vt = b.Task; + b.SetException(e); + Assert.True(vt.IsFaulted); + Assert.Same(e, Assert.Throws(() => vt.GetAwaiter().GetResult())); + } + + [Fact] + public void Generic_SetException_AfterAccessTask_FaultsTask() { AsyncValueTaskMethodBuilder b = default; var e = new FormatException(); @@ -62,7 +110,18 @@ public void SetException_AfterAccessTask_FaultsTask() } [Fact] - public void SetException_OperationCanceledException_CancelsTask() + public void NonGeneric_SetException_OperationCanceledException_CancelsTask() + { + AsyncValueTaskMethodBuilder b = default; + var e = new OperationCanceledException(); + ValueTask vt = b.Task; + b.SetException(e); + Assert.True(vt.IsCanceled); + Assert.Same(e, Assert.Throws(() => vt.GetAwaiter().GetResult())); + } + + [Fact] + public void Generic_SetException_OperationCanceledException_CancelsTask() { AsyncValueTaskMethodBuilder b = default; var e = new OperationCanceledException(); @@ -73,7 +132,17 @@ public void SetException_OperationCanceledException_CancelsTask() } [Fact] - public void Start_InvokesMoveNext() + public void NonGeneric_Start_InvokesMoveNext() + { + AsyncValueTaskMethodBuilder b = default; + int invokes = 0; + var dsm = new DelegateStateMachine { MoveNextDelegate = () => invokes++ }; + b.Start(ref dsm); + Assert.Equal(1, invokes); + } + + [Fact] + public void Generic_Start_InvokesMoveNext() { AsyncValueTaskMethodBuilder b = default; int invokes = 0; @@ -87,7 +156,38 @@ public void Start_InvokesMoveNext() [InlineData(2, false)] [InlineData(1, true)] [InlineData(2, true)] - public void AwaitOnCompleted_ForcesTaskCreation(int numAwaits, bool awaitUnsafe) + public void NonGeneric_AwaitOnCompleted_ForcesTaskCreation(int numAwaits, bool awaitUnsafe) + { + AsyncValueTaskMethodBuilder b = default; + + var dsm = new DelegateStateMachine(); + TaskAwaiter t = new TaskCompletionSource().Task.GetAwaiter(); + + Assert.InRange(numAwaits, 1, int.MaxValue); + for (int i = 1; i <= numAwaits; i++) + { + if (awaitUnsafe) + { + b.AwaitUnsafeOnCompleted(ref t, ref dsm); + } + else + { + b.AwaitOnCompleted(ref t, ref dsm); + } + } + + b.SetResult(); + + Assert.True(WrapsTask(b.Task)); + Assert.True(b.Task.IsCompletedSuccessfully); + } + + [Theory] + [InlineData(1, false)] + [InlineData(2, false)] + [InlineData(1, true)] + [InlineData(2, true)] + public void Generic_AwaitOnCompleted_ForcesTaskCreation(int numAwaits, bool awaitUnsafe) { AsyncValueTaskMethodBuilder b = default; @@ -115,14 +215,56 @@ public void AwaitOnCompleted_ForcesTaskCreation(int numAwaits, bool awaitUnsafe) [Fact] [ActiveIssue("https://github.com/dotnet/corefx/issues/22506", TargetFrameworkMonikers.UapAot)] - public void SetStateMachine_InvalidArgument_ThrowsException() + public void NonGeneric_SetStateMachine_InvalidArgument_ThrowsException() + { + AsyncValueTaskMethodBuilder b = default; + AssertExtensions.Throws("stateMachine", () => b.SetStateMachine(null)); + } + + [Fact] + [ActiveIssue("https://github.com/dotnet/corefx/issues/22506", TargetFrameworkMonikers.UapAot)] + public void Generic_SetStateMachine_InvalidArgument_ThrowsException() { AsyncValueTaskMethodBuilder b = default; AssertExtensions.Throws("stateMachine", () => b.SetStateMachine(null)); } [Fact] - public void Start_ExecutionContextChangesInMoveNextDontFlowOut() + public void NonGeneric_Start_ExecutionContextChangesInMoveNextDontFlowOut() + { + var al = new AsyncLocal { Value = 0 }; + int calls = 0; + + var dsm = new DelegateStateMachine + { + MoveNextDelegate = () => + { + al.Value++; + calls++; + } + }; + + dsm.MoveNext(); + Assert.Equal(1, al.Value); + Assert.Equal(1, calls); + + dsm.MoveNext(); + Assert.Equal(2, al.Value); + Assert.Equal(2, calls); + + AsyncValueTaskMethodBuilder b = default; + b.Start(ref dsm); + Assert.Equal(2, al.Value); // change should not be visible + Assert.Equal(3, calls); + + // Make sure we've not caused the Task to be allocated + b.SetResult(); + ValueTask vt = b.Task; + Assert.False(WrapsTask(vt)); + } + + [Fact] + public void Generic_Start_ExecutionContextChangesInMoveNextDontFlowOut() { var al = new AsyncLocal { Value = 0 }; int calls = 0; @@ -160,7 +302,25 @@ public void Start_ExecutionContextChangesInMoveNextDontFlowOut() [InlineData(1)] [InlineData(2)] [InlineData(10)] - public static async Task UsedWithAsyncMethod_CompletesSuccessfully(int yields) + public static async Task NonGeneric_UsedWithAsyncMethod_CompletesSuccessfully(int yields) + { + await ValueTaskReturningAsyncMethod(42); + + ValueTask vt = ValueTaskReturningAsyncMethod(84); + Assert.Equal(yields > 0, WrapsTask(vt)); + + async ValueTask ValueTaskReturningAsyncMethod(int result) + { + for (int i = 0; i < yields; i++) await Task.Yield(); + } + } + + [Theory] + [InlineData(0)] + [InlineData(1)] + [InlineData(2)] + [InlineData(10)] + public static async Task Generic_UsedWithAsyncMethod_CompletesSuccessfully(int yields) { Assert.Equal(42, await ValueTaskReturningAsyncMethod(42)); @@ -175,7 +335,129 @@ async ValueTask ValueTaskReturningAsyncMethod(int result) } } - /// Gets whether the ValueTask has a non-null Task. + [Fact] + public static async Task AwaitTasksAndValueTasks_InTaskAndValueTaskMethods() + { + for (int i = 0; i < 2; i++) + { + await TaskReturningMethod(); + Assert.Equal(17, await TaskInt32ReturningMethod()); + await ValueTaskReturningMethod(); + Assert.Equal(18, await ValueTaskInt32ReturningMethod()); + } + + async Task TaskReturningMethod() + { + for (int i = 0; i < 3; i++) + { + // Complete + await Task.CompletedTask; + await Task.FromResult(42); + await new ValueTask(); + await Assert.ThrowsAsync(async () => await new ValueTask(Task.FromException(new FormatException()))); + await Assert.ThrowsAsync(async () => await new ValueTask(ManualResetValueTaskSource.Completed(0, new FormatException()))); + Assert.Equal(42, await new ValueTask(42)); + Assert.Equal(42, await new ValueTask(Task.FromResult(42))); + Assert.Equal(42, await new ValueTask(ManualResetValueTaskSource.Completed(42, null))); + await Assert.ThrowsAsync(async () => await new ValueTask(Task.FromException(new FormatException()))); + await Assert.ThrowsAsync(async () => await new ValueTask(ManualResetValueTaskSource.Completed(0, new FormatException()))); + + // Incomplete + await Assert.ThrowsAsync(async () => await new ValueTask(Task.Delay(1).ContinueWith(_ => throw new FormatException()))); + await Assert.ThrowsAsync(async () => await new ValueTask(ManualResetValueTaskSource.Delay(1, 0, new FormatException()))); + Assert.Equal(42, await new ValueTask(Task.Delay(1).ContinueWith(_ => 42))); + Assert.Equal(42, await new ValueTask(ManualResetValueTaskSource.Delay(1, 42, null))); + await Assert.ThrowsAsync(async () => await new ValueTask(Task.Delay(1).ContinueWith(_ => throw new FormatException()))); + await Assert.ThrowsAsync(async () => await new ValueTask(ManualResetValueTaskSource.Delay(1, 0, new FormatException()))); + await Task.Yield(); + } + } + + async Task TaskInt32ReturningMethod() + { + for (int i = 0; i < 3; i++) + { + // Complete + await Task.CompletedTask; + await Task.FromResult(42); + await new ValueTask(); + await Assert.ThrowsAsync(async () => await new ValueTask(Task.FromException(new FormatException()))); + await Assert.ThrowsAsync(async () => await new ValueTask(ManualResetValueTaskSource.Completed(0, new FormatException()))); + Assert.Equal(42, await new ValueTask(42)); + Assert.Equal(42, await new ValueTask(Task.FromResult(42))); + Assert.Equal(42, await new ValueTask(ManualResetValueTaskSource.Completed(42, null))); + await Assert.ThrowsAsync(async () => await new ValueTask(Task.FromException(new FormatException()))); + await Assert.ThrowsAsync(async () => await new ValueTask(ManualResetValueTaskSource.Completed(0, new FormatException()))); + + // Incomplete + await Assert.ThrowsAsync(async () => await new ValueTask(Task.Delay(1).ContinueWith(_ => throw new FormatException()))); + await Assert.ThrowsAsync(async () => await new ValueTask(ManualResetValueTaskSource.Delay(1, 0, new FormatException()))); + Assert.Equal(42, await new ValueTask(Task.Delay(1).ContinueWith(_ => 42))); + Assert.Equal(42, await new ValueTask(ManualResetValueTaskSource.Delay(1, 42, null))); + await Assert.ThrowsAsync(async () => await new ValueTask(Task.Delay(1).ContinueWith(_ => throw new FormatException()))); + await Assert.ThrowsAsync(async () => await new ValueTask(ManualResetValueTaskSource.Delay(1, 0, new FormatException()))); + await Task.Yield(); + } + return 17; + } + + async ValueTask ValueTaskReturningMethod() + { + for (int i = 0; i < 3; i++) + { + // Complete + await Task.CompletedTask; + await Task.FromResult(42); + await new ValueTask(); + await Assert.ThrowsAsync(async () => await new ValueTask(Task.FromException(new FormatException()))); + await Assert.ThrowsAsync(async () => await new ValueTask(ManualResetValueTaskSource.Completed(0, new FormatException()))); + Assert.Equal(42, await new ValueTask(42)); + Assert.Equal(42, await new ValueTask(Task.FromResult(42))); + Assert.Equal(42, await new ValueTask(ManualResetValueTaskSource.Completed(42, null))); + await Assert.ThrowsAsync(async () => await new ValueTask(Task.FromException(new FormatException()))); + await Assert.ThrowsAsync(async () => await new ValueTask(ManualResetValueTaskSource.Completed(0, new FormatException()))); + + // Incomplete + await Assert.ThrowsAsync(async () => await new ValueTask(Task.Delay(1).ContinueWith(_ => throw new FormatException()))); + await Assert.ThrowsAsync(async () => await new ValueTask(ManualResetValueTaskSource.Delay(1, 0, new FormatException()))); + Assert.Equal(42, await new ValueTask(Task.Delay(1).ContinueWith(_ => 42))); + Assert.Equal(42, await new ValueTask(ManualResetValueTaskSource.Delay(1, 42, null))); + await Assert.ThrowsAsync(async () => await new ValueTask(Task.Delay(1).ContinueWith(_ => throw new FormatException()))); + await Assert.ThrowsAsync(async () => await new ValueTask(ManualResetValueTaskSource.Delay(1, 0, new FormatException()))); + await Task.Yield(); + } + } + + async ValueTask ValueTaskInt32ReturningMethod() + { + for (int i = 0; i < 3; i++) + { + // Complete + await Task.CompletedTask; + await Task.FromResult(42); + await new ValueTask(); + await Assert.ThrowsAsync(async () => await new ValueTask(Task.FromException(new FormatException()))); + await Assert.ThrowsAsync(async () => await new ValueTask(ManualResetValueTaskSource.Completed(0, new FormatException()))); + Assert.Equal(42, await new ValueTask(42)); + Assert.Equal(42, await new ValueTask(Task.FromResult(42))); + Assert.Equal(42, await new ValueTask(ManualResetValueTaskSource.Completed(42, null))); + await Assert.ThrowsAsync(async () => await new ValueTask(Task.FromException(new FormatException()))); + await Assert.ThrowsAsync(async () => await new ValueTask(ManualResetValueTaskSource.Completed(0, new FormatException()))); + + // Incomplete + await Assert.ThrowsAsync(async () => await new ValueTask(Task.Delay(1).ContinueWith(_ => throw new FormatException()))); + await Assert.ThrowsAsync(async () => await new ValueTask(ManualResetValueTaskSource.Delay(1, 0, new FormatException()))); + Assert.Equal(42, await new ValueTask(Task.Delay(1).ContinueWith(_ => 42))); + Assert.Equal(42, await new ValueTask(ManualResetValueTaskSource.Delay(1, 42, null))); + await Assert.ThrowsAsync(async () => await new ValueTask(Task.Delay(1).ContinueWith(_ => throw new FormatException()))); + await Assert.ThrowsAsync(async () => await new ValueTask(ManualResetValueTaskSource.Delay(1, 0, new FormatException()))); + await Task.Yield(); + } + return 18; + } + } + + private static bool WrapsTask(ValueTask vt) => vt != default; private static bool WrapsTask(ValueTask vt) => ReferenceEquals(vt.AsTask(), vt.AsTask()); private struct DelegateStateMachine : IAsyncStateMachine diff --git a/src/System.Threading.Tasks.Extensions/tests/Configurations.props b/src/System.Threading.Tasks.Extensions/tests/Configurations.props index c70175586371..7de008759811 100644 --- a/src/System.Threading.Tasks.Extensions/tests/Configurations.props +++ b/src/System.Threading.Tasks.Extensions/tests/Configurations.props @@ -4,6 +4,7 @@ netcoreapp; uap; + netfx; \ No newline at end of file diff --git a/src/System.Threading.Tasks.Extensions/tests/ManualResetValueTaskSource.cs b/src/System.Threading.Tasks.Extensions/tests/ManualResetValueTaskSource.cs new file mode 100644 index 000000000000..e49be1918ba5 --- /dev/null +++ b/src/System.Threading.Tasks.Extensions/tests/ManualResetValueTaskSource.cs @@ -0,0 +1,170 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Runtime.ExceptionServices; + +namespace System.Threading.Tasks.Tests +{ + internal static class ManualResetValueTaskSource + { + public static ManualResetValueTaskSource Completed(T result, Exception error = null) + { + var vts = new ManualResetValueTaskSource(); + if (error != null) + { + vts.SetException(error); + } + else + { + vts.SetResult(result); + } + return vts; + } + + public static ManualResetValueTaskSource Delay(int delayMs, T result, Exception error = null) + { + var vts = new ManualResetValueTaskSource(); + Task.Delay(delayMs).ContinueWith(_ => + { + if (error != null) + { + vts.SetException(error); + } + else + { + vts.SetResult(result); + } + }); + return vts; + } + } + + internal sealed class ManualResetValueTaskSource : IValueTaskSource, IValueTaskSource + { + private static readonly Action s_sentinel = new Action(s => { }); + private Action _continuation; + private object _continuationState; + private SynchronizationContext _capturedContext; + private ExecutionContext _executionContext; + private bool _completed; + private T _result; + private ExceptionDispatchInfo _error; + + public ValueTaskSourceStatus Status => + !_completed ? ValueTaskSourceStatus.Pending : + _error == null ? ValueTaskSourceStatus.Succeeded : + _error.SourceException is OperationCanceledException ? ValueTaskSourceStatus.Canceled : + ValueTaskSourceStatus.Faulted; + + public T GetResult() + { + if (!_completed) + { + throw new Exception("Not completed"); + } + + ExecutionContext ctx = _executionContext; + if (ctx != null) + { + ctx.Dispose(); + _executionContext = null; + } + + _error?.Throw(); + return _result; + } + + void IValueTaskSource.GetResult() + { + GetResult(); + } + + public void Reset() + { + _completed = false; + _continuation = null; + _continuationState = null; + _result = default; + _error = null; + } + + public void OnCompleted(Action continuation, object state, ValueTaskSourceOnCompletedFlags flags) + { + if ((flags & ValueTaskSourceOnCompletedFlags.FlowExecutionContext) != 0) + { + _executionContext = ExecutionContext.Capture(); + } + + if ((flags & ValueTaskSourceOnCompletedFlags.UseSchedulingContext) != 0) + { + _capturedContext = SynchronizationContext.Current; + } + + _continuationState = state; + if (Interlocked.CompareExchange(ref _continuation, continuation, null) != null) + { + SynchronizationContext sc = _capturedContext; + if (sc != null) + { + _capturedContext = null; + sc.Post(s => + { + var tuple = (Tuple, object>)s; + tuple.Item1(tuple.Item2); + }, Tuple.Create(continuation, state)); + } + else + { + Task.Factory.StartNew(continuation, state, CancellationToken.None, TaskCreationOptions.DenyChildAttach, TaskScheduler.Default); + } + } + } + + public void SetResult(T result) + { + _result = result; + SignalCompletion(); + } + + public void SetException(Exception error) + { + _error = ExceptionDispatchInfo.Capture(error); + SignalCompletion(); + } + + private void SignalCompletion() + { + _completed = true; + if (Interlocked.CompareExchange(ref _continuation, s_sentinel, null) != null) + { + if (_executionContext != null) + { + ExecutionContext.Run(_executionContext, s => ((ManualResetValueTaskSource)s).InvokeContinuation(), this); + } + else + { + InvokeContinuation(); + } + } + } + + private void InvokeContinuation() + { + SynchronizationContext sc = _capturedContext; + if (sc != null) + { + _capturedContext = null; + sc.Post(s => + { + var thisRef = (ManualResetValueTaskSource)s; + thisRef._continuation(thisRef._continuationState); + }, this); + } + else + { + _continuation(_continuationState); + } + } + } +} diff --git a/src/System.Threading.Tasks.Extensions/tests/System.Threading.Tasks.Extensions.Tests.csproj b/src/System.Threading.Tasks.Extensions/tests/System.Threading.Tasks.Extensions.Tests.csproj index 9d8135bccc40..720fc370b2cd 100644 --- a/src/System.Threading.Tasks.Extensions/tests/System.Threading.Tasks.Extensions.Tests.csproj +++ b/src/System.Threading.Tasks.Extensions/tests/System.Threading.Tasks.Extensions.Tests.csproj @@ -11,10 +11,11 @@ + - \ No newline at end of file + diff --git a/src/System.Threading.Tasks.Extensions/tests/ValueTaskTests.cs b/src/System.Threading.Tasks.Extensions/tests/ValueTaskTests.cs index af9ef3be4b49..82fdfa09be87 100644 --- a/src/System.Threading.Tasks.Extensions/tests/ValueTaskTests.cs +++ b/src/System.Threading.Tasks.Extensions/tests/ValueTaskTests.cs @@ -11,8 +11,24 @@ namespace System.Threading.Tasks.Tests { public class ValueTaskTests { + public enum CtorMode + { + Result, + Task, + ValueTaskSource + } + + [Fact] + public void NonGeneric_DefaultValueTask_DefaultValue() + { + Assert.True(default(ValueTask).IsCompleted); + Assert.True(default(ValueTask).IsCompletedSuccessfully); + Assert.False(default(ValueTask).IsFaulted); + Assert.False(default(ValueTask).IsCanceled); + } + [Fact] - public void DefaultValueTask_ValueType_DefaultValue() + public void Generic_DefaultValueTask_DefaultValue() { Assert.True(default(ValueTask).IsCompleted); Assert.True(default(ValueTask).IsCompletedSuccessfully); @@ -27,21 +43,32 @@ public void DefaultValueTask_ValueType_DefaultValue() Assert.Equal(null, default(ValueTask).Result); } - [Fact] - public void CreateFromValue_IsRanToCompletion() + [Theory] + [InlineData(CtorMode.Result)] + [InlineData(CtorMode.Task)] + [InlineData(CtorMode.ValueTaskSource)] + public void NonGeneric_CreateFromSuccessfullyCompleted_IsCompletedSuccessfully(CtorMode mode) { - ValueTask t = new ValueTask(42); + ValueTask t = + mode == CtorMode.Result ? default : + mode == CtorMode.Task ? new ValueTask(Task.CompletedTask) : + new ValueTask(ManualResetValueTaskSource.Completed(0, null)); Assert.True(t.IsCompleted); Assert.True(t.IsCompletedSuccessfully); Assert.False(t.IsFaulted); Assert.False(t.IsCanceled); - Assert.Equal(42, t.Result); } - [Fact] - public void CreateFromCompletedTask_IsRanToCompletion() + [Theory] + [InlineData(CtorMode.Result)] + [InlineData(CtorMode.Task)] + [InlineData(CtorMode.ValueTaskSource)] + public void Generic_CreateFromSuccessfullyCompleted_IsCompletedSuccessfully(CtorMode mode) { - ValueTask t = new ValueTask(Task.FromResult(42)); + ValueTask t = + mode == CtorMode.Result ? new ValueTask(42) : + mode == CtorMode.Task ? new ValueTask(Task.FromResult(42)) : + new ValueTask(ManualResetValueTaskSource.Completed(42, null)); Assert.True(t.IsCompleted); Assert.True(t.IsCompletedSuccessfully); Assert.False(t.IsFaulted); @@ -49,18 +76,87 @@ public void CreateFromCompletedTask_IsRanToCompletion() Assert.Equal(42, t.Result); } - [Fact] - public void CreateFromNotCompletedTask_IsNotRanToCompletion() + [Theory] + [InlineData(CtorMode.Task)] + [InlineData(CtorMode.ValueTaskSource)] + public void NonGeneric_CreateFromNotCompleted_ThenCompleteSuccessfully(CtorMode mode) + { + object completer = null; + ValueTask t = default; + switch (mode) + { + case CtorMode.Task: + var tcs = new TaskCompletionSource(); + t = new ValueTask(tcs.Task); + completer = tcs; + break; + + case CtorMode.ValueTaskSource: + var mre = new ManualResetValueTaskSource(); + t = new ValueTask(mre); + completer = mre; + break; + } + + Assert.False(t.IsCompleted); + Assert.False(t.IsCompletedSuccessfully); + Assert.False(t.IsFaulted); + Assert.False(t.IsCanceled); + + switch (mode) + { + case CtorMode.Task: + ((TaskCompletionSource)completer).SetResult(42); + break; + + case CtorMode.ValueTaskSource: + ((ManualResetValueTaskSource)completer).SetResult(42); + break; + } + + Assert.True(t.IsCompleted); + Assert.True(t.IsCompletedSuccessfully); + Assert.False(t.IsFaulted); + Assert.False(t.IsCanceled); + } + + [Theory] + [InlineData(CtorMode.Task)] + [InlineData(CtorMode.ValueTaskSource)] + public void Generic_CreateFromNotCompleted_ThenCompleteSuccessfully(CtorMode mode) { - var tcs = new TaskCompletionSource(); - ValueTask t = new ValueTask(tcs.Task); + object completer = null; + ValueTask t = default; + switch (mode) + { + case CtorMode.Task: + var tcs = new TaskCompletionSource(); + t = new ValueTask(tcs.Task); + completer = tcs; + break; + + case CtorMode.ValueTaskSource: + var mre = new ManualResetValueTaskSource(); + t = new ValueTask(mre); + completer = mre; + break; + } Assert.False(t.IsCompleted); Assert.False(t.IsCompletedSuccessfully); Assert.False(t.IsFaulted); Assert.False(t.IsCanceled); - tcs.SetResult(42); + switch (mode) + { + case CtorMode.Task: + ((TaskCompletionSource)completer).SetResult(42); + break; + + case CtorMode.ValueTaskSource: + ((ManualResetValueTaskSource)completer).SetResult(42); + break; + } Assert.Equal(42, t.Result); Assert.True(t.IsCompleted); @@ -69,15 +165,164 @@ public void CreateFromNotCompletedTask_IsNotRanToCompletion() Assert.False(t.IsCanceled); } + [Theory] + [InlineData(CtorMode.Task)] + [InlineData(CtorMode.ValueTaskSource)] + public void NonGeneric_CreateFromNotCompleted_ThenFault(CtorMode mode) + { + object completer = null; + ValueTask t = default; + switch (mode) + { + case CtorMode.Task: + var tcs = new TaskCompletionSource(); + t = new ValueTask(tcs.Task); + completer = tcs; + break; + + case CtorMode.ValueTaskSource: + var mre = new ManualResetValueTaskSource(); + t = new ValueTask(mre); + completer = mre; + break; + } + + Assert.False(t.IsCompleted); + Assert.False(t.IsCompletedSuccessfully); + Assert.False(t.IsFaulted); + Assert.False(t.IsCanceled); + + Exception e = new InvalidOperationException(); + + switch (mode) + { + case CtorMode.Task: + ((TaskCompletionSource)completer).SetException(e); + break; + + case CtorMode.ValueTaskSource: + ((ManualResetValueTaskSource)completer).SetException(e); + break; + } + + Assert.True(t.IsCompleted); + Assert.False(t.IsCompletedSuccessfully); + Assert.True(t.IsFaulted); + Assert.False(t.IsCanceled); + + Assert.Same(e, Assert.Throws(() => t.GetAwaiter().GetResult())); + } + + [Theory] + [InlineData(CtorMode.Task)] + [InlineData(CtorMode.ValueTaskSource)] + public void Generic_CreateFromNotCompleted_ThenFault(CtorMode mode) + { + object completer = null; + ValueTask t = default; + switch (mode) + { + case CtorMode.Task: + var tcs = new TaskCompletionSource(); + t = new ValueTask(tcs.Task); + completer = tcs; + break; + + case CtorMode.ValueTaskSource: + var mre = new ManualResetValueTaskSource(); + t = new ValueTask(mre); + completer = mre; + break; + } + + Assert.False(t.IsCompleted); + Assert.False(t.IsCompletedSuccessfully); + Assert.False(t.IsFaulted); + Assert.False(t.IsCanceled); + + Exception e = new InvalidOperationException(); + + switch (mode) + { + case CtorMode.Task: + ((TaskCompletionSource)completer).SetException(e); + break; + + case CtorMode.ValueTaskSource: + ((ManualResetValueTaskSource)completer).SetException(e); + break; + } + + Assert.True(t.IsCompleted); + Assert.False(t.IsCompletedSuccessfully); + Assert.True(t.IsFaulted); + Assert.False(t.IsCanceled); + + Assert.Same(e, Assert.Throws(() => t.Result)); + Assert.Same(e, Assert.Throws(() => t.GetAwaiter().GetResult())); + } + + [Theory] + [InlineData(CtorMode.Task)] + [InlineData(CtorMode.ValueTaskSource)] + public void NonGeneric_CreateFromFaulted_IsFaulted(CtorMode mode) + { + InvalidOperationException e = new InvalidOperationException(); + ValueTask t = mode == CtorMode.Task ? new ValueTask(Task.FromException(e)) : new ValueTask(ManualResetValueTaskSource.Completed(0, e)); + + Assert.True(t.IsCompleted); + Assert.False(t.IsCompletedSuccessfully); + Assert.True(t.IsFaulted); + Assert.False(t.IsCanceled); + + Assert.Same(e, Assert.Throws(() => t.GetAwaiter().GetResult())); + } + + [Theory] + [InlineData(CtorMode.Task)] + [InlineData(CtorMode.ValueTaskSource)] + public void Generic_CreateFromFaulted_IsFaulted(CtorMode mode) + { + InvalidOperationException e = new InvalidOperationException(); + ValueTask t = mode == CtorMode.Task ? new ValueTask(Task.FromException(e)) : new ValueTask(ManualResetValueTaskSource.Completed(0, e)); + + Assert.True(t.IsCompleted); + Assert.False(t.IsCompletedSuccessfully); + Assert.True(t.IsFaulted); + Assert.False(t.IsCanceled); + + Assert.Same(e, Assert.Throws(() => t.Result)); + Assert.Same(e, Assert.Throws(() => t.GetAwaiter().GetResult())); + } + [Fact] - public void CreateFromNullTask_Throws() + public void NonGeneric_CreateFromNullTask_Throws() { - Assert.Throws(() => new ValueTask((Task)null)); - Assert.Throws(() => new ValueTask((Task)null)); + AssertExtensions.Throws("task", () => new ValueTask((Task)null)); + AssertExtensions.Throws("source", () => new ValueTask((IValueTaskSource)null)); } [Fact] - public void CreateFromTask_AsTaskIdempotent() + public void Generic_CreateFromNullTask_Throws() + { + AssertExtensions.Throws("task", () => new ValueTask((Task)null)); + AssertExtensions.Throws("task", () => new ValueTask((Task)null)); + + AssertExtensions.Throws("source", () => new ValueTask((IValueTaskSource)null)); + AssertExtensions.Throws("source", () => new ValueTask((IValueTaskSource)null)); + } + + [Fact] + public void NonGeneric_CreateFromTask_AsTaskIdempotent() + { + Task source = Task.FromResult(42); + ValueTask t = new ValueTask(source); + Assert.Same(source, t.AsTask()); + Assert.Same(t.AsTask(), t.AsTask()); + } + + [Fact] + public void Generic_CreateFromTask_AsTaskIdempotent() { Task source = Task.FromResult(42); ValueTask t = new ValueTask(source); @@ -86,7 +331,14 @@ public void CreateFromTask_AsTaskIdempotent() } [Fact] - public void CreateFromValue_AsTaskNotIdempotent() + public void NonGeneric_CreateFromDefault_AsTaskIdempotent() + { + ValueTask t = new ValueTask(); + Assert.Same(t.AsTask(), t.AsTask()); + } + + [Fact] + public void Generic_CreateFromValue_AsTaskNotIdempotent() { ValueTask t = new ValueTask(42); Assert.NotSame(Task.FromResult(42), t.AsTask()); @@ -94,68 +346,428 @@ public void CreateFromValue_AsTaskNotIdempotent() } [Fact] - public async Task CreateFromValue_Await() + public void NonGeneric_CreateFromValueTaskSource_AsTaskIdempotent() // validates unsupported behavior specific to the backing IValueTaskSource { - ValueTask t = new ValueTask(42); - Assert.Equal(42, await t); - Assert.Equal(42, await t.ConfigureAwait(false)); - Assert.Equal(42, await t.ConfigureAwait(true)); + ValueTask vt = new ValueTask(ManualResetValueTaskSource.Completed(42, null)); + Task t = vt.AsTask(); + Assert.NotNull(t); + Assert.Same(t, vt.AsTask()); + Assert.Same(Task.CompletedTask, vt.AsTask()); } [Fact] - public async Task CreateFromTask_Await_Normal() + public void Generic_CreateFromValueTaskSource_AsTaskNotIdempotent() // validates unsupported behavior specific to the backing IValueTaskSource { - Task source = Task.Delay(1).ContinueWith(_ => 42); - ValueTask t = new ValueTask(source); + ValueTask t = new ValueTask(ManualResetValueTaskSource.Completed(42, null)); + Assert.NotSame(Task.FromResult(42), t.AsTask()); + Assert.NotSame(t.AsTask(), t.AsTask()); + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task NonGeneric_CreateFromValueTaskSource_Success(bool sync) + { + ValueTask vt = new ValueTask(sync ? ManualResetValueTaskSource.Completed(0) : ManualResetValueTaskSource.Delay(1, 0)); + Task t = vt.AsTask(); + if (sync) + { + Assert.True(t.Status == TaskStatus.RanToCompletion); + } + await t; + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task Generic_CreateFromValueTaskSource_Success(bool sync) + { + ValueTask vt = new ValueTask(sync ? ManualResetValueTaskSource.Completed(42) : ManualResetValueTaskSource.Delay(1, 42)); + Task t = vt.AsTask(); + if (sync) + { + Assert.True(t.Status == TaskStatus.RanToCompletion); + } Assert.Equal(42, await t); } + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task NonGeneric_CreateFromValueTaskSource_Faulted(bool sync) + { + ValueTask vt = new ValueTask(sync ? ManualResetValueTaskSource.Completed(0, new FormatException()) : ManualResetValueTaskSource.Delay(1, 0, new FormatException())); + Task t = vt.AsTask(); + if (sync) + { + Assert.True(t.IsFaulted); + Assert.IsType(t.Exception.InnerException); + } + else + { + await Assert.ThrowsAsync(() => t); + } + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task Generic_CreateFromValueTaskSource_Faulted(bool sync) + { + ValueTask vt = new ValueTask(sync ? ManualResetValueTaskSource.Completed(0, new FormatException()) : ManualResetValueTaskSource.Delay(1, 0, new FormatException())); + Task t = vt.AsTask(); + if (sync) + { + Assert.True(t.IsFaulted); + Assert.IsType(t.Exception.InnerException); + } + else + { + await Assert.ThrowsAsync(() => t); + } + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task NonGeneric_CreateFromValueTaskSource_Canceled(bool sync) + { + ValueTask vt = new ValueTask(sync ? ManualResetValueTaskSource.Completed(0, new OperationCanceledException()) : ManualResetValueTaskSource.Delay(1, 0, new OperationCanceledException())); + Task t = vt.AsTask(); + if (sync) + { + Assert.True(t.IsCanceled); + } + else + { + await Assert.ThrowsAnyAsync(() => t); + Assert.True(t.IsCanceled); + } + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task Generic_CreateFromValueTaskSource_Canceled(bool sync) + { + ValueTask vt = new ValueTask(sync ? ManualResetValueTaskSource.Completed(0, new OperationCanceledException()) : ManualResetValueTaskSource.Delay(1, 0, new OperationCanceledException())); + Task t = vt.AsTask(); + if (sync) + { + Assert.True(t.IsCanceled); + } + else + { + await Assert.ThrowsAnyAsync(() => t); + Assert.True(t.IsCanceled); + } + } + [Fact] - public async Task CreateFromTask_Await_ConfigureAwaitFalse() + public void NonGeneric_Preserve_FromResult_NoChanges() { - Task source = Task.Delay(1).ContinueWith(_ => 42); - ValueTask t = new ValueTask(source); - Assert.Equal(42, await t.ConfigureAwait(false)); + ValueTask vt1 = default; + ValueTask vt2 = vt1.Preserve(); + Assert.True(vt1 == vt2); } [Fact] - public async Task CreateFromTask_Await_ConfigureAwaitTrue() + public void NonGeneric_Preserve_FromTask_EqualityMaintained() { - Task source = Task.Delay(1).ContinueWith(_ => 42); - ValueTask t = new ValueTask(source); - Assert.Equal(42, await t.ConfigureAwait(true)); + ValueTask vt1 = new ValueTask(Task.FromResult(42)); + ValueTask vt2 = vt1.Preserve(); + Assert.True(vt1 == vt2); } [Fact] - public async Task Awaiter_OnCompleted() + public void NonGeneric_Preserve_FromValueTaskSource_TransitionedToTask() { - // Since ValueTask implements both OnCompleted and UnsafeOnCompleted, - // OnCompleted typically won't be used by await, so we add an explicit test - // for it here. + ValueTask vt1 = new ValueTask(ManualResetValueTaskSource.Completed(42)); + ValueTask vt2 = vt1.Preserve(); + ValueTask vt3 = vt2.Preserve(); + Assert.True(vt1 != vt2); + Assert.True(vt2 == vt3); + Assert.Same(vt2.AsTask(), vt2.AsTask()); + } + + [Fact] + public void Generic_Preserve_FromResult_EqualityMaintained() + { + ValueTask vt1 = new ValueTask(42); + ValueTask vt2 = vt1.Preserve(); + Assert.True(vt1 == vt2); + } + + [Fact] + public void Generic_Preserve_FromTask_EqualityMaintained() + { + ValueTask vt1 = new ValueTask(Task.FromResult(42)); + ValueTask vt2 = vt1.Preserve(); + Assert.True(vt1 == vt2); + } + + [Fact] + public void Generic_Preserve_FromValueTaskSource_TransitionedToTask() + { + ValueTask vt1 = new ValueTask(ManualResetValueTaskSource.Completed(42)); + ValueTask vt2 = vt1.Preserve(); + ValueTask vt3 = vt2.Preserve(); + Assert.True(vt1 != vt2); + Assert.True(vt2 == vt3); + Assert.Same(vt2.AsTask(), vt2.AsTask()); + } + + [Theory] + [InlineData(CtorMode.Result)] + [InlineData(CtorMode.Task)] + [InlineData(CtorMode.ValueTaskSource)] + public async Task NonGeneric_CreateFromCompleted_Await(CtorMode mode) + { + ValueTask Create() => + mode == CtorMode.Result ? new ValueTask() : + mode == CtorMode.Task ? new ValueTask(Task.FromResult(42)) : + new ValueTask(ManualResetValueTaskSource.Completed(0, null)); + + int thread = Environment.CurrentManagedThreadId; + + await Create(); + Assert.Equal(thread, Environment.CurrentManagedThreadId); + + await Create().ConfigureAwait(false); + Assert.Equal(thread, Environment.CurrentManagedThreadId); + + await Create().ConfigureAwait(true); + Assert.Equal(thread, Environment.CurrentManagedThreadId); + } + + [Theory] + [InlineData(CtorMode.Result)] + [InlineData(CtorMode.Task)] + [InlineData(CtorMode.ValueTaskSource)] + public async Task Generic_CreateFromCompleted_Await(CtorMode mode) + { + ValueTask Create() => + mode == CtorMode.Result ? new ValueTask(42) : + mode == CtorMode.Task ? new ValueTask(Task.FromResult(42)) : + new ValueTask(ManualResetValueTaskSource.Completed(42, null)); + + int thread = Environment.CurrentManagedThreadId; + + Assert.Equal(42, await Create()); + Assert.Equal(thread, Environment.CurrentManagedThreadId); + + Assert.Equal(42, await Create().ConfigureAwait(false)); + Assert.Equal(thread, Environment.CurrentManagedThreadId); + + Assert.Equal(42, await Create().ConfigureAwait(true)); + Assert.Equal(thread, Environment.CurrentManagedThreadId); + } + + [Theory] + [InlineData(null)] + [InlineData(false)] + [InlineData(true)] + public async Task NonGeneric_CreateFromTask_Await_Normal(bool? continueOnCapturedContext) + { + var t = new ValueTask(Task.Delay(1)); + switch (continueOnCapturedContext) + { + case null: await t; break; + default: await t.ConfigureAwait(continueOnCapturedContext.Value); break; + } + } + + [Theory] + [InlineData(null)] + [InlineData(false)] + [InlineData(true)] + public async Task Generic_CreateFromTask_Await_Normal(bool? continueOnCapturedContext) + { + var t = new ValueTask(Task.Delay(1).ContinueWith(_ => 42)); + switch (continueOnCapturedContext) + { + case null: Assert.Equal(42, await t); break; + default: Assert.Equal(42, await t.ConfigureAwait(continueOnCapturedContext.Value)); break; + } + } + + [Theory] + [InlineData(null)] + [InlineData(false)] + [InlineData(true)] + public async Task CreateFromValueTaskSource_Await_Normal(bool? continueOnCapturedContext) + { + var mre = new ManualResetValueTaskSource(); + ValueTask t = new ValueTask(mre); + var ignored = Task.Delay(1).ContinueWith(_ => mre.SetResult(42)); + switch (continueOnCapturedContext) + { + case null: await t; break; + default: await t.ConfigureAwait(continueOnCapturedContext.Value); break; + } + } + + [Theory] + [InlineData(null)] + [InlineData(false)] + [InlineData(true)] + public async Task Generic_CreateFromValueTaskSource_Await_Normal(bool? continueOnCapturedContext) + { + var mre = new ManualResetValueTaskSource(); + ValueTask t = new ValueTask(mre); + var ignored = Task.Delay(1).ContinueWith(_ => mre.SetResult(42)); + switch (continueOnCapturedContext) + { + case null: Assert.Equal(42, await t); break; + default: Assert.Equal(42, await t.ConfigureAwait(continueOnCapturedContext.Value)); break; + } + } + + [Theory] + [InlineData(CtorMode.Result)] + [InlineData(CtorMode.Task)] + [InlineData(CtorMode.ValueTaskSource)] + public async Task NonGeneric_Awaiter_OnCompleted(CtorMode mode) + { + ValueTask t = + mode == CtorMode.Result ? new ValueTask() : + mode == CtorMode.Task ? new ValueTask(Task.CompletedTask) : + new ValueTask(ManualResetValueTaskSource.Completed(0, null)); - ValueTask t = new ValueTask(42); var tcs = new TaskCompletionSource(); t.GetAwaiter().OnCompleted(() => tcs.SetResult(true)); await tcs.Task; } [Theory] - [InlineData(true)] - [InlineData(false)] - public async Task ConfiguredAwaiter_OnCompleted(bool continueOnCapturedContext) + [InlineData(CtorMode.Result)] + [InlineData(CtorMode.Task)] + [InlineData(CtorMode.ValueTaskSource)] + public async Task NonGeneric_Awaiter_UnsafeOnCompleted(CtorMode mode) { - // Since ValueTask implements both OnCompleted and UnsafeOnCompleted, - // OnCompleted typically won't be used by await, so we add an explicit test - // for it here. + ValueTask t = + mode == CtorMode.Result ? new ValueTask() : + mode == CtorMode.Task ? new ValueTask(Task.CompletedTask) : + new ValueTask(ManualResetValueTaskSource.Completed(0, null)); + + var tcs = new TaskCompletionSource(); + t.GetAwaiter().UnsafeOnCompleted(() => tcs.SetResult(true)); + await tcs.Task; + } + + [Theory] + [InlineData(CtorMode.Result)] + [InlineData(CtorMode.Task)] + [InlineData(CtorMode.ValueTaskSource)] + public async Task Generic_Awaiter_OnCompleted(CtorMode mode) + { + ValueTask t = + mode == CtorMode.Result ? new ValueTask(42) : + mode == CtorMode.Task ? new ValueTask(Task.FromResult(42)) : + new ValueTask(ManualResetValueTaskSource.Completed(42, null)); + + var tcs = new TaskCompletionSource(); + t.GetAwaiter().OnCompleted(() => tcs.SetResult(true)); + await tcs.Task; + } + + [Theory] + [InlineData(CtorMode.Result)] + [InlineData(CtorMode.Task)] + [InlineData(CtorMode.ValueTaskSource)] + public async Task Generic_Awaiter_UnsafeOnCompleted(CtorMode mode) + { + ValueTask t = + mode == CtorMode.Result ? new ValueTask(42) : + mode == CtorMode.Task ? new ValueTask(Task.FromResult(42)) : + new ValueTask(ManualResetValueTaskSource.Completed(42, null)); + + var tcs = new TaskCompletionSource(); + t.GetAwaiter().UnsafeOnCompleted(() => tcs.SetResult(true)); + await tcs.Task; + } + + [Theory] + [InlineData(CtorMode.Result, true)] + [InlineData(CtorMode.Task, true)] + [InlineData(CtorMode.ValueTaskSource, true)] + [InlineData(CtorMode.Result, false)] + [InlineData(CtorMode.Task, false)] + [InlineData(CtorMode.ValueTaskSource, false)] + public async Task NonGeneric_ConfiguredAwaiter_OnCompleted(CtorMode mode, bool continueOnCapturedContext) + { + ValueTask t = + mode == CtorMode.Result ? new ValueTask() : + mode == CtorMode.Task ? new ValueTask(Task.CompletedTask) : + new ValueTask(ManualResetValueTaskSource.Completed(0, null)); - ValueTask t = new ValueTask(42); var tcs = new TaskCompletionSource(); t.ConfigureAwait(continueOnCapturedContext).GetAwaiter().OnCompleted(() => tcs.SetResult(true)); await tcs.Task; } - [Fact] - public async Task Awaiter_ContinuesOnCapturedContext() + [Theory] + [InlineData(CtorMode.Result, true)] + [InlineData(CtorMode.Task, true)] + [InlineData(CtorMode.ValueTaskSource, true)] + [InlineData(CtorMode.Result, false)] + [InlineData(CtorMode.Task, false)] + [InlineData(CtorMode.ValueTaskSource, false)] + public async Task NonGeneric_ConfiguredAwaiter_UnsafeOnCompleted(CtorMode mode, bool continueOnCapturedContext) + { + ValueTask t = + mode == CtorMode.Result ? new ValueTask() : + mode == CtorMode.Task ? new ValueTask(Task.CompletedTask) : + new ValueTask(ManualResetValueTaskSource.Completed(0, null)); + + var tcs = new TaskCompletionSource(); + t.ConfigureAwait(continueOnCapturedContext).GetAwaiter().UnsafeOnCompleted(() => tcs.SetResult(true)); + await tcs.Task; + } + + [Theory] + [InlineData(CtorMode.Result, true)] + [InlineData(CtorMode.Task, true)] + [InlineData(CtorMode.ValueTaskSource, true)] + [InlineData(CtorMode.Result, false)] + [InlineData(CtorMode.Task, false)] + [InlineData(CtorMode.ValueTaskSource, false)] + public async Task Generic_ConfiguredAwaiter_OnCompleted(CtorMode mode, bool continueOnCapturedContext) + { + ValueTask t = + mode == CtorMode.Result ? new ValueTask(42) : + mode == CtorMode.Task ? new ValueTask(Task.FromResult(42)) : + new ValueTask(ManualResetValueTaskSource.Completed(42, null)); + + var tcs = new TaskCompletionSource(); + t.ConfigureAwait(continueOnCapturedContext).GetAwaiter().OnCompleted(() => tcs.SetResult(true)); + await tcs.Task; + } + + [Theory] + [InlineData(CtorMode.Result, true)] + [InlineData(CtorMode.Task, true)] + [InlineData(CtorMode.ValueTaskSource, true)] + [InlineData(CtorMode.Result, false)] + [InlineData(CtorMode.Task, false)] + [InlineData(CtorMode.ValueTaskSource, false)] + public async Task Generic_ConfiguredAwaiter_UnsafeOnCompleted(CtorMode mode, bool continueOnCapturedContext) + { + ValueTask t = + mode == CtorMode.Result ? new ValueTask(42) : + mode == CtorMode.Task ? new ValueTask(Task.FromResult(42)) : + new ValueTask(ManualResetValueTaskSource.Completed(42, null)); + + var tcs = new TaskCompletionSource(); + t.ConfigureAwait(continueOnCapturedContext).GetAwaiter().UnsafeOnCompleted(() => tcs.SetResult(true)); + await tcs.Task; + } + + [Theory] + [InlineData(CtorMode.Result)] + [InlineData(CtorMode.Task)] + [InlineData(CtorMode.ValueTaskSource)] + public async Task NonGeneric_Awaiter_ContinuesOnCapturedContext(CtorMode mode) { await Task.Run(() => { @@ -163,7 +775,11 @@ await Task.Run(() => SynchronizationContext.SetSynchronizationContext(tsc); try { - ValueTask t = new ValueTask(42); + ValueTask t = + mode == CtorMode.Result ? new ValueTask() : + mode == CtorMode.Task ? new ValueTask(Task.CompletedTask) : + new ValueTask(ManualResetValueTaskSource.Completed(0, null)); + var mres = new ManualResetEventSlim(); t.GetAwaiter().OnCompleted(() => mres.Set()); Assert.True(mres.Wait(10000)); @@ -177,9 +793,84 @@ await Task.Run(() => } [Theory] - [InlineData(true)] - [InlineData(false)] - public async Task ConfiguredAwaiter_ContinuesOnCapturedContext(bool continueOnCapturedContext) + [InlineData(CtorMode.Task, false)] + [InlineData(CtorMode.ValueTaskSource, false)] + [InlineData(CtorMode.Result, true)] + [InlineData(CtorMode.Task, true)] + [InlineData(CtorMode.ValueTaskSource, true)] + public async Task Generic_Awaiter_ContinuesOnCapturedContext(CtorMode mode, bool sync) + { + await Task.Run(() => + { + var tsc = new TrackingSynchronizationContext(); + SynchronizationContext.SetSynchronizationContext(tsc); + try + { + ValueTask t = + mode == CtorMode.Result ? new ValueTask(42) : + mode == CtorMode.Task ? new ValueTask(sync ? Task.FromResult(42) : Task.Delay(1).ContinueWith(_ => 42)) : + new ValueTask(sync ? ManualResetValueTaskSource.Completed(42, null) : ManualResetValueTaskSource.Delay(1, 42, null)); + + var mres = new ManualResetEventSlim(); + t.GetAwaiter().OnCompleted(() => mres.Set()); + Assert.True(mres.Wait(10000)); + Assert.Equal(1, tsc.Posts); + } + finally + { + SynchronizationContext.SetSynchronizationContext(null); + } + }); + } + + [Theory] + [InlineData(CtorMode.Task, true, false)] + [InlineData(CtorMode.ValueTaskSource, true, false)] + [InlineData(CtorMode.Task, false, false)] + [InlineData(CtorMode.ValueTaskSource, false, false)] + [InlineData(CtorMode.Result, true, true)] + [InlineData(CtorMode.Task, true, true)] + [InlineData(CtorMode.ValueTaskSource, true, true)] + [InlineData(CtorMode.Result, false, true)] + [InlineData(CtorMode.Task, false, true)] + [InlineData(CtorMode.ValueTaskSource, false, true)] + public async Task NonGeneric_ConfiguredAwaiter_ContinuesOnCapturedContext(CtorMode mode, bool continueOnCapturedContext, bool sync) + { + await Task.Run(() => + { + var tsc = new TrackingSynchronizationContext(); + SynchronizationContext.SetSynchronizationContext(tsc); + try + { + ValueTask t = + mode == CtorMode.Result ? new ValueTask() : + mode == CtorMode.Task ? new ValueTask(sync ? Task.CompletedTask : Task.Delay(1)) : + new ValueTask(sync ? ManualResetValueTaskSource.Completed(0, null) : ManualResetValueTaskSource.Delay(42, 0, null)); + + var mres = new ManualResetEventSlim(); + t.ConfigureAwait(continueOnCapturedContext).GetAwaiter().OnCompleted(() => mres.Set()); + Assert.True(mres.Wait(10000)); + Assert.Equal(continueOnCapturedContext ? 1 : 0, tsc.Posts); + } + finally + { + SynchronizationContext.SetSynchronizationContext(null); + } + }); + } + + [Theory] + [InlineData(CtorMode.Task, true, false)] + [InlineData(CtorMode.ValueTaskSource, true, false)] + [InlineData(CtorMode.Task, false, false)] + [InlineData(CtorMode.ValueTaskSource, false, false)] + [InlineData(CtorMode.Result, true, true)] + [InlineData(CtorMode.Task, true, true)] + [InlineData(CtorMode.ValueTaskSource, true, true)] + [InlineData(CtorMode.Result, false, true)] + [InlineData(CtorMode.Task, false, true)] + [InlineData(CtorMode.ValueTaskSource, false, true)] + public async Task Generic_ConfiguredAwaiter_ContinuesOnCapturedContext(CtorMode mode, bool continueOnCapturedContext, bool sync) { await Task.Run(() => { @@ -187,7 +878,11 @@ await Task.Run(() => SynchronizationContext.SetSynchronizationContext(tsc); try { - ValueTask t = new ValueTask(42); + ValueTask t = + mode == CtorMode.Result ? new ValueTask(42) : + mode == CtorMode.Task ? new ValueTask(sync ? Task.FromResult(42) : Task.Delay(1).ContinueWith(_ => 42)) : + new ValueTask(sync ? ManualResetValueTaskSource.Completed(42, null) : ManualResetValueTaskSource.Delay(1, 42, null)); + var mres = new ManualResetEventSlim(); t.ConfigureAwait(continueOnCapturedContext).GetAwaiter().OnCompleted(() => mres.Set()); Assert.True(mres.Wait(10000)); @@ -201,60 +896,166 @@ await Task.Run(() => } [Fact] - public void GetHashCode_ContainsResult() + public void NonGeneric_GetHashCode_FromDefault_0() { - ValueTask t = new ValueTask(42); - Assert.Equal(t.Result.GetHashCode(), t.GetHashCode()); + Assert.Equal(0, new ValueTask().GetHashCode()); } [Fact] - public void GetHashCode_ContainsTask() + public void Generic_GetHashCode_FromResult_ContainsResult() + { + var vt = new ValueTask(42); + Assert.Equal(vt.Result.GetHashCode(), vt.GetHashCode()); + + var rt = new ValueTask((string)null); + Assert.Equal(0, rt.GetHashCode()); + rt = new ValueTask("12345"); + Assert.Equal(rt.Result.GetHashCode(), rt.GetHashCode()); + } + + [Theory] + [InlineData(CtorMode.Task)] + [InlineData(CtorMode.ValueTaskSource)] + public void NonGeneric_GetHashCode_FromObject_MatchesObjectHashCode(CtorMode mode) { - ValueTask t = new ValueTask(Task.FromResult("42")); - Assert.Equal(t.AsTask().GetHashCode(), t.GetHashCode()); + object obj; + ValueTask vt; + if (mode == CtorMode.Task) + { + Task t = Task.CompletedTask; + vt = new ValueTask(t); + obj = t; + } + else + { + var t = ManualResetValueTaskSource.Completed(42, null); + vt = new ValueTask(t); + obj = t; + } + + Assert.Equal(obj.GetHashCode(), vt.GetHashCode()); + } + + [Theory] + [InlineData(CtorMode.Task)] + [InlineData(CtorMode.ValueTaskSource)] + public void Generic_GetHashCode_FromObject_MatchesObjectHashCode(CtorMode mode) + { + object obj; + ValueTask vt; + if (mode == CtorMode.Task) + { + Task t = Task.FromResult(42); + vt = new ValueTask(t); + obj = t; + } + else + { + ManualResetValueTaskSource t = ManualResetValueTaskSource.Completed(42, null); + vt = new ValueTask(t); + obj = t; + } + + Assert.Equal(obj.GetHashCode(), vt.GetHashCode()); } [Fact] - public void GetHashCode_ContainsNull() + public void NonGeneric_OperatorEquals() { - ValueTask t = new ValueTask((string)null); - Assert.Equal(0, t.GetHashCode()); + var completedTcs = new TaskCompletionSource(); + completedTcs.SetResult(42); + + var completedVts = ManualResetValueTaskSource.Completed(42, null); + + Assert.True(new ValueTask() == new ValueTask()); + Assert.True(new ValueTask(Task.CompletedTask) == new ValueTask(Task.CompletedTask)); + Assert.True(new ValueTask(completedTcs.Task) == new ValueTask(completedTcs.Task)); + Assert.True(new ValueTask(completedVts) == new ValueTask(completedVts)); + + Assert.False(new ValueTask(Task.CompletedTask) == new ValueTask(completedTcs.Task)); + Assert.False(new ValueTask(Task.CompletedTask) == new ValueTask(completedVts)); + Assert.False(new ValueTask(completedTcs.Task) == new ValueTask(completedVts)); } [Fact] - public void OperatorEquals() + public void Generic_OperatorEquals() { + var completedTask = Task.FromResult(42); + var completedVts = ManualResetValueTaskSource.Completed(42, null); + Assert.True(new ValueTask(42) == new ValueTask(42)); - Assert.False(new ValueTask(42) == new ValueTask(43)); + Assert.True(new ValueTask(completedTask) == new ValueTask(completedTask)); + Assert.True(new ValueTask(completedVts) == new ValueTask(completedVts)); Assert.True(new ValueTask("42") == new ValueTask("42")); Assert.True(new ValueTask((string)null) == new ValueTask((string)null)); + Assert.False(new ValueTask(42) == new ValueTask(43)); Assert.False(new ValueTask("42") == new ValueTask((string)null)); Assert.False(new ValueTask((string)null) == new ValueTask("42")); Assert.False(new ValueTask(42) == new ValueTask(Task.FromResult(42))); Assert.False(new ValueTask(Task.FromResult(42)) == new ValueTask(42)); + Assert.False(new ValueTask(ManualResetValueTaskSource.Completed(42, null)) == new ValueTask(42)); + Assert.False(new ValueTask(completedTask) == new ValueTask(completedVts)); } [Fact] - public void OperatorNotEquals() + public void NonGeneric_OperatorNotEquals() { + var completedTcs = new TaskCompletionSource(); + completedTcs.SetResult(42); + + var completedVts = ManualResetValueTaskSource.Completed(42, null); + + Assert.False(new ValueTask() != new ValueTask()); + Assert.False(new ValueTask(Task.CompletedTask) != new ValueTask(Task.CompletedTask)); + Assert.False(new ValueTask(completedTcs.Task) != new ValueTask(completedTcs.Task)); + Assert.False(new ValueTask(completedVts) != new ValueTask(completedVts)); + + Assert.True(new ValueTask(Task.CompletedTask) != new ValueTask(completedTcs.Task)); + Assert.True(new ValueTask(Task.CompletedTask) != new ValueTask(completedVts)); + Assert.True(new ValueTask(completedTcs.Task) != new ValueTask(completedVts)); + } + + [Fact] + public void Generic_OperatorNotEquals() + { + var completedTask = Task.FromResult(42); + var completedVts = ManualResetValueTaskSource.Completed(42, null); + Assert.False(new ValueTask(42) != new ValueTask(42)); - Assert.True(new ValueTask(42) != new ValueTask(43)); + Assert.False(new ValueTask(completedTask) != new ValueTask(completedTask)); + Assert.False(new ValueTask(completedVts) != new ValueTask(completedVts)); Assert.False(new ValueTask("42") != new ValueTask("42")); Assert.False(new ValueTask((string)null) != new ValueTask((string)null)); + Assert.True(new ValueTask(42) != new ValueTask(43)); Assert.True(new ValueTask("42") != new ValueTask((string)null)); Assert.True(new ValueTask((string)null) != new ValueTask("42")); Assert.True(new ValueTask(42) != new ValueTask(Task.FromResult(42))); Assert.True(new ValueTask(Task.FromResult(42)) != new ValueTask(42)); + Assert.True(new ValueTask(ManualResetValueTaskSource.Completed(42, null)) != new ValueTask(42)); + Assert.True(new ValueTask(completedTask) != new ValueTask(completedVts)); } [Fact] - public void Equals_ValueTask() + public void NonGeneric_Equals_ValueTask() + { + Assert.True(new ValueTask().Equals(new ValueTask())); + + Assert.False(new ValueTask().Equals(new ValueTask(Task.CompletedTask))); + Assert.False(new ValueTask(Task.CompletedTask).Equals(new ValueTask())); + Assert.False(new ValueTask(ManualResetValueTaskSource.Completed(42, null)).Equals(new ValueTask())); + Assert.False(new ValueTask().Equals(new ValueTask(ManualResetValueTaskSource.Completed(42, null)))); + Assert.False(new ValueTask(Task.CompletedTask).Equals(new ValueTask(ManualResetValueTaskSource.Completed(42, null)))); + Assert.False(new ValueTask(ManualResetValueTaskSource.Completed(42, null)).Equals(new ValueTask(Task.CompletedTask))); + } + + [Fact] + public void Generic_Equals_ValueTask() { Assert.True(new ValueTask(42).Equals(new ValueTask(42))); Assert.False(new ValueTask(42).Equals(new ValueTask(43))); @@ -267,10 +1068,29 @@ public void Equals_ValueTask() Assert.False(new ValueTask(42).Equals(new ValueTask(Task.FromResult(42)))); Assert.False(new ValueTask(Task.FromResult(42)).Equals(new ValueTask(42))); + Assert.False(new ValueTask(ManualResetValueTaskSource.Completed(42, null)).Equals(new ValueTask(42))); + } + + [Fact] + public void NonGeneric_Equals_Object() + { + Assert.True(new ValueTask().Equals((object)new ValueTask())); + + Assert.False(new ValueTask().Equals((object)new ValueTask(Task.CompletedTask))); + Assert.False(new ValueTask(Task.CompletedTask).Equals((object)new ValueTask())); + Assert.False(new ValueTask(ManualResetValueTaskSource.Completed(42, null)).Equals((object)new ValueTask())); + Assert.False(new ValueTask().Equals((object)new ValueTask(ManualResetValueTaskSource.Completed(42, null)))); + Assert.False(new ValueTask(Task.CompletedTask).Equals((object)new ValueTask(ManualResetValueTaskSource.Completed(42, null)))); + Assert.False(new ValueTask(ManualResetValueTaskSource.Completed(42, null)).Equals((object)new ValueTask(Task.CompletedTask))); + + Assert.False(new ValueTask().Equals(null)); + Assert.False(new ValueTask().Equals("12345")); + Assert.False(new ValueTask(Task.CompletedTask).Equals("12345")); + Assert.False(new ValueTask(ManualResetValueTaskSource.Completed(42, null)).Equals("12345")); } [Fact] - public void Equals_Object() + public void Generic_Equals_Object() { Assert.True(new ValueTask(42).Equals((object)new ValueTask(42))); Assert.False(new ValueTask(42).Equals((object)new ValueTask(43))); @@ -283,6 +1103,7 @@ public void Equals_Object() Assert.False(new ValueTask(42).Equals((object)new ValueTask(Task.FromResult(42)))); Assert.False(new ValueTask(Task.FromResult(42)).Equals((object)new ValueTask(42))); + Assert.False(new ValueTask(ManualResetValueTaskSource.Completed(42, null)).Equals((object)new ValueTask(42))); Assert.False(new ValueTask(42).Equals((object)null)); Assert.False(new ValueTask(42).Equals(new object())); @@ -290,19 +1111,31 @@ public void Equals_Object() } [Fact] - public void ToString_Success() + public void NonGeneric_ToString_Success() + { + Assert.Equal("System.Threading.Tasks.ValueTask", new ValueTask().ToString()); + Assert.Equal("System.Threading.Tasks.ValueTask", new ValueTask(Task.CompletedTask).ToString()); + Assert.Equal("System.Threading.Tasks.ValueTask", new ValueTask(ManualResetValueTaskSource.Completed(42, null)).ToString()); + } + + [Fact] + public void Generic_ToString_Success() { Assert.Equal("Hello", new ValueTask("Hello").ToString()); Assert.Equal("Hello", new ValueTask(Task.FromResult("Hello")).ToString()); + Assert.Equal("Hello", new ValueTask(ManualResetValueTaskSource.Completed("Hello", null)).ToString()); Assert.Equal("42", new ValueTask(42).ToString()); Assert.Equal("42", new ValueTask(Task.FromResult(42)).ToString()); + Assert.Equal("42", new ValueTask(ManualResetValueTaskSource.Completed(42, null)).ToString()); Assert.Same(string.Empty, new ValueTask(string.Empty).ToString()); Assert.Same(string.Empty, new ValueTask(Task.FromResult(string.Empty)).ToString()); + Assert.Same(string.Empty, new ValueTask(ManualResetValueTaskSource.Completed(string.Empty, null)).ToString()); Assert.Same(string.Empty, new ValueTask(Task.FromException(new InvalidOperationException())).ToString()); Assert.Same(string.Empty, new ValueTask(Task.FromException(new OperationCanceledException())).ToString()); + Assert.Same(string.Empty, new ValueTask(ManualResetValueTaskSource.Completed(null, new InvalidOperationException())).ToString()); Assert.Same(string.Empty, new ValueTask(Task.FromCanceled(new CancellationToken(true))).ToString()); @@ -310,15 +1143,28 @@ public void ToString_Success() Assert.Same(string.Empty, default(ValueTask).ToString()); Assert.Same(string.Empty, new ValueTask((string)null).ToString()); Assert.Same(string.Empty, new ValueTask(Task.FromResult(null)).ToString()); + Assert.Same(string.Empty, new ValueTask(ManualResetValueTaskSource.Completed(null, null)).ToString()); Assert.Same(string.Empty, new ValueTask(new TaskCompletionSource().Task).ToString()); } + [Theory] + [InlineData(typeof(ValueTask))] + public void NonGeneric_AsyncMethodBuilderAttribute_ValueTaskAttributed(Type valueTaskType) + { + CustomAttributeData cad = valueTaskType.GetTypeInfo().CustomAttributes.Single(attr => attr.AttributeType == typeof(AsyncMethodBuilderAttribute)); + Type builderTypeCtorArg = (Type)cad.ConstructorArguments[0].Value; + Assert.Equal(typeof(AsyncValueTaskMethodBuilder), builderTypeCtorArg); + + AsyncMethodBuilderAttribute amba = valueTaskType.GetTypeInfo().GetCustomAttribute(); + Assert.Equal(builderTypeCtorArg, amba.BuilderType); + } + [Theory] [InlineData(typeof(ValueTask<>))] [InlineData(typeof(ValueTask))] [InlineData(typeof(ValueTask))] - public void AsyncMethodBuilderAttribute_ValueTaskAttributed(Type valueTaskType) + public void Generic_AsyncMethodBuilderAttribute_ValueTaskAttributed(Type valueTaskType) { CustomAttributeData cad = valueTaskType.GetTypeInfo().CustomAttributes.Single(attr => attr.AttributeType == typeof(AsyncMethodBuilderAttribute)); Type builderTypeCtorArg = (Type)cad.ConstructorArguments[0].Value; @@ -328,6 +1174,88 @@ public void AsyncMethodBuilderAttribute_ValueTaskAttributed(Type valueTaskType) Assert.Equal(builderTypeCtorArg, amba.BuilderType); } + [Fact] + public void NonGeneric_AsTask_ValueTaskSourcePassesInvalidStateToOnCompleted_Throws() + { + void Validate(IValueTaskSource vts) + { + var vt = new ValueTask(vts); + Assert.Throws(() => { vt.AsTask(); }); + } + + Validate(new DelegateValueTaskSource { OnCompletedFunc = (continuation, state, flags) => continuation(null) }); + Validate(new DelegateValueTaskSource { OnCompletedFunc = (continuation, state, flags) => continuation(new object()) }); + Validate(new DelegateValueTaskSource { OnCompletedFunc = (continuation, state, flags) => { continuation(state); continuation(state); } }); + } + + [Fact] + public void Generic_AsTask_ValueTaskSourcePassesInvalidStateToOnCompleted_Throws() + { + void Validate(IValueTaskSource vts) + { + var vt = new ValueTask(vts); + Assert.Throws(() => { vt.AsTask(); }); + } + + Validate(new DelegateValueTaskSource { OnCompletedFunc = (continuation, state, flags) => continuation(null) }); + Validate(new DelegateValueTaskSource { OnCompletedFunc = (continuation, state, flags) => continuation(new object()) }); + Validate(new DelegateValueTaskSource { OnCompletedFunc = (continuation, state, flags) => { continuation(state); continuation(state); } }); + } + + [Fact] + public void NonGeneric_OnCompleted_ValueTaskSourcePassesInvalidStateToOnCompleted_Throws() + { + void Validate(IValueTaskSource vts) + { + var vt = new ValueTask(vts); + Assert.Throws(() => vt.GetAwaiter().OnCompleted(() => { })); + Assert.Throws(() => vt.GetAwaiter().UnsafeOnCompleted(() => { })); + foreach (bool continueOnCapturedContext in new[] { true, false }) + { + Assert.Throws(() => vt.ConfigureAwait(false).GetAwaiter().OnCompleted(() => { })); + Assert.Throws(() => vt.ConfigureAwait(false).GetAwaiter().UnsafeOnCompleted(() => { })); + } + } + + Validate(new DelegateValueTaskSource { OnCompletedFunc = (continuation, state, flags) => continuation(null) }); + Validate(new DelegateValueTaskSource { OnCompletedFunc = (continuation, state, flags) => continuation(new object()) }); + } + + [Fact] + public void Generic_OnCompleted_ValueTaskSourcePassesInvalidStateToOnCompleted_Throws() + { + void Validate(IValueTaskSource vts) + { + var vt = new ValueTask(vts); + Assert.Throws(() => vt.GetAwaiter().OnCompleted(() => { })); + Assert.Throws(() => vt.GetAwaiter().UnsafeOnCompleted(() => { })); + foreach (bool continueOnCapturedContext in new[] { true, false }) + { + Assert.Throws(() => vt.ConfigureAwait(false).GetAwaiter().OnCompleted(() => { })); + Assert.Throws(() => vt.ConfigureAwait(false).GetAwaiter().UnsafeOnCompleted(() => { })); + } + } + + Validate(new DelegateValueTaskSource { OnCompletedFunc = (continuation, state, flags) => continuation(null) }); + Validate(new DelegateValueTaskSource { OnCompletedFunc = (continuation, state, flags) => continuation(new object()) }); + } + + private sealed class DelegateValueTaskSource : IValueTaskSource, IValueTaskSource + { + public Func StatusFunc = null; + public Action GetResultAction = null; + public Func GetResultFunc = null; + public Action, object, ValueTaskSourceOnCompletedFlags> OnCompletedFunc; + + public ValueTaskSourceStatus Status => StatusFunc?.Invoke() ?? ValueTaskSourceStatus.Pending; + + public void GetResult() => GetResultAction?.Invoke(); + T IValueTaskSource.GetResult() => GetResultFunc != null ? GetResultFunc() : default; + + public void OnCompleted(Action continuation, object state, ValueTaskSourceOnCompletedFlags flags) => + OnCompletedFunc?.Invoke(continuation, state, flags); + } + private sealed class TrackingSynchronizationContext : SynchronizationContext { internal int Posts { get; set; }