diff --git a/src/libraries/Common/tests/System/Net/Http/Http3LoopbackStream.cs b/src/libraries/Common/tests/System/Net/Http/Http3LoopbackStream.cs index 778919eb4863d..3834d92be94c5 100644 --- a/src/libraries/Common/tests/System/Net/Http/Http3LoopbackStream.cs +++ b/src/libraries/Common/tests/System/Net/Http/Http3LoopbackStream.cs @@ -286,7 +286,7 @@ public async Task SendResponseBodyAsync(byte[] content, bool isFinal = true) if (isFinal) { - _stream.CompleteWrites(); + await _stream.CompleteWritesAsync().ConfigureAwait(false); } } diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3RequestStream.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3RequestStream.cs index c411556644565..a63df8259235c 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3RequestStream.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3RequestStream.cs @@ -445,7 +445,7 @@ private async Task SendContentAsync(HttpContent content, CancellationToken cance } else { - _stream.CompleteWrites(); + await _stream.CompleteWritesAsync().ConfigureAwait(false); } if (HttpTelemetry.Log.IsEnabled()) HttpTelemetry.Log.RequestContentStop(bytesWritten); diff --git a/src/libraries/System.Net.Quic/ref/System.Net.Quic.cs b/src/libraries/System.Net.Quic/ref/System.Net.Quic.cs index 847dc74b56a95..1ae4783da95c8 100644 --- a/src/libraries/System.Net.Quic/ref/System.Net.Quic.cs +++ b/src/libraries/System.Net.Quic/ref/System.Net.Quic.cs @@ -126,7 +126,9 @@ internal QuicStream() { } public void Abort(System.Net.Quic.QuicAbortDirection abortDirection, long errorCode) { } public override System.IAsyncResult BeginRead(byte[] buffer, int offset, int count, System.AsyncCallback? callback, object? state) { throw null; } public override System.IAsyncResult BeginWrite(byte[] buffer, int offset, int count, System.AsyncCallback? callback, object? state) { throw null; } + [System.ObsoleteAttribute("Will be removed soon, use CompleteWritesAsync instead.")] public void CompleteWrites() { } + public System.Threading.Tasks.ValueTask CompleteWritesAsync() { throw null; } protected override void Dispose(bool disposing) { } public override System.Threading.Tasks.ValueTask DisposeAsync() { throw null; } public override int EndRead(System.IAsyncResult asyncResult) { throw null; } diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicStream.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicStream.cs index a0713d3b8f9bb..20c97de80127f 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicStream.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicStream.cs @@ -35,7 +35,7 @@ namespace System.Net.Quic; /// Allows to close the writing side of the stream as a single operation with the write itself. /// /// -/// +/// /// Close the writing side of the stream. /// /// @@ -147,7 +147,7 @@ public sealed partial class QuicStream /// /// A that will get completed once writing side has been closed. - /// Which might be by closing the write side via + /// Which might be by closing the write side via /// or with completeWrites: true and getting acknowledgement from the peer for it, /// or when for is called, /// or when the peer called for . @@ -360,16 +360,13 @@ public override ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationTo /// The region of memory to write data from. /// The token to monitor for cancellation requests. The default value is . /// Notifies the peer about gracefully closing the write side, i.e.: sends FIN flag with the data. - public ValueTask WriteAsync(ReadOnlyMemory buffer, bool completeWrites, CancellationToken cancellationToken = default) + public async ValueTask WriteAsync(ReadOnlyMemory buffer, bool completeWrites, CancellationToken cancellationToken = default) { - if (_disposed == 1) - { - return ValueTask.FromException(ExceptionDispatchInfo.SetCurrentStackTrace(new ObjectDisposedException(nameof(QuicStream)))); - } + ObjectDisposedException.ThrowIf(_disposed == 1, this); if (!_canWrite) { - return ValueTask.FromException(ExceptionDispatchInfo.SetCurrentStackTrace(new InvalidOperationException(SR.net_quic_writing_notallowed))); + throw new InvalidOperationException(SR.net_quic_writing_notallowed); } if (NetEventSource.Log.IsEnabled()) @@ -377,34 +374,36 @@ public ValueTask WriteAsync(ReadOnlyMemory buffer, bool completeWrites, Ca NetEventSource.Info(this, $"{this} Stream writing memory of '{buffer.Length}' bytes while {(completeWrites ? "completing" : "not completing")} writes."); } - if (_sendTcs.IsCompleted && cancellationToken.IsCancellationRequested) + if (_sendTcs.IsCompleted) { // Special case exception type for pre-canceled token while we've already transitioned to a final state and don't need to abort write. // It must happen before we try to get the value task, since the task source is versioned and each instance must be awaited. - return ValueTask.FromCanceled(cancellationToken); + cancellationToken.ThrowIfCancellationRequested(); } // Concurrent call, this one lost the race. if (!_sendTcs.TryGetValueTask(out ValueTask valueTask, this, cancellationToken)) { - return ValueTask.FromException(ExceptionDispatchInfo.SetCurrentStackTrace(new InvalidOperationException(SR.Format(SR.net_io_invalidnestedcall, "write")))); + throw new InvalidOperationException(SR.Format(SR.net_io_invalidnestedcall, "write")); } // No need to call anything since we already have a result, most likely an exception. if (valueTask.IsCompleted) { - return valueTask; + await valueTask.ConfigureAwait(false); + return; } // For an empty buffer complete immediately, close the writing side of the stream if necessary. if (buffer.IsEmpty) { _sendTcs.TrySetResult(); + await valueTask.ConfigureAwait(false); if (completeWrites) { - CompleteWrites(); + await CompleteWritesAsync().ConfigureAwait(false); } - return valueTask; + return; } // We own the lock, abort might happen, but exception will get stored instead. @@ -440,7 +439,11 @@ public ValueTask WriteAsync(ReadOnlyMemory buffer, bool completeWrites, Ca } } - return valueTask; + await valueTask.ConfigureAwait(false); + if (completeWrites) + { + await _sendTcs.GetFinalTask(this).ConfigureAwait(false); + } } /// @@ -511,6 +514,7 @@ public void Abort(QuicAbortDirection abortDirection, long errorCode) /// /// Corresponds to an empty STREAM frame with FIN flag set to true. /// + [Obsolete("Will be removed soon, use CompleteWritesAsync instead.")] public void CompleteWrites() { ObjectDisposedException.ThrowIf(_disposed == 1, this); @@ -535,6 +539,41 @@ public void CompleteWrites() } } + /// + /// Gracefully completes the writing side of the stream. + /// Equivalent to using with completeWrites: true. + /// + /// + /// Corresponds to an empty STREAM frame with FIN flag set to true. + /// + public ValueTask CompleteWritesAsync() + { + ObjectDisposedException.ThrowIf(_disposed == 1, this); + + // Nothing to complete, the writing side is already closed. + if (_sendTcs.IsCompleted) + { + return ValueTask.CompletedTask; + } + + if (NetEventSource.Log.IsEnabled()) + { + NetEventSource.Info(this, $"{this} Completing writes."); + } + unsafe + { + int status = MsQuicApi.Api.StreamShutdown( + _handle, + QUIC_STREAM_SHUTDOWN_FLAGS.GRACEFUL, + default); + if (StatusFailed(status)) + { + return ValueTask.FromException(ExceptionDispatchInfo.SetCurrentStackTrace(ThrowHelper.GetExceptionForMsQuicStatus(status, message: "StreamShutdown failed"))); + } + } + return new ValueTask(_sendTcs.GetFinalTask(this)); + } + private unsafe int HandleEventStartComplete(ref START_COMPLETE_DATA data) { Debug.Assert(_decrementStreamCapacity is not null); @@ -709,7 +748,7 @@ private static unsafe int NativeCallback(QUIC_HANDLE* stream, void* context, QUI /// /// If the read side is not fully consumed, i.e.: is not completed and/or hasn't returned 0, /// dispose will abort the read side with provided . - /// If the write side hasn't been closed, it'll be closed gracefully as if was called. + /// If the write side hasn't been closed, it'll be closed gracefully as if was called. /// Finally, all resources associated with the stream will be released. /// /// A task that represents the asynchronous dispose operation. diff --git a/src/libraries/System.Net.Quic/tests/FunctionalTests/MsQuicTests.cs b/src/libraries/System.Net.Quic/tests/FunctionalTests/MsQuicTests.cs index 41ac5e41da24d..959e9430236e4 100644 --- a/src/libraries/System.Net.Quic/tests/FunctionalTests/MsQuicTests.cs +++ b/src/libraries/System.Net.Quic/tests/FunctionalTests/MsQuicTests.cs @@ -1029,7 +1029,7 @@ await RunClientServer( } } - stream.CompleteWrites(); + await stream.CompleteWritesAsync(); }, async serverConnection => { @@ -1046,7 +1046,7 @@ await RunClientServer( int expectedTotalBytes = writes.SelectMany(x => x).Sum(); Assert.Equal(expectedTotalBytes, totalBytes); - stream.CompleteWrites(); + await stream.CompleteWritesAsync(); }); } @@ -1339,7 +1339,7 @@ public async Task BigWrite_SmallRead_Success(bool closeWithData) if (!closeWithData) { - serverStream.CompleteWrites(); + await serverStream.CompleteWritesAsync(); } readLength = await clientStream.ReadAsync(actual); diff --git a/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicStreamTests.cs b/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicStreamTests.cs index 8f455d61f6a7c..585d3a02d5ef6 100644 --- a/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicStreamTests.cs +++ b/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicStreamTests.cs @@ -469,7 +469,7 @@ await RunBidirectionalClientServer( { await clientStream.WriteAsync(new byte[1]); sem.Release(); - clientStream.CompleteWrites(); + await clientStream.CompleteWritesAsync(); sem.Release(); }, async serverStream => @@ -761,7 +761,6 @@ async ValueTask ReleaseOnWritesClosedAsync() }); } - [Fact] public async Task WriteAsync_LocalAbort_Throws() { @@ -1014,7 +1013,7 @@ await RunBidirectionalClientServer( Assert.False(writesClosedTask.IsCompleted, "Server is still writing."); - serverStream.CompleteWrites(); + await serverStream.CompleteWritesAsync(); await writesClosedTask; }); @@ -1035,7 +1034,7 @@ await RunBidirectionalClientServer( await clientStream.WriteAsync(new byte[1], completeWrites: !extraCall); if (extraCall) { - clientStream.CompleteWrites(); + await clientStream.CompleteWritesAsync(); } }, async serverStream => @@ -1237,12 +1236,23 @@ async ValueTask ReleaseOnReadsClosedAsync() private const int BufferPlusPayload = 64 * 1024 + 1; private const int BigPayload = 1024 * 1024 * 1024; - public static IEnumerable PayloadSizeAndTwoBools() + public static IEnumerable BigPayloadSizeAndTwoBools() { var boolValues = new[] { true, false }; var payloadValues = !PlatformDetection.IsInHelix ? - new[] { SmallestPayload, SmallPayload, BufferPayload, BufferPlusPayload, BigPayload } : - new[] { SmallestPayload, SmallPayload, BufferPayload, BufferPlusPayload }; + new[] { BufferPlusPayload, BigPayload } : + new[] { BufferPlusPayload }; + return + from payload in payloadValues + from bool1 in boolValues + from bool2 in boolValues + select new object[] { payload, bool1, bool2 }; + } + + public static IEnumerable PayloadSizeAndTwoBools() + { + var boolValues = new[] { true, false }; + var payloadValues = new[] { SmallestPayload, SmallPayload, BufferPayload }; return from payload in payloadValues from bool1 in boolValues @@ -1265,13 +1275,10 @@ await RunClientServer( await using QuicStream stream = await connection.AcceptInboundStreamAsync(); await stream.WriteAsync(new byte[payloadSize], completeWrites: true); - // Make sure the data gets received by the peer if we expect the reading side to get buffered including FIN. - if (payloadSize <= BufferPayload) - { - await stream.WritesClosed; - } + // Make sure the data gets received by the peer as we expect the reading side to get buffered including FIN. + await stream.WritesClosed; - var _ = await stream.ReadAsync(new byte[0]); + int _ = await stream.ReadAsync(new byte[0]); serverSem.Release(); await clientSem.WaitAsync(); @@ -1299,12 +1306,10 @@ await RunClientServer( await using QuicStream stream = await connection.OpenOutboundStreamAsync(QuicStreamType.Bidirectional); await stream.WriteAsync(new byte[payloadSize], completeWrites: true); - if (payloadSize <= BufferPayload) - { - await stream.WritesClosed; - } + // Make sure the data gets received by the peer as we expect the reading side to get buffered including FIN. + await stream.WritesClosed; - var _ = await stream.ReadAsync(new byte[0]); + int _ = await stream.ReadAsync(new byte[0]); clientSem.Release(); await serverSem.WaitAsync(); @@ -1330,31 +1335,101 @@ await RunClientServer( async ValueTask CheckReadsClosed(QuicStream stream, QuicError expectedError, long expectedErrorCode) { // All data should be buffered if they fit in the internal buffer, reading should still pass. - if (payloadSize <= BufferPayload) - { - Assert.False(stream.ReadsClosed.IsCompleted); - var buffer = new byte[BufferPayload]; - var length = await ReadAll(stream, buffer); - Assert.True(stream.ReadsClosed.IsCompletedSuccessfully); - Assert.Equal(payloadSize, length); - } - else + Assert.False(stream.ReadsClosed.IsCompleted); + var buffer = new byte[BufferPayload]; + var length = await ReadAll(stream, buffer); + Assert.True(stream.ReadsClosed.IsCompletedSuccessfully); + Assert.Equal(payloadSize, length); + } + } + + [Theory] + [MemberData(nameof(BigPayloadSizeAndTwoBools))] + public async Task ReadsClosedFinishes_ConnectionClose_BigData(int payloadSize, bool closeServer, bool useDispose) + { + using SemaphoreSlim serverSem = new SemaphoreSlim(0); + using SemaphoreSlim clientSem = new SemaphoreSlim(0); + + await RunClientServer( + serverFunction: async connection => { - var ex = await AssertThrowsQuicExceptionAsync(expectedError, () => stream.ReadsClosed); - if (expectedError == QuicError.OperationAborted) + QuicError expectedError = QuicError.ConnectionAborted; + long expectedErrorCode = DefaultCloseErrorCodeClient; + + await using QuicStream stream = await connection.AcceptInboundStreamAsync(); + await stream.WriteAsync(new byte[payloadSize]); + ValueTask writesCompleted = stream.CompleteWritesAsync(); + + int _ = await stream.ReadAsync(new byte[0]); + + serverSem.Release(); + await clientSem.WaitAsync(); + + if (closeServer) { - Assert.Null(ex.ApplicationErrorCode); + expectedError = QuicError.OperationAborted; + expectedErrorCode = DefaultCloseErrorCodeServer; + if (useDispose) + { + await connection.DisposeAsync(); + } + else + { + await connection.CloseAsync(DefaultCloseErrorCodeServer); + } } - else + + await CheckReadsClosed(stream, expectedError, expectedErrorCode); + }, + clientFunction: async connection => + { + QuicError expectedError = QuicError.ConnectionAborted; + long expectedErrorCode = DefaultCloseErrorCodeServer; + + await using QuicStream stream = await connection.OpenOutboundStreamAsync(QuicStreamType.Bidirectional); + await stream.WriteAsync(new byte[payloadSize]); + ValueTask writesCompleted = stream.CompleteWritesAsync(); + + int _ = await stream.ReadAsync(new byte[0]); + + clientSem.Release(); + await serverSem.WaitAsync(); + + if (!closeServer) { - Assert.Equal(expectedErrorCode, ex.ApplicationErrorCode); + expectedError = QuicError.OperationAborted; + expectedErrorCode = DefaultCloseErrorCodeClient; + if (useDispose) + { + await connection.DisposeAsync(); + } + else + { + await connection.CloseAsync(DefaultCloseErrorCodeClient); + } } + + await CheckReadsClosed(stream, expectedError, expectedErrorCode); + } + ); + + async ValueTask CheckReadsClosed(QuicStream stream, QuicError expectedError, long expectedErrorCode) + { + var ex = await AssertThrowsQuicExceptionAsync(expectedError, () => stream.ReadsClosed); + if (expectedError == QuicError.OperationAborted) + { + Assert.Null(ex.ApplicationErrorCode); + } + else + { + Assert.Equal(expectedErrorCode, ex.ApplicationErrorCode); } } } [Theory] [MemberData(nameof(PayloadSizeAndTwoBools))] + [MemberData(nameof(BigPayloadSizeAndTwoBools))] public async Task WritesClosedFinishes_ConnectionClose(int payloadSize, bool closeServer, bool useDispose) { using SemaphoreSlim serverSem = new SemaphoreSlim(0); diff --git a/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicTestBase.cs b/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicTestBase.cs index c3e0e4e7372ab..0f24c74882fae 100644 --- a/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicTestBase.cs +++ b/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicTestBase.cs @@ -320,8 +320,6 @@ await RunClientServer( await stream.WriteAsync(buffer); await clientFunction(stream); - - stream.CompleteWrites(); }, serverFunction: async connection => { @@ -329,8 +327,6 @@ await RunClientServer( Assert.Equal(1, await stream.ReadAsync(buffer)); await serverFunction(stream); - - stream.CompleteWrites(); }, iterations, millisecondsTimeout