Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Correctly check socket on stream creation #2215

Merged
merged 4 commits into from
Aug 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -253,53 +253,7 @@ private void OnCheckSocketConnection(object? state)
{
CompatibilityHelpers.Assert(socketAddress != null);

try
{
SocketConnectivitySubchannelTransportLog.CheckingSocket(_logger, _subchannel.Id, socketAddress);

// Poll socket to check if it can be read from. Unfortunatly this requires reading pending data.
// The server might send data, e.g. HTTP/2 SETTINGS frame, so we need to read and cache it.
//
// Available data needs to be read now because the only way to determine whether the connection is closed is to
// get the results of polling after available data is received.
bool hasReadData;
do
{
closeSocket = IsSocketInBadState(socket, socketAddress);
var available = socket.Available;
if (available > 0)
{
hasReadData = true;
var serverDataAvailable = CalculateInitialSocketDataLength(_initialSocketData) + available;
if (serverDataAvailable > MaximumInitialSocketDataSize)
{
// Data sent to the client before a connection is started shouldn't be large.
// Put a maximum limit on the buffer size to prevent an unexpected scenario from consuming too much memory.
throw new InvalidOperationException($"The server sent {serverDataAvailable} bytes to the client before a connection was established. Maximum allowed data exceeded.");
}

SocketConnectivitySubchannelTransportLog.SocketReceivingAvailable(_logger, _subchannel.Id, socketAddress, available);

// Data is already available so this won't block.
var buffer = new byte[available];
var readCount = socket.Receive(buffer);

_initialSocketData ??= new List<ReadOnlyMemory<byte>>();
_initialSocketData.Add(buffer.AsMemory(0, readCount));
}
else
{
hasReadData = false;
}
}
while (hasReadData);
}
catch (Exception ex)
{
closeSocket = true;
checkException = ex;
SocketConnectivitySubchannelTransportLog.ErrorCheckingSocket(_logger, _subchannel.Id, socketAddress, ex);
}
closeSocket = ShouldCloseSocket(socket, socketAddress, ref _initialSocketData, out checkException);
}
}

Expand Down Expand Up @@ -383,7 +337,7 @@ public async ValueTask<Stream> GetStreamAsync(BalancerAddress address, Cancellat
SocketConnectivitySubchannelTransportLog.ClosingSocketFromIdleTimeoutOnCreateStream(_logger, _subchannel.Id, address, _socketIdleTimeout);
closeSocket = true;
}
else if (IsSocketInBadState(socket, address))
else if (ShouldCloseSocket(socket, address, ref socketData, out _))
{
SocketConnectivitySubchannelTransportLog.ClosingUnusableSocketOnCreateStream(_logger, _subchannel.Id, address);
closeSocket = true;
Expand Down Expand Up @@ -419,7 +373,75 @@ public async ValueTask<Stream> GetStreamAsync(BalancerAddress address, Cancellat
return stream;
}

private bool IsSocketInBadState(Socket socket, BalancerAddress address)
/// <summary>
/// Checks whether the socket is healthy. May read available data into the passed in buffer.
/// Returns true if the socket should be closed.
/// </summary>
private bool ShouldCloseSocket(Socket socket, BalancerAddress socketAddress, ref List<ReadOnlyMemory<byte>>? socketData, out Exception? checkException)
{
checkException = null;

try
{
SocketConnectivitySubchannelTransportLog.CheckingSocket(_logger, _subchannel.Id, socketAddress);

// Poll socket to check if it can be read from. Unfortunately this requires reading pending data.
// The server might send data, e.g. HTTP/2 SETTINGS frame, so we need to read and cache it.
//
// Available data needs to be read now because the only way to determine whether the connection is
// closed is to get the results of polling after available data is received.
// For example, the server may have sent an HTTP/2 SETTINGS or GOAWAY frame.
// We need to cache whatever we read so it isn't dropped.
do
{
if (PollSocket(socket, socketAddress))
{
// Polling socket reported an unhealthy state.
return true;
}

var available = socket.Available;
if (available > 0)
{
var serverDataAvailable = CalculateInitialSocketDataLength(socketData) + available;
if (serverDataAvailable > MaximumInitialSocketDataSize)
{
// Data sent to the client before a connection is started shouldn't be large.
// Put a maximum limit on the buffer size to prevent an unexpected scenario from consuming too much memory.
throw new InvalidOperationException($"The server sent {serverDataAvailable} bytes to the client before a connection was established. Maximum allowed data exceeded.");
}

SocketConnectivitySubchannelTransportLog.SocketReceivingAvailable(_logger, _subchannel.Id, socketAddress, available);

// Data is already available so this won't block.
var buffer = new byte[available];
var readCount = socket.Receive(buffer);

socketData ??= new List<ReadOnlyMemory<byte>>();
socketData.Add(buffer.AsMemory(0, readCount));
}
else
{
// There is no more available data to read and the socket is healthy.
return false;
}
}
while (true);
}
catch (Exception ex)
{
checkException = ex;
SocketConnectivitySubchannelTransportLog.ErrorCheckingSocket(_logger, _subchannel.Id, socketAddress, ex);
return true;
}
}

/// <summary>
/// Poll the socket to check for health and available data.
/// Shouldn't be used by itself as data needs to be consumed to accurately report the socket health.
/// <see cref="ShouldCloseSocket"/> handles consuming data and getting the socket health.
/// </summary>
private bool PollSocket(Socket socket, BalancerAddress address)
{
// From https://github.com/dotnet/runtime/blob/3195fbbd82fdb7f132d6698591ba6489ad6dd8cf/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnection.cs#L158-L168
try
Expand Down
25 changes: 18 additions & 7 deletions test/FunctionalTests/Balancer/BalancerHelpers.cs
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,12 @@ public static EndpointContext<TRequest, TResponse> CreateGrpcEndpoint<TRequest,
HttpProtocols? protocols = null,
bool? isHttps = null,
X509Certificate2? certificate = null,
ILoggerFactory? loggerFactory = null)
ILoggerFactory? loggerFactory = null,
Action<KestrelServerOptions>? configureServer = null)
where TRequest : class, IMessage, new()
where TResponse : class, IMessage, new()
{
var server = CreateServer(port, protocols, isHttps, certificate, loggerFactory);
var server = CreateServer(port, protocols, isHttps, certificate, loggerFactory, configureServer);
var method = server.DynamicGrpc.AddUnaryMethod(callHandler, methodName);
var url = server.GetUrl(isHttps.GetValueOrDefault(false) ? TestServerEndpointName.Http2WithTls : TestServerEndpointName.Http2);

Expand Down Expand Up @@ -88,7 +89,13 @@ public void Dispose()
}
}

public static GrpcTestFixture<Startup> CreateServer(int port, HttpProtocols? protocols = null, bool? isHttps = null, X509Certificate2? certificate = null, ILoggerFactory? loggerFactory = null)
public static GrpcTestFixture<Startup> CreateServer(
int port,
HttpProtocols? protocols = null,
bool? isHttps = null,
X509Certificate2? certificate = null,
ILoggerFactory? loggerFactory = null,
Action<KestrelServerOptions>? configureServer = null)
{
var endpointName = isHttps.GetValueOrDefault(false) ? TestServerEndpointName.Http2WithTls : TestServerEndpointName.Http2;

Expand All @@ -102,6 +109,8 @@ public static GrpcTestFixture<Startup> CreateServer(int port, HttpProtocols? pro
},
(options, urls) =>
{
configureServer?.Invoke(options);

urls[endpointName] = isHttps.GetValueOrDefault(false)
? $"https://127.0.0.1:{port}"
: $"http://127.0.0.1:{port}";
Expand Down Expand Up @@ -136,13 +145,14 @@ public static Task<GrpcChannel> CreateChannel(
RetryPolicy? retryPolicy = null,
Func<Socket, DnsEndPoint, CancellationToken, ValueTask>? socketConnect = null,
TimeSpan? connectTimeout = null,
TimeSpan? connectionIdleTimeout = null)
TimeSpan? connectionIdleTimeout = null,
TimeSpan? socketPingInterval = null)
{
var resolver = new TestResolver();
var e = endpoints.Select(i => new BalancerAddress(i.Host, i.Port)).ToList();
resolver.UpdateAddresses(e);

return CreateChannel(loggerFactory, loadBalancingConfig, resolver, httpMessageHandler, connect, retryPolicy, socketConnect, connectTimeout, connectionIdleTimeout);
return CreateChannel(loggerFactory, loadBalancingConfig, resolver, httpMessageHandler, connect, retryPolicy, socketConnect, connectTimeout, connectionIdleTimeout, socketPingInterval);
}

public static async Task<GrpcChannel> CreateChannel(
Expand All @@ -154,12 +164,13 @@ public static async Task<GrpcChannel> CreateChannel(
RetryPolicy? retryPolicy = null,
Func<Socket, DnsEndPoint, CancellationToken, ValueTask>? socketConnect = null,
TimeSpan? connectTimeout = null,
TimeSpan? connectionIdleTimeout = null)
TimeSpan? connectionIdleTimeout = null,
TimeSpan? socketPingInterval = null)
{
var services = new ServiceCollection();
services.AddSingleton<ResolverFactory>(new TestResolverFactory(resolver));
services.AddSingleton<IRandomGenerator>(new TestRandomGenerator());
services.AddSingleton<ISubchannelTransportFactory>(new TestSubchannelTransportFactory(TimeSpan.FromSeconds(0.5), connectTimeout, connectionIdleTimeout ?? TimeSpan.FromMinutes(1), socketConnect));
services.AddSingleton<ISubchannelTransportFactory>(new TestSubchannelTransportFactory(socketPingInterval ?? TimeSpan.FromSeconds(0.5), connectTimeout, connectionIdleTimeout ?? TimeSpan.FromMinutes(1), socketConnect));
services.AddSingleton<LoadBalancerFactory>(new LeastUsedBalancerFactory());

var serviceConfig = new ServiceConfig();
Expand Down
51 changes: 50 additions & 1 deletion test/FunctionalTests/Balancer/ConnectionTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ Task<HelloReply> UnaryMethod(HelloRequest request, ServerCallContext context)
}

// Arrange
using var endpoint = BalancerHelpers.CreateGrpcEndpoint<HelloRequest, HelloReply>(50051, UnaryMethod, nameof(UnaryMethod));
using var endpoint = BalancerHelpers.CreateGrpcEndpoint<HelloRequest, HelloReply>(50051, UnaryMethod, nameof(UnaryMethod), loggerFactory: LoggerFactory);

var connectionIdleTimeout = TimeSpan.FromSeconds(1);
var channel = await BalancerHelpers.CreateChannel(
Expand All @@ -180,6 +180,55 @@ Task<HelloReply> UnaryMethod(HelloRequest request, ServerCallContext context)
AssertHasLog(LogLevel.Trace, "ConnectingOnCreateStream", "Subchannel id '1' doesn't have a connected socket available. Connecting new stream socket for 127.0.0.1:50051.");
}

[Test]
public async Task Active_UnaryCall_ServerCloseOnKeepAlive_SocketRecreatedOnRequest()
{
// Ignore errors
SetExpectedErrorsFilter(writeContext =>
{
return true;
});

Task<HelloReply> UnaryMethod(HelloRequest request, ServerCallContext context)
{
return Task.FromResult(new HelloReply { Message = request.Name });
}

// In this test the client connects to the server, and the server then closes it after keep-alive is triggered.
JamesNK marked this conversation as resolved.
Show resolved Hide resolved
// The client then starts a gRPC call to the server. The client should discard the closed socket and create a new one.

// Arrange
using var endpoint = BalancerHelpers.CreateGrpcEndpoint<HelloRequest, HelloReply>(
50051,
JamesNK marked this conversation as resolved.
Show resolved Hide resolved
UnaryMethod,
nameof(UnaryMethod),
loggerFactory: LoggerFactory,
configureServer: o => o.Limits.KeepAliveTimeout = TimeSpan.FromSeconds(1));

// Don't timeout the socket or ping it from the client.
var channel = await BalancerHelpers.CreateChannel(
LoggerFactory,
new RoundRobinConfig(),
new[] { endpoint.Address },
connectionIdleTimeout: TimeSpan.FromMinutes(30),
socketPingInterval: TimeSpan.FromMinutes(30)).DefaultTimeout();

Logger.LogInformation("Connecting channel.");
await channel.ConnectAsync();

// Fails when this test is run with debugging. Kestrel doesn't trigger keepalive timeout if debugging is enabled.
await TestHelpers.AssertIsTrueRetryAsync(() =>
{
return Logs.Any(l => l.LoggerName.StartsWith("Microsoft.AspNetCore.Server.Kestrel") && l.EventId.Name == "ConnectionStop");
}, "Wait for server to close connection.");

var client = TestClientFactory.Create(channel, endpoint.Method);
var response = await client.UnaryCall(new HelloRequest { Name = "Test!" }).ResponseAsync.DefaultTimeout();

// Assert
Assert.AreEqual("Test!", response.Message);
}

[Test]
public async Task Active_UnaryCall_MultipleStreams_UnavailableAddress_FallbackToWorkingAddress()
{
Expand Down