From d10799ffb9803b9b480735ea90a1b2aa1822d87d Mon Sep 17 00:00:00 2001 From: Miha Zupan Date: Tue, 17 Dec 2024 03:38:21 +0100 Subject: [PATCH] Fix race condition when cancelling pending HTTP connection attempts --- .../SocketsHttpHandler/HttpConnectionPool.cs | 28 +++++++-- .../SocketsHttpHandlerTest.Cancellation.cs | 62 +++++++++++++++++++ 2 files changed, 85 insertions(+), 5 deletions(-) diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnectionPool.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnectionPool.cs index 17ed807eeefec4..a4221c9f039dc3 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnectionPool.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnectionPool.cs @@ -492,8 +492,7 @@ private async Task AddHttp11ConnectionAsync(RequestQueue.QueueIt HttpConnection? connection = null; Exception? connectionException = null; - CancellationTokenSource cts = GetConnectTimeoutCancellationTokenSource(); - waiter.ConnectionCancellationTokenSource = cts; + CancellationTokenSource cts = GetConnectTimeoutCancellationTokenSource(waiter); try { connection = await CreateHttp11ConnectionAsync(queueItem.Request, true, cts.Token).ConfigureAwait(false); @@ -691,8 +690,7 @@ private async Task AddHttp2ConnectionAsync(RequestQueue.QueueI Exception? connectionException = null; HttpConnectionWaiter waiter = queueItem.Waiter; - CancellationTokenSource cts = GetConnectTimeoutCancellationTokenSource(); - waiter.ConnectionCancellationTokenSource = cts; + CancellationTokenSource cts = GetConnectTimeoutCancellationTokenSource(waiter); try { (Stream stream, TransportContext? transportContext, IPEndPoint? remoteEndPoint) = await ConnectAsync(queueItem.Request, true, cts.Token).ConfigureAwait(false); @@ -1520,7 +1518,27 @@ public ValueTask SendAsync(HttpRequestMessage request, bool return SendWithProxyAuthAsync(request, async, doRequestAuth, cancellationToken); } - private CancellationTokenSource GetConnectTimeoutCancellationTokenSource() => new CancellationTokenSource(Settings._connectTimeout); + private CancellationTokenSource GetConnectTimeoutCancellationTokenSource(HttpConnectionWaiter waiter) + where T : HttpConnectionBase? + { + var cts = new CancellationTokenSource(Settings._connectTimeout); + + lock (waiter) + { + waiter.ConnectionCancellationTokenSource = cts; + + // The initiating request for this connection attempt may complete concurrently at any time. + // If it completed before we've set the CTS, CancelIfNecessary would no-op. + // Check it again now that we're holding the lock and ensure we always set a timeout. + if (waiter.Task.IsCompleted) + { + CancelIfNecessary(waiter, requestCancelled: waiter.Task.IsCanceled); + waiter.ConnectionCancellationTokenSource = null; + } + } + + return cts; + } private async ValueTask<(Stream, TransportContext?, IPEndPoint?)> ConnectAsync(HttpRequestMessage request, bool async, CancellationToken cancellationToken) { diff --git a/src/libraries/System.Net.Http/tests/FunctionalTests/SocketsHttpHandlerTest.Cancellation.cs b/src/libraries/System.Net.Http/tests/FunctionalTests/SocketsHttpHandlerTest.Cancellation.cs index 76d7086c37c174..c7ee4c6b781d64 100644 --- a/src/libraries/System.Net.Http/tests/FunctionalTests/SocketsHttpHandlerTest.Cancellation.cs +++ b/src/libraries/System.Net.Http/tests/FunctionalTests/SocketsHttpHandlerTest.Cancellation.cs @@ -393,6 +393,68 @@ public void PendingConnectionTimeout_HighValue_PendingConnectionIsNotCancelled(i }, UseVersion.ToString(), timeout.ToString()).Dispose(); } + [OuterLoop("We wait for PendingConnectionTimeout which defaults to 5 seconds.")] + [Fact] + public async Task PendingConnectionTimeout_SignalsAllConnectionAttempts() + { + if (UseVersion == HttpVersion.Version30) + { + // HTTP3 does not support ConnectCallback + return; + } + + int pendingConnectionAttempts = 0; + bool connectionAttemptTimedOut = false; + + using var handler = new SocketsHttpHandler + { + ConnectCallback = async (context, cancellation) => + { + Interlocked.Increment(ref pendingConnectionAttempts); + try + { + await Assert.ThrowsAsync(() => Task.Delay(-1, cancellation)).WaitAsync(TestHelper.PassingTestTimeout); + cancellation.ThrowIfCancellationRequested(); + throw new UnreachableException(); + } + catch (TimeoutException) + { + connectionAttemptTimedOut = true; + throw; + } + finally + { + Interlocked.Decrement(ref pendingConnectionAttempts); + } + } + }; + + using HttpClient client = CreateHttpClient(handler); + client.Timeout = TimeSpan.FromSeconds(2); + + // Many of these requests should trigger new connection attempts, and all of those should eventually be cleaned up. + await Parallel.ForAsync(0, 100, async (_, _) => + { + await Assert.ThrowsAnyAsync(() => client.GetAsync("https://dummy")); + }); + + Stopwatch stopwatch = Stopwatch.StartNew(); + + while (Volatile.Read(ref pendingConnectionAttempts) > 0) + { + Assert.False(connectionAttemptTimedOut); + + if (stopwatch.Elapsed > 2 * TestHelper.PassingTestTimeout) + { + Assert.Fail("Connection attempts took too long to get cleaned up"); + } + + await Task.Delay(100); + } + + Assert.False(connectionAttemptTimedOut); + } + private sealed class SetTcsContent : StreamContent { private readonly TaskCompletionSource _tcs;