diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicStream.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicStream.cs index 4672f9f8be37e..b93c0970a3560 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicStream.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicStream.cs @@ -51,10 +51,9 @@ private sealed class State // Buffers to hold during a call to send. public MemoryHandle[] BufferArrays = new MemoryHandle[1]; - public QuicBuffer[] SendQuicBuffers = new QuicBuffer[1]; - - // Handle to pinned SendQuicBuffers. - public GCHandle SendHandle; + public IntPtr SendQuicBuffers; + public int SendBufferMaxCount; + public int SendBufferCount; // Resettable completions to be used for multiple calls to send, start, and shutdown. public readonly ResettableCompletionSource SendResettableCompletionSource = new ResettableCompletionSource(); @@ -176,14 +175,12 @@ internal override async ValueTask WriteAsync(ReadOnlyMemory using CancellationTokenRegistration registration = await HandleWriteStartState(cancellationToken).ConfigureAwait(false); await SendReadOnlyMemoryListAsync(buffers, endStream ? QUIC_SEND_FLAGS.FIN : QUIC_SEND_FLAGS.NONE).ConfigureAwait(false); - HandleWriteCompletedState(); } internal override async ValueTask WriteAsync(ReadOnlyMemory buffer, bool endStream, CancellationToken cancellationToken = default) { ThrowIfDisposed(); - using CancellationTokenRegistration registration = await HandleWriteStartState(cancellationToken).ConfigureAwait(false); await SendReadOnlyMemoryAsync(buffer, endStream ? QUIC_SEND_FLAGS.FIN : QUIC_SEND_FLAGS.NONE).ConfigureAwait(false); @@ -212,7 +209,7 @@ private async ValueTask HandleWriteStartState(Can bool shouldComplete = false; lock (state) { - if (state.SendState == SendState.None) + if (state.SendState == SendState.None || state.SendState == SendState.Pending) { state.SendState = SendState.Aborted; shouldComplete = true; @@ -240,7 +237,7 @@ private void HandleWriteCompletedState() { lock (_state) { - if (_state.SendState == SendState.Finished) + if (_state.SendState == SendState.Finished || _state.SendState == SendState.Aborted) { _state.SendState = SendState.None; } @@ -501,11 +498,11 @@ private void Dispose(bool disposing) return; } + _disposed = true; _state.Handle.Dispose(); + Marshal.FreeHGlobal(_state.SendQuicBuffers); if (_stateHandle.IsAllocated) _stateHandle.Free(); CleanupSendState(_state); - - _disposed = true; } private void EnableReceive() @@ -602,7 +599,7 @@ private static uint HandleEventPeerRecvAborted(State state, ref StreamEvent evt) bool shouldComplete = false; lock (state) { - if (state.SendState == SendState.None) + if (state.SendState == SendState.None || state.SendState == SendState.Pending) { shouldComplete = true; } @@ -761,7 +758,7 @@ private static uint HandleEventSendComplete(State state, ref StreamEvent evt) lock (state) { - if (state.SendState == SendState.None) + if (state.SendState == SendState.Pending) { state.SendState = SendState.Finished; complete = true; @@ -771,7 +768,6 @@ private static uint HandleEventSendComplete(State state, ref StreamEvent evt) if (complete) { CleanupSendState(state); - // TODO throw if a write was canceled. state.SendResettableCompletionSource.Complete(MsQuicStatusCodes.Success); } @@ -781,15 +777,15 @@ private static uint HandleEventSendComplete(State state, ref StreamEvent evt) private static void CleanupSendState(State state) { - if (state.SendHandle.IsAllocated) + lock (state) { - state.SendHandle.Free(); - } + Debug.Assert(state.SendState != SendState.Pending); + Debug.Assert(state.SendBufferCount <= state.BufferArrays.Length); - // Callings dispose twice on a memory handle should be okay - foreach (MemoryHandle buffer in state.BufferArrays) - { - buffer.Dispose(); + for (int i = 0; i < state.SendBufferCount; i++) + { + state.BufferArrays[i].Dispose(); + } } } @@ -798,6 +794,12 @@ private unsafe ValueTask SendReadOnlyMemoryAsync( ReadOnlyMemory buffer, QUIC_SEND_FLAGS flags) { + lock (_state) + { + Debug.Assert(_state.SendState != SendState.Pending); + _state.SendState = buffer.IsEmpty ? SendState.Finished : SendState.Pending; + } + if (buffer.IsEmpty) { if ((flags & QUIC_SEND_FLAGS.FIN) == QUIC_SEND_FLAGS.FIN) @@ -809,18 +811,22 @@ private unsafe ValueTask SendReadOnlyMemoryAsync( } MemoryHandle handle = buffer.Pin(); - _state.SendQuicBuffers[0].Length = (uint)buffer.Length; - _state.SendQuicBuffers[0].Buffer = (byte*)handle.Pointer; - - _state.BufferArrays[0] = handle; + if (_state.SendQuicBuffers == IntPtr.Zero) + { + _state.SendQuicBuffers = Marshal.AllocHGlobal(sizeof(QuicBuffer)); + _state.SendBufferMaxCount = 1; + } - _state.SendHandle = GCHandle.Alloc(_state.SendQuicBuffers, GCHandleType.Pinned); + QuicBuffer* quicBuffers = (QuicBuffer*)_state.SendQuicBuffers; + quicBuffers->Length = (uint)buffer.Length; + quicBuffers->Buffer = (byte*)handle.Pointer; - var quicBufferPointer = (QuicBuffer*)Marshal.UnsafeAddrOfPinnedArrayElement(_state.SendQuicBuffers, 0); + _state.BufferArrays[0] = handle; + _state.SendBufferCount = 1; uint status = MsQuicApi.Api.StreamSendDelegate( _state.Handle, - quicBufferPointer, + quicBuffers, bufferCount: 1, flags, IntPtr.Zero); @@ -841,6 +847,13 @@ private unsafe ValueTask SendReadOnlySequenceAsync( ReadOnlySequence buffers, QUIC_SEND_FLAGS flags) { + + lock (_state) + { + Debug.Assert(_state.SendState != SendState.Pending); + _state.SendState = buffers.IsEmpty ? SendState.Finished : SendState.Pending; + } + if (buffers.IsEmpty) { if ((flags & QUIC_SEND_FLAGS.FIN) == QUIC_SEND_FLAGS.FIN) @@ -851,38 +864,39 @@ private unsafe ValueTask SendReadOnlySequenceAsync( return default; } - uint count = 0; + int count = 0; foreach (ReadOnlyMemory buffer in buffers) { ++count; } - if (_state.SendQuicBuffers.Length < count) + if (_state.SendBufferMaxCount < count) { - _state.SendQuicBuffers = new QuicBuffer[count]; + Marshal.FreeHGlobal(_state.SendQuicBuffers); + _state.SendQuicBuffers = IntPtr.Zero; + _state.SendQuicBuffers = Marshal.AllocHGlobal(sizeof(QuicBuffer) * count); + _state.SendBufferMaxCount = count; _state.BufferArrays = new MemoryHandle[count]; } + _state.SendBufferCount = count; count = 0; + QuicBuffer* quicBuffers = (QuicBuffer*)_state.SendQuicBuffers; foreach (ReadOnlyMemory buffer in buffers) { MemoryHandle handle = buffer.Pin(); - _state.SendQuicBuffers[count].Length = (uint)buffer.Length; - _state.SendQuicBuffers[count].Buffer = (byte*)handle.Pointer; + quicBuffers[count].Length = (uint)buffer.Length; + quicBuffers[count].Buffer = (byte*)handle.Pointer; _state.BufferArrays[count] = handle; ++count; } - _state.SendHandle = GCHandle.Alloc(_state.SendQuicBuffers, GCHandleType.Pinned); - - var quicBufferPointer = (QuicBuffer*)Marshal.UnsafeAddrOfPinnedArrayElement(_state.SendQuicBuffers, 0); - uint status = MsQuicApi.Api.StreamSendDelegate( _state.Handle, - quicBufferPointer, - count, + quicBuffers, + (uint)count, flags, IntPtr.Zero); @@ -902,6 +916,12 @@ private unsafe ValueTask SendReadOnlyMemoryListAsync( ReadOnlyMemory> buffers, QUIC_SEND_FLAGS flags) { + lock (_state) + { + Debug.Assert(_state.SendState != SendState.Pending); + _state.SendState = buffers.IsEmpty ? SendState.Finished : SendState.Pending; + } + if (buffers.IsEmpty) { if ((flags & QUIC_SEND_FLAGS.FIN) == QUIC_SEND_FLAGS.FIN) @@ -916,28 +936,31 @@ private unsafe ValueTask SendReadOnlyMemoryListAsync( uint length = (uint)array.Length; - if (_state.SendQuicBuffers.Length < length) + if (_state.SendBufferMaxCount < array.Length) { - _state.SendQuicBuffers = new QuicBuffer[length]; - _state.BufferArrays = new MemoryHandle[length]; + Marshal.FreeHGlobal(_state.SendQuicBuffers); + _state.SendQuicBuffers = IntPtr.Zero; + _state.SendQuicBuffers = Marshal.AllocHGlobal(sizeof(QuicBuffer) * array.Length); + _state.SendBufferMaxCount = array.Length; + _state.BufferArrays = new MemoryHandle[array.Length]; } + _state.SendBufferCount = array.Length; + QuicBuffer* quicBuffers = (QuicBuffer*)_state.SendQuicBuffers; for (int i = 0; i < length; i++) { ReadOnlyMemory buffer = array[i]; MemoryHandle handle = buffer.Pin(); - _state.SendQuicBuffers[i].Length = (uint)buffer.Length; - _state.SendQuicBuffers[i].Buffer = (byte*)handle.Pointer; - _state.BufferArrays[i] = handle; - } - _state.SendHandle = GCHandle.Alloc(_state.SendQuicBuffers, GCHandleType.Pinned); + quicBuffers[i].Length = (uint)buffer.Length; + quicBuffers[i].Buffer = (byte*)handle.Pointer; - var quicBufferPointer = (QuicBuffer*)Marshal.UnsafeAddrOfPinnedArrayElement(_state.SendQuicBuffers, 0); + _state.BufferArrays[i] = handle; + } uint status = MsQuicApi.Api.StreamSendDelegate( _state.Handle, - quicBufferPointer, + quicBuffers, length, flags, IntPtr.Zero); @@ -1014,6 +1037,7 @@ private enum ShutdownState private enum SendState { None, + Pending, Aborted, Finished } diff --git a/src/libraries/System.Net.Quic/tests/FunctionalTests/MsQuicTests.cs b/src/libraries/System.Net.Quic/tests/FunctionalTests/MsQuicTests.cs index 3c937baeec5f1..446daf8021d8b 100644 --- a/src/libraries/System.Net.Quic/tests/FunctionalTests/MsQuicTests.cs +++ b/src/libraries/System.Net.Quic/tests/FunctionalTests/MsQuicTests.cs @@ -294,7 +294,6 @@ public BufferSegment Append(ReadOnlyMemory memory) } } - [ActiveIssue("https://github.com/dotnet/runtime/issues/52047")] [Fact] public async Task ByteMixingOrNativeAVE_MinimalFailingTest() { diff --git a/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicStreamTests.cs b/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicStreamTests.cs index 6daf51cfd4d55..b08f93f94486e 100644 --- a/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicStreamTests.cs +++ b/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicStreamTests.cs @@ -173,7 +173,6 @@ public async Task GetStreamIdWithoutStartWorks() await clientConnection.CloseAsync(0); } - [ActiveIssue("https://github.com/dotnet/runtime/issues/52047")] [Fact] public async Task LargeDataSentAndReceived() { @@ -348,7 +347,6 @@ private static async Task SendAndReceiveEOFAsync(QuicStream s1, QuicStream s2) Assert.Equal(0, bytesRead); } - [ActiveIssue("https://github.com/dotnet/runtime/issues/52047")] [Theory] [MemberData(nameof(ReadWrite_Random_Success_Data))] public async Task ReadWrite_Random_Success(int readSize, int writeSize) @@ -434,7 +432,7 @@ await Task.Run(async () => byte[] buffer = new byte[100]; QuicStreamAbortedException ex = await Assert.ThrowsAsync(() => serverStream.ReadAsync(buffer).AsTask()); Assert.Equal(ExpectedErrorCode, ex.ErrorCode); - }).WaitAsync(TimeSpan.FromSeconds(5)); + }).WaitAsync(TimeSpan.FromSeconds(15)); } [ActiveIssue("https://github.com/dotnet/runtime/issues/32050")]