Skip to content

Commit

Permalink
Handle concurrent reads and concurrent writes on MsQuicStream (#67329)
Browse files Browse the repository at this point in the history
* Throw on concurrent operations

* Remove redundant test

* Move _state.SendState access under a lock

* Fix data race on ReceiveResettableCompletionSource

* Remove unnecessary override

* Move task completion outside of the lock

* Fix failing tests

* Update comments

* Add missing case

* fixup! Add missing case
  • Loading branch information
rzikm authored Apr 27, 2022
1 parent f31e0c2 commit a077f27
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 41 deletions.
3 changes: 3 additions & 0 deletions src/libraries/System.Net.Quic/src/Resources/Strings.resx
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,9 @@
<data name="net_quic_empty_cipher_suite" xml:space="preserve">
<value>CipherSuitePolicy must specify at least one cipher supported by QUIC.</value>
</data>
<data name="net_io_invalidnestedcall" xml:space="preserve">
<value> This method may not be called when another {0} operation is pending.</value>
</data>
<!-- Referenced in shared IPEndPointExtensions.cs-->
<data name="net_InvalidAddressFamily" xml:space="preserve">
<value>The AddressFamily {0} is not valid for the {1} end point, use {2} instead.</value>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,11 @@ private CancellationTokenRegistration SetupWriteStartState(bool emptyBuffer, Can
throw GetConnectionAbortedException(_state);
}

if (_state.SendState == SendState.Pending || _state.SendState == SendState.Finished)
{
throw new InvalidOperationException(SR.Format(SR.net_io_invalidnestedcall, "write"));
}

// Change the state in the same lock where we check for final states to prevent coming back from Aborted/ConnectionClosed.
Debug.Assert(_state.SendState != SendState.Pending);
_state.SendState = emptyBuffer ? SendState.Finished : SendState.Pending;
Expand Down Expand Up @@ -420,7 +425,7 @@ private void CleanupWriteFailedState()
}
}

internal override ValueTask<int> ReadAsync(Memory<byte> destination, CancellationToken cancellationToken = default)
internal override async ValueTask<int> ReadAsync(Memory<byte> destination, CancellationToken cancellationToken = default)
{
ThrowIfDisposed();

Expand All @@ -446,9 +451,9 @@ internal override ValueTask<int> ReadAsync(Memory<byte> destination, Cancellatio
abortError = _state.ReadErrorCode;

// Failure scenario: pre-canceled token. Transition: Any non-final -> Aborted
// PendingRead state indicates there is another concurrent read operation in flight
// PendingRead or PendingReadFinished state indicates there is another concurrent read operation in flight
// which is forbidden, so it is handled separately
if (initialReadState != ReadState.PendingRead && cancellationToken.IsCancellationRequested)
if (initialReadState != ReadState.PendingRead && initialReadState != ReadState.PendingReadFinished && cancellationToken.IsCancellationRequested)
{
initialReadState = ReadState.Aborted;
CleanupReadStateAndCheckPending(_state, ReadState.Aborted);
Expand All @@ -458,7 +463,7 @@ internal override ValueTask<int> ReadAsync(Memory<byte> destination, Cancellatio
// Success scenario: EOS already reached, completing synchronously. No transition (final state)
if (initialReadState == ReadState.ReadsCompleted)
{
return new ValueTask<int>(0);
return 0;
}

// Success scenario: no data available yet, will return a task to wait on. Transition None->PendingRead
Expand Down Expand Up @@ -492,8 +497,6 @@ internal override ValueTask<int> ReadAsync(Memory<byte> destination, Cancellatio
{
_state.ReceiveCancellationRegistration = default;
}

return _state.ReceiveResettableCompletionSource.GetValueTask();
}

// Success scenario: data already available, completing synchronously.
Expand All @@ -517,6 +520,23 @@ internal override ValueTask<int> ReadAsync(Memory<byte> destination, Cancellatio
}
}

if (initialReadState == ReadState.None)
{
// wait for the incoming data to finish the read.
bytesRead = await _state.ReceiveResettableCompletionSource.GetValueTask().ConfigureAwait(false);

// Reset the read state
lock (_state)
{
if (_state.ReadState == ReadState.PendingReadFinished)
{
_state.ReadState = ReadState.None;
}
}

return bytesRead;
}

// methods below need to be called outside of the lock
if (bytesRead > -1)
{
Expand All @@ -527,7 +547,7 @@ internal override ValueTask<int> ReadAsync(Memory<byte> destination, Cancellatio
EnableReceive();
}

return new ValueTask<int>(bytesRead);
return bytesRead;
}

// All success scenarios returned at this point. Failure scenarios below:
Expand All @@ -537,7 +557,8 @@ internal override ValueTask<int> ReadAsync(Memory<byte> destination, Cancellatio
switch (initialReadState)
{
case ReadState.PendingRead:
ex = new InvalidOperationException("Only one read is supported at a time.");
case ReadState.PendingReadFinished:
ex = new InvalidOperationException(SR.Format(SR.net_io_invalidnestedcall, "read"));
break;
case ReadState.Aborted:
ex = preCanceled ? new OperationCanceledException(cancellationToken) :
Expand All @@ -550,7 +571,7 @@ internal override ValueTask<int> ReadAsync(Memory<byte> destination, Cancellatio
break;
}

return ValueTask.FromException<int>(ExceptionDispatchInfo.SetCurrentStackTrace(ex!));
throw ex;
}

/// <returns>The number of bytes copied.</returns>
Expand Down Expand Up @@ -925,7 +946,7 @@ private static unsafe uint NativeCallbackHandler(
return HandleEventStartComplete(state, ref *streamEvent);
// Received data on the stream
case QUIC_STREAM_EVENT_TYPE.RECEIVE:
return HandleEventRecv(state, ref *streamEvent);
return HandleEventReceive(state, ref *streamEvent);
// Send has completed.
// Contains a canceled bool to indicate if the send was canceled.
case QUIC_STREAM_EVENT_TYPE.SEND_COMPLETE:
Expand Down Expand Up @@ -964,7 +985,7 @@ private static unsafe uint NativeCallbackHandler(
}
}

private static unsafe uint HandleEventRecv(State state, ref StreamEvent evt)
private static unsafe uint HandleEventReceive(State state, ref StreamEvent evt)
{
ref StreamEventDataReceive receiveEvent = ref evt.Data.Receive;

Expand All @@ -980,8 +1001,12 @@ private static unsafe uint HandleEventRecv(State state, ref StreamEvent evt)
{
switch (state.ReadState)
{
// ReadAsync() hasn't been called yet.
case ReadState.None:
// ReadAsync() hasn't been called yet. Stash the buffer so the next ReadAsync call completes synchronously.
// A pending read has just been finished, and this is a second event in a row (before reading thread
// managed to clear the state)
case ReadState.PendingReadFinished:
// Stash the buffer so the next ReadAsync call completes synchronously.

// We are overwriting state.ReceiveQuicBuffers here even if we only partially consumed them
// and it is intended, because unconsumed data will arrive again from the point we've stopped.
Expand Down Expand Up @@ -1034,7 +1059,8 @@ private static unsafe uint HandleEventRecv(State state, ref StreamEvent evt)
state.ReceiveCancellationRegistration.Unregister();
shouldComplete = true;
state.Stream = null;
state.ReadState = ReadState.None;
state.ReadState = ReadState.PendingReadFinished;
// state.ReadState will be set to None later once the ReceiveResettableCompletionSource is awaited.

readLength = CopyMsQuicBuffersToUserBuffer(new ReadOnlySpan<QuicBuffer>(receiveEvent.Buffers, (int)receiveEvent.BufferCount), state.ReceiveUserBuffer.Span);

Expand All @@ -1049,18 +1075,18 @@ private static unsafe uint HandleEventRecv(State state, ref StreamEvent evt)
break;

default:
Debug.Assert(state.ReadState is ReadState.Aborted or ReadState.ConnectionClosed, $"Unexpected {nameof(ReadState)} '{state.ReadState}' in {nameof(HandleEventRecv)}.");
Debug.Assert(state.ReadState is ReadState.Aborted or ReadState.ConnectionClosed, $"Unexpected {nameof(ReadState)} '{state.ReadState}' in {nameof(HandleEventReceive)}.");

// There was a race between a user aborting the read stream and the callback being ran.
// This will eat any received data.
return MsQuicStatusCodes.Success;
}
}

// We're completing a pending read.
if (shouldComplete)
{
state.ReceiveResettableCompletionSource.Complete(readLength);
// _state.ReadState will be reset to None on the reading thread.
}

// Returning Success when the entire buffer hasn't been consumed will cause MsQuic to disable further receive events until EnableReceive() is called.
Expand Down Expand Up @@ -1634,10 +1660,12 @@ private static bool CleanupReadStateAndCheckPending(State state, ReadState final
// IndividualReadComplete(+FIN) --(user calls ReadAsync() & consumes only partial data)-> None
// IndividualReadComplete(+FIN) --(user calls ReadAsync() & consumes full data)-> ReadsCompleted
//
// PendingRead --(data arrives in event RECV & completes user's ReadAsync())-> None
// PendingRead --(data arrives in event RECV with FIN flag & completes user's ReadAsync() with only partial data)-> None
// PendingRead --(data arrives in event RECV & completes user's ReadAsync())-> PendingReadFinished
// PendingRead --(data arrives in event RECV with FIN flag & completes user's ReadAsync() with only partial data)-> PendingReadFinished
// PendingRead --(data arrives in event RECV with FIN flag & completes user's ReadAsync() with full data)-> ReadsCompleted
//
// PendingReadFinished --(reading thread awaits ReceiveResettableCompletionSource)-> None
//
// Any non-final state --(event PEER_SEND_SHUTDOWN or SHUTDOWN_COMPLETED with ConnectionClosed=false)-> ReadsCompleted
// Any non-final state --(event PEER_SEND_ABORT)-> Aborted
// Any non-final state --(user calls AbortRead())-> Aborted
Expand All @@ -1662,6 +1690,11 @@ private enum ReadState
/// </summary>
PendingRead,

/// <summary>
/// Read was completed from the MsQuic callback.
/// </summary>
PendingReadFinished,

// following states are terminal:

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -531,30 +531,6 @@ public async Task ReadOutstanding_ReadAborted_Throws()
}
}

[Fact]
public async Task Read_ConcurrentReads_Throws()
{
using SemaphoreSlim sem = new SemaphoreSlim(0);

await RunBidirectionalClientServer(
async clientStream =>
{
await sem.WaitAsync();
},
async serverStream =>
{
ValueTask<int> readTask = serverStream.ReadAsync(new byte[1]);
Assert.False(readTask.IsCompleted);
await Assert.ThrowsAsync<InvalidOperationException>(async () => await serverStream.ReadAsync(new byte[1]));
sem.Release();
int res = await readTask;
Assert.Equal(0, res);
});
}

[Fact]
public async Task WriteAbortedWithoutWriting_ReadThrows()
{
Expand Down

0 comments on commit a077f27

Please sign in to comment.