Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

http.sys accept loop - mitigate against break due to possible conflicting IO callbacks #54368

Merged
merged 4 commits into from
Mar 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions src/Servers/HttpSys/src/AsyncAcceptContext.Log.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using Microsoft.Extensions.Logging;

namespace Microsoft.AspNetCore.Server.HttpSys;

internal partial class AsyncAcceptContext
{
private static partial class Log
{
[LoggerMessage(LoggerEventIds.AcceptSetResultFailed, LogLevel.Error, "Error attempting to set 'accept' outcome", EventName = "AcceptSetResultFailed")]
public static partial void AcceptSetResultFailed(ILogger logger, Exception exception);

// note on "critical": these represent an unexpected IO callback state that needs investigation; see https://github.com/dotnet/aspnetcore/pull/54368/

[LoggerMessage(LoggerEventIds.AcceptSetExpectationMismatch, LogLevel.Critical, "Mismatch setting callback expectation - {Value}", EventName = "AcceptSetExpectationMismatch")]
public static partial void AcceptSetExpectationMismatch(ILogger logger, int value);

[LoggerMessage(LoggerEventIds.AcceptCancelExpectationMismatch, LogLevel.Critical, "Mismatch canceling accept state - {Value}", EventName = "AcceptCancelExpectationMismatch")]
public static partial void AcceptCancelExpectationMismatch(ILogger logger, int value);

[LoggerMessage(LoggerEventIds.AcceptObserveExpectationMismatch, LogLevel.Critical, "Mismatch observing {Kind} accept callback - {Value}", EventName = "AcceptObserveExpectationMismatch")]
public static partial void AcceptObserveExpectationMismatch(ILogger logger, string kind, int value);
}
}
148 changes: 110 additions & 38 deletions src/Servers/HttpSys/src/AsyncAcceptContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,23 @@

using System.Diagnostics;
using System.Threading.Tasks.Sources;
using Microsoft.Extensions.Logging;

namespace Microsoft.AspNetCore.Server.HttpSys;

internal sealed unsafe class AsyncAcceptContext : IValueTaskSource<RequestContext>, IDisposable
internal sealed unsafe partial class AsyncAcceptContext : IValueTaskSource<RequestContext>, IDisposable
{
private static readonly IOCompletionCallback IOCallback = IOWaitCallback;
private readonly PreAllocatedOverlapped _preallocatedOverlapped;
private readonly IRequestContextFactory _requestContextFactory;
private readonly ILogger _logger;
private int _expectedCompletionCount;

private NativeOverlapped* _overlapped;

private readonly bool _logExpectationFailures = AppContext.TryGetSwitch(
"Microsoft.AspNetCore.Server.HttpSys.LogAcceptExpectationFailure", out var enabled) && enabled;

// mutable struct; do not make this readonly
private ManualResetValueTaskSourceCore<RequestContext> _mrvts = new()
{
Expand All @@ -23,11 +29,12 @@ internal sealed unsafe class AsyncAcceptContext : IValueTaskSource<RequestContex

private RequestContext? _requestContext;

internal AsyncAcceptContext(HttpSysListener server, IRequestContextFactory requestContextFactory)
internal AsyncAcceptContext(HttpSysListener server, IRequestContextFactory requestContextFactory, ILogger logger)
{
Server = server;
_requestContextFactory = requestContextFactory;
_preallocatedOverlapped = new(IOCallback, state: this, pinData: null);
_logger = logger;
}

internal HttpSysListener Server { get; }
Expand All @@ -50,15 +57,16 @@ internal ValueTask<RequestContext> AcceptAsync()
return new ValueTask<RequestContext>(this, _mrvts.Version);
}

private void IOCompleted(uint errorCode, uint numBytes)
private void IOCompleted(uint errorCode, uint numBytes, bool managed)
{
try
{
ObserveCompletion(managed); // expectation tracking
if (errorCode != ErrorCodes.ERROR_SUCCESS &&
errorCode != ErrorCodes.ERROR_MORE_DATA)
{
_mrvts.SetException(new HttpSysException((int)errorCode));
return;
// (keep all the error handling in one place)
throw new HttpSysException((int)errorCode);
}

Debug.Assert(_requestContext != null);
Expand All @@ -70,7 +78,14 @@ private void IOCompleted(uint errorCode, uint numBytes)
// we want to reuse the acceptContext object for future accepts.
_requestContext = null;

_mrvts.SetResult(requestContext);
try
{
_mrvts.SetResult(requestContext);
}
catch (Exception ex)
{
Log.AcceptSetResultFailed(_logger, ex);
}
}
else
{
Expand All @@ -83,22 +98,69 @@ private void IOCompleted(uint errorCode, uint numBytes)
if (statusCode != ErrorCodes.ERROR_SUCCESS &&
statusCode != ErrorCodes.ERROR_IO_PENDING)
{
// someother bad error, possible(?) return values are:
// some other bad error, possible(?) return values are:
// ERROR_INVALID_HANDLE, ERROR_INSUFFICIENT_BUFFER, ERROR_OPERATION_ABORTED
_mrvts.SetException(new HttpSysException((int)statusCode));
// (keep all the error handling in one place)
throw new HttpSysException((int)statusCode);
}
}
}
catch (Exception exception)
{
_mrvts.SetException(exception);
try
{
_mrvts.SetException(exception);
}
catch (Exception ex)
{
Log.AcceptSetResultFailed(_logger, ex);
}
}
}

private static unsafe void IOWaitCallback(uint errorCode, uint numBytes, NativeOverlapped* nativeOverlapped)
{
var acceptContext = (AsyncAcceptContext)ThreadPoolBoundHandle.GetNativeOverlappedState(nativeOverlapped)!;
acceptContext.IOCompleted(errorCode, numBytes);
acceptContext.IOCompleted(errorCode, numBytes, false);
}

private void SetExpectCompletion() // we anticipate a completion *might* occur
{
// note this is intentionally a "reset and check" rather than Increment, so that we don't spam
// the logs forever if a glitch occurs
var value = Interlocked.Exchange(ref _expectedCompletionCount, 1); // should have been 0
if (value != 0)
{
if (_logExpectationFailures)
{
Log.AcceptSetExpectationMismatch(_logger, value);
}
Debug.Assert(false, nameof(SetExpectCompletion)); // fail hard in debug
}
}
private void CancelExpectCompletion() // due to error-code etc, we no longer anticipate a completion
{
var value = Interlocked.Decrement(ref _expectedCompletionCount); // should have been 1, so now 0
if (value != 0)
{
if (_logExpectationFailures)
{
Log.AcceptCancelExpectationMismatch(_logger, value);
}
Debug.Assert(false, nameof(CancelExpectCompletion)); // fail hard in debug
}
}
private void ObserveCompletion(bool managed) // a completion was invoked
{
var value = Interlocked.Decrement(ref _expectedCompletionCount); // should have been 1, so now 0
if (value != 0)
{
if (_logExpectationFailures)
{
Log.AcceptObserveExpectationMismatch(_logger, managed ? "managed" : "unmanaged", value);
}
Debug.Assert(false, nameof(ObserveCompletion)); // fail hard in debug
}
}

private uint QueueBeginGetContext()
Expand All @@ -111,6 +173,7 @@ private uint QueueBeginGetContext()

retry = false;
uint bytesTransferred = 0;
SetExpectCompletion(); // track this *before*, because of timing vs IOCP (could even be effectively synchronous)
statusCode = HttpApi.HttpReceiveHttpRequest(
Server.RequestQueue.Handle,
_requestContext.RequestId,
Expand All @@ -122,35 +185,44 @@ private uint QueueBeginGetContext()
&bytesTransferred,
_overlapped);

if ((statusCode == ErrorCodes.ERROR_CONNECTION_INVALID
|| statusCode == ErrorCodes.ERROR_INVALID_PARAMETER)
&& _requestContext.RequestId != 0)
{
// ERROR_CONNECTION_INVALID:
// The client reset the connection between the time we got the MORE_DATA error and when we called HttpReceiveHttpRequest
// with the new buffer. We can clear the request id and move on to the next request.
//
// ERROR_INVALID_PARAMETER: Historical check from HttpListener.
// https://referencesource.microsoft.com/#System/net/System/Net/_ListenerAsyncResult.cs,137
// we might get this if somebody stole our RequestId,
// set RequestId to 0 and start all over again with the buffer we just allocated
// BUGBUG: how can someone steal our request ID? seems really bad and in need of fix.
_requestContext.RequestId = 0;
retry = true;
}
else if (statusCode == ErrorCodes.ERROR_MORE_DATA)
{
// the buffer was not big enough to fit the headers, we need
// to read the RequestId returned, allocate a new buffer of the required size
// (uint)backingBuffer.Length - AlignmentPadding
AllocateNativeRequest(bytesTransferred);
retry = true;
}
else if (statusCode == ErrorCodes.ERROR_SUCCESS
&& HttpSysListener.SkipIOCPCallbackOnSuccess)
switch (statusCode)
{
// IO operation completed synchronously - callback won't be called to signal completion.
IOCompleted(statusCode, bytesTransferred);
case (ErrorCodes.ERROR_CONNECTION_INVALID or ErrorCodes.ERROR_INVALID_PARAMETER) when _requestContext.RequestId != 0:
// ERROR_CONNECTION_INVALID:
// The client reset the connection between the time we got the MORE_DATA error and when we called HttpReceiveHttpRequest
// with the new buffer. We can clear the request id and move on to the next request.
//
// ERROR_INVALID_PARAMETER: Historical check from HttpListener.
// https://referencesource.microsoft.com/#System/net/System/Net/_ListenerAsyncResult.cs,137
// we might get this if somebody stole our RequestId,
// set RequestId to 0 and start all over again with the buffer we just allocated
// BUGBUG: how can someone steal our request ID? seems really bad and in need of fix.
CancelExpectCompletion();
_requestContext.RequestId = 0;
retry = true;
break;
case ErrorCodes.ERROR_MORE_DATA:
// the buffer was not big enough to fit the headers, we need
// to read the RequestId returned, allocate a new buffer of the required size
// (uint)backingBuffer.Length - AlignmentPadding
CancelExpectCompletion(); // we'll "expect" again when we retry
AllocateNativeRequest(bytesTransferred);
retry = true;
break;
case ErrorCodes.ERROR_SUCCESS:
if (HttpSysListener.SkipIOCPCallbackOnSuccess)
{
// IO operation completed synchronously - callback won't be called to signal completion.
IOCompleted(statusCode, bytesTransferred, true); // marks completion
}
// else: callback fired by IOCP (at some point), which marks completion
break;
case ErrorCodes.ERROR_IO_PENDING:
break; // no change to state - callback will occur at some point
default:
// fault code, not expecting an IOCP callback
CancelExpectCompletion();
break;
}
}
while (retry);
Expand Down
4 changes: 4 additions & 0 deletions src/Servers/HttpSys/src/LoggerEventIds.cs
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,8 @@ internal static class LoggerEventIds
public const int RequestValidationFailed = 47;
public const int CreateDisconnectTokenError = 48;
public const int RequestAborted = 49;
public const int AcceptSetResultFailed = 50;
public const int AcceptSetExpectationMismatch = 51;
public const int AcceptCancelExpectationMismatch = 52;
public const int AcceptObserveExpectationMismatch = 53;
}
2 changes: 1 addition & 1 deletion src/Servers/HttpSys/src/MessagePump.cs
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ private void ProcessRequestsWorker()
Debug.Assert(RequestContextFactory != null);

// Allocate and accept context per loop and reuse it for all accepts
var acceptContext = new AsyncAcceptContext(Listener, RequestContextFactory);
var acceptContext = new AsyncAcceptContext(Listener, RequestContextFactory, _logger);

var loop = new AcceptLoop(acceptContext, this);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ internal static HttpSysListener CreateServerOnExistingQueue(string requestQueueN
internal static async Task<RequestContext> AcceptAsync(this HttpSysListener server, TimeSpan timeout)
{
var factory = new TestRequestContextFactory(server);
using var acceptContext = new AsyncAcceptContext(server, factory);
using var acceptContext = new AsyncAcceptContext(server, factory, server.Logger);

async Task<RequestContext> AcceptAsync()
{
Expand Down
Loading