Skip to content

Commit

Permalink
[QUIC] Add ShutdownCompleted method and fix receive and shutdown beha…
Browse files Browse the repository at this point in the history
…vior in tests (#50930)

This PR fixed QuicStreamTests.BasicTest and QuicStreamTests.MultipleReadsAndWrites.

Contributes to #49157
  • Loading branch information
CarnaViire authored Apr 19, 2021
1 parent 8f4a12d commit 01d4dca
Show file tree
Hide file tree
Showing 14 changed files with 336 additions and 251 deletions.
1 change: 1 addition & 0 deletions src/libraries/System.Net.Quic/ref/System.Net.Quic.cs
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ public override void Flush() { }
public override void SetLength(long value) { }
public void Shutdown() { }
public System.Threading.Tasks.ValueTask ShutdownWriteCompleted(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
public System.Threading.Tasks.ValueTask ShutdownCompleted(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
public override void Write(byte[] buffer, int offset, int count) { }
public override void Write(System.ReadOnlySpan<byte> buffer) { }
public System.Threading.Tasks.ValueTask WriteAsync(System.Buffers.ReadOnlySequence<byte> buffers, bool endStream, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,12 @@ internal override ValueTask CloseAsync(long errorCode, CancellationToken cancell
ConnectionState? state = _state;
if (state is not null)
{
if (state._closed)
{
return default;
}
state._closed = true;

if (_isClient)
{
state._clientErrorCode = errorCode;
Expand Down Expand Up @@ -272,6 +278,7 @@ internal sealed class ConnectionState
public Channel<MockStream.StreamState> _serverInitiatedStreamChannel;
public long _clientErrorCode;
public long _serverErrorCode;
public bool _closed;

public ConnectionState(SslApplicationProtocol applicationProtocol)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,14 @@ internal override ValueTask ShutdownWriteCompleted(CancellationToken cancellatio
return default;
}


internal override ValueTask ShutdownCompleted(CancellationToken cancellationToken = default)
{
CheckDisposed();

return default;
}

internal override void Shutdown()
{
CheckDisposed();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -361,8 +361,6 @@ private ValueTask ShutdownAsync(
QUIC_CONNECTION_SHUTDOWN_FLAGS Flags,
long ErrorCode)
{
Debug.Assert(!_state.ShutdownTcs.Task.IsCompleted);

// Store the connection into the GCHandle'd state to prevent GC if user calls ShutdownAsync and gets rid of all references to the MsQuicConnection.
Debug.Assert(_state.Connection == null);
_state.Connection = this;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System.Collections.Generic;
using System.Diagnostics;
using System.Net.Quic.Implementations.MsQuic.Internal;
using System.Runtime.ExceptionServices;
using System.Runtime.InteropServices;
using System.Threading;
using System.Threading.Tasks;
Expand Down Expand Up @@ -62,6 +63,12 @@ private sealed class State

// Set once writes have been shutdown.
public readonly TaskCompletionSource ShutdownWriteCompletionSource = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);


public ShutdownState ShutdownState;

// Set once stream have been shutdown.
public readonly TaskCompletionSource ShutdownCompletionSource = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
}

// inbound.
Expand Down Expand Up @@ -199,23 +206,25 @@ private async ValueTask<CancellationTokenRegistration> HandleWriteStartState(Can
}
}

CancellationTokenRegistration registration = cancellationToken.Register(() =>
CancellationTokenRegistration registration = cancellationToken.UnsafeRegister(static (s, token) =>
{
var state = (State)s!;
bool shouldComplete = false;
lock (_state)
lock (state)
{
if (_state.SendState == SendState.None)
if (state.SendState == SendState.None)
{
_state.SendState = SendState.Aborted;
state.SendState = SendState.Aborted;
shouldComplete = true;
}
}

if (shouldComplete)
{
_state.SendResettableCompletionSource.CompleteException(new OperationCanceledException("Write was canceled", cancellationToken));
state.SendResettableCompletionSource.CompleteException(
ExceptionDispatchInfo.SetCurrentStackTrace(new OperationCanceledException("Write was canceled", token)));
}
});
}, _state);

// Make sure start has completed
if (!_started)
Expand Down Expand Up @@ -268,24 +277,26 @@ internal override async ValueTask<int> ReadAsync(Memory<byte> destination, Cance
}
}

using CancellationTokenRegistration registration = cancellationToken.Register(() =>
using CancellationTokenRegistration registration = cancellationToken.UnsafeRegister(static (s, token) =>
{
var state = (State)s!;
bool shouldComplete = false;
lock (_state)
lock (state)
{
if (_state.ReadState == ReadState.None)
if (state.ReadState == ReadState.None)
{
shouldComplete = true;
}

_state.ReadState = ReadState.Aborted;
state.ReadState = ReadState.Aborted;
}

if (shouldComplete)
{
_state.ReceiveResettableCompletionSource.CompleteException(new OperationCanceledException("Read was canceled", cancellationToken));
state.ReceiveResettableCompletionSource.CompleteException(
ExceptionDispatchInfo.SetCurrentStackTrace(new OperationCanceledException("Read was canceled", token)));
}
});
}, _state);

// TODO there could potentially be a perf gain by storing the buffer from the inital read
// This reduces the amount of async calls, however it makes it so MsQuic holds onto the buffers
Expand Down Expand Up @@ -358,7 +369,8 @@ internal override void AbortWrite(long errorCode)

if (shouldComplete)
{
_state.ShutdownWriteCompletionSource.SetException(new QuicStreamAbortedException("Shutdown was aborted.", errorCode));
_state.ShutdownWriteCompletionSource.SetException(
ExceptionDispatchInfo.SetCurrentStackTrace(new QuicStreamAbortedException("Shutdown was aborted.", errorCode)));
}

StartShutdown(QUIC_STREAM_SHUTDOWN_FLAGS.ABORT_SEND, errorCode);
Expand All @@ -375,30 +387,61 @@ internal override async ValueTask ShutdownWriteCompleted(CancellationToken cance
ThrowIfDisposed();

// TODO do anything to stop writes?
using CancellationTokenRegistration registration = cancellationToken.Register(() =>
using CancellationTokenRegistration registration = cancellationToken.UnsafeRegister(static (s, token) =>
{
var state = (State)s!;
bool shouldComplete = false;
lock (_state)
lock (state)
{
if (_state.ShutdownWriteState == ShutdownWriteState.None)
if (state.ShutdownWriteState == ShutdownWriteState.None)
{
_state.ShutdownWriteState = ShutdownWriteState.Canceled;
state.ShutdownWriteState = ShutdownWriteState.Canceled; // TODO: should we separate states for cancelling here vs calling Abort?
shouldComplete = true;
}
}

if (shouldComplete)
{
_state.ShutdownWriteCompletionSource.SetException(new OperationCanceledException("Shutdown was canceled", cancellationToken));
state.ShutdownWriteCompletionSource.SetException(
ExceptionDispatchInfo.SetCurrentStackTrace(new OperationCanceledException("Wait for shutdown write was canceled", token)));
}
});
}, _state);

await _state.ShutdownWriteCompletionSource.Task.ConfigureAwait(false);
}

internal override async ValueTask ShutdownCompleted(CancellationToken cancellationToken = default)
{
ThrowIfDisposed();

// TODO do anything to stop writes?
using CancellationTokenRegistration registration = cancellationToken.UnsafeRegister(static (s, token) =>
{
var state = (State)s!;
bool shouldComplete = false;
lock (state)
{
if (state.ShutdownState == ShutdownState.None)
{
state.ShutdownState = ShutdownState.Canceled;
shouldComplete = true;
}
}

if (shouldComplete)
{
state.ShutdownWriteCompletionSource.SetException(
ExceptionDispatchInfo.SetCurrentStackTrace(new OperationCanceledException("Wait for shutdown was canceled", token)));
}
}, _state);

await _state.ShutdownCompletionSource.Task.ConfigureAwait(false);
}

internal override void Shutdown()
{
ThrowIfDisposed();
// it is ok to send shutdown several times, MsQuic will ignore it
StartShutdown(QUIC_STREAM_SHUTDOWN_FLAGS.GRACEFUL, errorCode: 0);
}

Expand Down Expand Up @@ -481,6 +524,11 @@ private static uint NativeCallbackHandler(

private static uint HandleEvent(State state, ref StreamEvent evt)
{
if (NetEventSource.Log.IsEnabled())
{
NetEventSource.Info(state, $"[{state.GetHashCode()}] received event {evt.Type}");
}

try
{
switch ((QUIC_STREAM_EVENT_TYPE)evt.Type)
Expand Down Expand Up @@ -564,7 +612,8 @@ private static uint HandleEventPeerRecvAborted(State state, ref StreamEvent evt)

if (shouldComplete)
{
state.SendResettableCompletionSource.CompleteException(new QuicStreamAbortedException(state.SendErrorCode));
state.SendResettableCompletionSource.CompleteException(
ExceptionDispatchInfo.SetCurrentStackTrace(new QuicStreamAbortedException(state.SendErrorCode)));
}

return MsQuicStatusCodes.Success;
Expand Down Expand Up @@ -604,7 +653,7 @@ private static uint HandleEventSendShutdownComplete(State state, ref StreamEvent

if (shouldComplete)
{
state.ShutdownWriteCompletionSource.TrySetResult();
state.ShutdownWriteCompletionSource.SetResult();
}

return MsQuicStatusCodes.Success;
Expand All @@ -614,6 +663,7 @@ private static uint HandleEventShutdownComplete(State state)
{
bool shouldReadComplete = false;
bool shouldShutdownWriteComplete = false;
bool shouldShutdownComplete = false;

lock (state)
{
Expand All @@ -632,6 +682,12 @@ private static uint HandleEventShutdownComplete(State state)
state.ShutdownWriteState = ShutdownWriteState.Finished;
shouldShutdownWriteComplete = true;
}

if (state.ShutdownState == ShutdownState.None)
{
state.ShutdownState = ShutdownState.Finished;
shouldShutdownComplete = true;
}
}

if (shouldReadComplete)
Expand All @@ -641,7 +697,12 @@ private static uint HandleEventShutdownComplete(State state)

if (shouldShutdownWriteComplete)
{
state.ShutdownWriteCompletionSource.TrySetResult();
state.ShutdownWriteCompletionSource.SetResult();
}

if (shouldShutdownComplete)
{
state.ShutdownCompletionSource.SetResult();
}

return MsQuicStatusCodes.Success;
Expand All @@ -662,7 +723,8 @@ private static uint HandleEventPeerSendAborted(State state, ref StreamEvent evt)

if (shouldComplete)
{
state.ReceiveResettableCompletionSource.CompleteException(new QuicStreamAbortedException(state.ReadErrorCode));
state.ReceiveResettableCompletionSource.CompleteException(
ExceptionDispatchInfo.SetCurrentStackTrace(new QuicStreamAbortedException(state.ReadErrorCode)));
}

return MsQuicStatusCodes.Success;
Expand Down Expand Up @@ -942,6 +1004,13 @@ private enum ShutdownWriteState
Finished
}

private enum ShutdownState
{
None,
Canceled,
Finished
}

private enum SendState
{
None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ internal abstract class QuicStreamProvider : IDisposable, IAsyncDisposable

internal abstract ValueTask ShutdownWriteCompleted(CancellationToken cancellationToken = default);

internal abstract ValueTask ShutdownCompleted(CancellationToken cancellationToken = default);

internal abstract void Shutdown();

internal abstract void Flush();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System.Net.Quic.Implementations;
using System.Net.Quic.Implementations.MsQuic.Internal;
using System.Net.Security;
using System.Threading;
using System.Threading.Tasks;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ public override Task WriteAsync(byte[] buffer, int offset, int count, Cancellati

public ValueTask ShutdownWriteCompleted(CancellationToken cancellationToken = default) => _provider.ShutdownWriteCompleted(cancellationToken);

public ValueTask ShutdownCompleted(CancellationToken cancellationToken = default) => _provider.ShutdownCompleted(cancellationToken);

public void Shutdown() => _provider.Shutdown();

protected override void Dispose(bool disposing)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

using System.Collections.Generic;
using System.Net.Security;
using System.Threading;
using System.Threading.Tasks;

namespace System.Net.Quic.Tests
Expand Down Expand Up @@ -50,18 +51,27 @@ internal async Task RunClientServer(Func<QuicConnection, Task> clientFunction, F
{
using QuicListener listener = CreateQuicListener();

var serverFinished = new ManualResetEventSlim();
var clientFinished = new ManualResetEventSlim();

await new[]
{
Task.Run(async () =>
{
using QuicConnection serverConnection = await listener.AcceptConnectionAsync();
await serverFunction(serverConnection);
serverFinished.Set();
clientFinished.Wait();
await serverConnection.CloseAsync(0);
}),
Task.Run(async () =>
{
using QuicConnection clientConnection = CreateQuicConnection(listener.ListenEndPoint);
await clientConnection.ConnectAsync();
await clientFunction(clientConnection);
clientFinished.Set();
serverFinished.Wait();
await clientConnection.CloseAsync(0);
})
}.WhenAllOrAnyFailed(millisecondsTimeout);
}
Expand Down
Loading

0 comments on commit 01d4dca

Please sign in to comment.