Skip to content

Commit

Permalink
[WinHTTP] Make concurrent IO check thread safe (#111750)
Browse files Browse the repository at this point in the history
* Check for concurrent IO is now thread safe.

* Add test

* Use compare exchange, fix cancellation token registration

* Typo

* Fixed unit tests.

* Test exception narrow down

* Narrow down the exception

* Don't dispose handle outside of lock
  • Loading branch information
ManickaP authored Jan 27, 2025
1 parent 2e42b6e commit c47a53d
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -941,7 +941,14 @@ private async Task StartRequestAsync(WinHttpRequestState state)
// will have the side-effect of WinHTTP cancelling any pending I/O and accelerating its callbacks
// on the handle and thus releasing the awaiting tasks in the loop below. This helps to provide
// a more timely, cooperative, cancellation pattern.
using (state.CancellationToken.Register(s => ((WinHttpRequestState)s!).RequestHandle!.Dispose(), state))
using (state.CancellationToken.Register(static s =>
{
var state = (WinHttpRequestState)s!;
lock (state.Lock)
{
state.RequestHandle?.Dispose();
}
}, state))
{
do
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ public WinHttpTransportContext TransportContext

public RendezvousAwaitable<int> LifecycleAwaitable { get; set; } = new RendezvousAwaitable<int>();
public TaskCompletionSource<bool>? TcsInternalWriteDataToRequestStream { get; set; }
public bool AsyncReadInProgress { get; set; }
public volatile int AsyncReadInProgress;

// WinHttpResponseStream state.
public long? ExpectedBytesToRead { get; set; }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ public static HttpResponseMessage CreateResponseMessage(

// Create response stream and wrap it in a StreamContent object.
var responseStream = new WinHttpResponseStream(requestHandle, state, response);
state.RequestHandle = null; // ownership successfully transferred to WinHttpResponseStram.
state.RequestHandle = null; // ownership successfully transferred to WinHttpResponseStream.
Stream decompressedStream = responseStream;

if (manuallyProcessedDecompressionMethods != DecompressionMethods.None)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,12 +94,6 @@ public override Task CopyToAsync(Stream destination, int bufferSize, Cancellatio
// Validate arguments as would base CopyToAsync
StreamHelpers.ValidateCopyToArgs(this, destination, bufferSize);

// Check that there are no other pending read operations
if (_state.AsyncReadInProgress)
{
throw new InvalidOperationException(SR.net_http_no_concurrent_io_allowed);
}

// Early check for cancellation
if (cancellationToken.IsCancellationRequested)
{
Expand All @@ -112,11 +106,15 @@ public override Task CopyToAsync(Stream destination, int bufferSize, Cancellatio

private async Task CopyToAsyncCore(Stream destination, byte[] buffer, CancellationToken cancellationToken)
{
_state.PinReceiveBuffer(buffer);
CancellationTokenRegistration ctr = cancellationToken.Register(s => ((WinHttpResponseStream)s!).CancelPendingResponseStreamReadOperation(), this);
_state.AsyncReadInProgress = true;
// Check that there are no other pending read operations
if (Interlocked.CompareExchange(ref _state.AsyncReadInProgress, 1, 0) == 1)
{
throw new InvalidOperationException(SR.net_http_no_concurrent_io_allowed);
}
try
{
using var ctr = cancellationToken.Register(s => ((WinHttpResponseStream)s!).CancelPendingResponseStreamReadOperation(), this);
_state.PinReceiveBuffer(buffer);
// Loop until there's no more data to be read
while (true)
{
Expand Down Expand Up @@ -163,8 +161,7 @@ private async Task CopyToAsyncCore(Stream destination, byte[] buffer, Cancellati
}
finally
{
_state.AsyncReadInProgress = false;
ctr.Dispose();
_state.AsyncReadInProgress = 0;
ArrayPool<byte>.Shared.Return(buffer);
}

Expand Down Expand Up @@ -201,11 +198,6 @@ public override Task<int> ReadAsync(byte[] buffer, int offset, int count, Cancel

CheckDisposed();

if (_state.AsyncReadInProgress)
{
throw new InvalidOperationException(SR.net_http_no_concurrent_io_allowed);
}

return ReadAsyncCore(buffer, offset, count, token);
}

Expand All @@ -221,12 +213,15 @@ private async Task<int> ReadAsyncCore(byte[] buffer, int offset, int count, Canc
{
return 0;
}

_state.PinReceiveBuffer(buffer);
var ctr = token.Register(s => ((WinHttpResponseStream)s!).CancelPendingResponseStreamReadOperation(), this);
_state.AsyncReadInProgress = true;
// Check that there are no other pending read operations
if (Interlocked.CompareExchange(ref _state.AsyncReadInProgress, 1, 0) == 1)
{
throw new InvalidOperationException(SR.net_http_no_concurrent_io_allowed);
}
try
{
using var ctr = token.Register(s => ((WinHttpResponseStream)s!).CancelPendingResponseStreamReadOperation(), this);
_state.PinReceiveBuffer(buffer);
lock (_state.Lock)
{
Debug.Assert(!_requestHandle.IsInvalid);
Expand Down Expand Up @@ -262,8 +257,7 @@ private async Task<int> ReadAsyncCore(byte[] buffer, int offset, int count, Canc
}
finally
{
_state.AsyncReadInProgress = false;
ctr.Dispose();
_state.AsyncReadInProgress = 0;
}
}

Expand Down Expand Up @@ -357,7 +351,7 @@ private void CancelPendingResponseStreamReadOperation()
{
lock (_state.Lock)
{
if (_state.AsyncReadInProgress)
if (_state.AsyncReadInProgress == 1)
{
if (NetEventSource.Log.IsEnabled()) NetEventSource.Info("before dispose");
_requestHandle?.Dispose(); // null check necessary to handle race condition between stream disposal and cancellation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,29 @@ public void SendAsync_SimpleGet_Success()
}
}

[OuterLoop("Uses external server.")]
[Fact]
public async Task GetAsync_ConcurrentRead_ThrowsInvalidOperationException()
{
using var client = new HttpClient(new WinHttpHandler());
using var response = await client.GetAsync("https://httpbin.org/stream-bytes/4096", HttpCompletionOption.ResponseHeadersRead);
using var stream = await response.Content.ReadAsStreamAsync();
var tasks = new Task[1_000];
for (int i = 0; i < tasks.Length; ++i)
{
tasks[i] = Task.Run(async () =>
{
try
{
await stream.ReadAsync(new byte[5]);
}
catch (InvalidOperationException ioe) when (ioe.Message.Contains("concurrent I/O")) // Expected exception for concurrent IO
{ }
});
}
await Task.WhenAll(tasks);
}

[OuterLoop]
[Theory]
[InlineData(CookieUsePolicy.UseInternalCookieStoreOnly, "cookieName1", "cookieValue1")]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ public void ReadAsync_PriorReadInProgress_ThrowsInvalidOperationException()
TestControl.WinHttpReadData.Pause();
Task t1 = stream.ReadAsync(new byte[1], 0, 1);

Assert.Throws<InvalidOperationException>(() => { Task t2 = stream.ReadAsync(new byte[1], 0, 1); });
Assert.ThrowsAsync<InvalidOperationException>(() => stream.ReadAsync(new byte[1], 0, 1));

TestControl.WinHttpReadData.Resume();
t1.Wait();
Expand Down

0 comments on commit c47a53d

Please sign in to comment.