Skip to content

Commit

Permalink
Add semaphore to limit subchannel connect to prevent race conditions (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
JamesNK authored Apr 30, 2024
1 parent 8199f66 commit 63914f2
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 27 deletions.
109 changes: 83 additions & 26 deletions src/Grpc.Net.Client/Balancer/Subchannel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ public sealed class Subchannel : IDisposable

internal readonly ConnectionManager _manager;
private readonly ILogger _logger;
private readonly SemaphoreSlim _connectSemaphore;

private ISubchannelTransport _transport = default!;
private ConnectContext? _connectContext;
Expand Down Expand Up @@ -89,6 +90,7 @@ internal Subchannel(ConnectionManager manager, IReadOnlyList<BalancerAddress> ad
{
Lock = new object();
_logger = manager.LoggerFactory.CreateLogger(GetType());
_connectSemaphore = new SemaphoreSlim(1);

Id = manager.GetNextId();
_addresses = addresses.ToList();
Expand Down Expand Up @@ -213,7 +215,10 @@ public void UpdateAddresses(IReadOnlyList<BalancerAddress> addresses)

if (requireReconnect)
{
CancelInProgressConnect();
lock (Lock)
{
CancelInProgressConnectUnsynchronized();
}
_transport.Disconnect();
RequestConnection();
}
Expand Down Expand Up @@ -268,43 +273,76 @@ public void RequestConnection()
}
}

private void CancelInProgressConnect()
private void CancelInProgressConnectUnsynchronized()
{
lock (Lock)
{
if (_connectContext != null && !_connectContext.Disposed)
{
SubchannelLog.CancelingConnect(_logger, Id);
Debug.Assert(Monitor.IsEntered(Lock));

// Cancel connect cancellation token.
_connectContext.CancelConnect();
_connectContext.Dispose();
}
if (_connectContext != null && !_connectContext.Disposed)
{
SubchannelLog.CancelingConnect(_logger, Id);

_delayInterruptTcs?.TrySetResult(null);
// Cancel connect cancellation token.
_connectContext.CancelConnect();
_connectContext.Dispose();
}

_delayInterruptTcs?.TrySetResult(null);
}

private ConnectContext GetConnectContext()
private ConnectContext GetConnectContextUnsynchronized()
{
lock (Lock)
{
// There shouldn't be a previous connect in progress, but cancel the CTS to ensure they're no longer running.
CancelInProgressConnect();
Debug.Assert(Monitor.IsEntered(Lock));

var connectContext = _connectContext = new ConnectContext(_transport.ConnectTimeout ?? Timeout.InfiniteTimeSpan);
return connectContext;
}
// There shouldn't be a previous connect in progress, but cancel the CTS to ensure they're no longer running.
CancelInProgressConnectUnsynchronized();

var connectContext = _connectContext = new ConnectContext(_transport.ConnectTimeout ?? Timeout.InfiniteTimeSpan);
return connectContext;
}

private async Task ConnectTransportAsync()
{
var connectContext = GetConnectContext();
ConnectContext connectContext;
Task? waitSemaporeTask = null;
lock (Lock)
{
// Don't start connecting if the subchannel has been shutdown. Transport/semaphore will be disposed if shutdown.
if (_state == ConnectivityState.Shutdown)
{
return;
}

connectContext = GetConnectContextUnsynchronized();

// Use a semaphore to limit one connection attempt at a time. This is done to prevent a race conditional where a canceled connect
// overwrites the status of a successful connect.
//
// Try to get semaphore without waiting. If semaphore is already taken then start a task to wait for it to be released.
// Start this inside a lock to make sure subchannel isn't shutdown before waiting for semaphore.
if (!_connectSemaphore.Wait(0))
{
SubchannelLog.QueuingConnect(_logger, Id);
waitSemaporeTask = _connectSemaphore.WaitAsync(connectContext.CancellationToken);
}
}

var backoffPolicy = _manager.BackoffPolicyFactory.Create();
if (waitSemaporeTask != null)
{
try
{
await waitSemaporeTask.ConfigureAwait(false);
}
catch (OperationCanceledException)
{
// Canceled while waiting for semaphore.
return;
}
}

try
{
var backoffPolicy = _manager.BackoffPolicyFactory.Create();

SubchannelLog.ConnectingTransport(_logger, Id);

for (var attempt = 0; ; attempt++)
Expand Down Expand Up @@ -384,6 +422,13 @@ private async Task ConnectTransportAsync()
// Dispose context because it might have been created with a connect timeout.
// Want to clean up the connect timeout timer.
connectContext.Dispose();

// Subchannel could have been disposed while connect is running.
// If subchannel is shutting down then don't release semaphore to avoid ObjectDisposedException.
if (_state != ConnectivityState.Shutdown)
{
_connectSemaphore.Release();
}
}
}
}
Expand Down Expand Up @@ -482,8 +527,12 @@ public void Dispose()
}
_stateChangedRegistrations.Clear();

CancelInProgressConnect();
_transport.Dispose();
lock (Lock)
{
CancelInProgressConnectUnsynchronized();
_transport.Dispose();
_connectSemaphore.Dispose();
}
}
}

Expand All @@ -505,7 +554,7 @@ internal static class SubchannelLog
LoggerMessage.Define<string, ConnectivityState>(LogLevel.Debug, new EventId(5, "ConnectionRequestedInNonIdleState"), "Subchannel id '{SubchannelId}' connection requested in non-idle state of {State}.");

private static readonly Action<ILogger, string, Exception?> _connectingTransport =
LoggerMessage.Define<string>(LogLevel.Trace, new EventId(6, "ConnectingTransport"), "Subchannel id '{SubchannelId}' connecting to transport.");
LoggerMessage.Define<string>(LogLevel.Debug, new EventId(6, "ConnectingTransport"), "Subchannel id '{SubchannelId}' connecting to transport.");

private static readonly Action<ILogger, string, TimeSpan, Exception?> _startingConnectBackoff =
LoggerMessage.Define<string, TimeSpan>(LogLevel.Trace, new EventId(7, "StartingConnectBackoff"), "Subchannel id '{SubchannelId}' starting connect backoff of {BackoffDuration}.");
Expand All @@ -514,7 +563,7 @@ internal static class SubchannelLog
LoggerMessage.Define<string>(LogLevel.Trace, new EventId(8, "ConnectBackoffInterrupted"), "Subchannel id '{SubchannelId}' connect backoff interrupted.");

private static readonly Action<ILogger, string, Exception?> _connectCanceled =
LoggerMessage.Define<string>(LogLevel.Trace, new EventId(9, "ConnectCanceled"), "Subchannel id '{SubchannelId}' connect canceled.");
LoggerMessage.Define<string>(LogLevel.Debug, new EventId(9, "ConnectCanceled"), "Subchannel id '{SubchannelId}' connect canceled.");

private static readonly Action<ILogger, string, Exception?> _connectError =
LoggerMessage.Define<string>(LogLevel.Error, new EventId(10, "ConnectError"), "Subchannel id '{SubchannelId}' unexpected error while connecting to transport.");
Expand Down Expand Up @@ -546,6 +595,9 @@ internal static class SubchannelLog
private static readonly Action<ILogger, string, string, Exception?> _addressesUpdated =
LoggerMessage.Define<string, string>(LogLevel.Trace, new EventId(19, "AddressesUpdated"), "Subchannel id '{SubchannelId}' updated with addresses: {Addresses}");

private static readonly Action<ILogger, string, Exception?> _queuingConnect =
LoggerMessage.Define<string>(LogLevel.Debug, new EventId(20, "QueuingConnect"), "Subchannel id '{SubchannelId}' queuing connect because a connect is already in progress.");

public static void SubchannelCreated(ILogger logger, string subchannelId, IReadOnlyList<BalancerAddress> addresses)
{
if (logger.IsEnabled(LogLevel.Debug))
Expand Down Expand Up @@ -648,5 +700,10 @@ public static void AddressesUpdated(ILogger logger, string subchannelId, IReadOn
_addressesUpdated(logger, subchannelId, addressesText, null);
}
}

public static void QueuingConnect(ILogger logger, string subchannelId)
{
_queuingConnect(logger, subchannelId, null);
}
}
#endif
5 changes: 4 additions & 1 deletion test/FunctionalTests/Balancer/ConnectionTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,11 @@ async Task<HelloReply> UnaryMethod(HelloRequest request, ServerCallContext conte

var client = TestClientFactory.Create(channel, endpoint.Method);

// Act
var ex = await ExceptionAssert.ThrowsAsync<RpcException>(() => client.UnaryCall(new HelloRequest()).ResponseAsync).DefaultTimeout();
Assert.AreEqual("A connection could not be established within the configured ConnectTimeout.", ex.Status.DebugException!.Message);

await ExceptionAssert.ThrowsAsync<OperationCanceledException>(() => connectTcs.Task).DefaultTimeout();
}

[Test]
Expand Down Expand Up @@ -167,7 +170,7 @@ Task<HelloReply> UnaryMethod(HelloRequest request, ServerCallContext context)
connectionIdleTimeout: connectionIdleTimeout).DefaultTimeout();

Logger.LogInformation("Connecting channel.");
await channel.ConnectAsync();
await channel.ConnectAsync().DefaultTimeout();

// Wait for timeout plus a little extra to avoid issues from imprecise timers.
await Task.Delay(connectionIdleTimeout + TimeSpan.FromMilliseconds(50));
Expand Down
39 changes: 39 additions & 0 deletions test/FunctionalTests/Balancer/PickFirstBalancerTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,45 @@ private GrpcChannel CreateGrpcWebChannel(TestServerEndpointName endpointName, Ve
return channel;
}

[Test]
public async Task UnaryCall_CallAfterConnectionTimeout_Success()
{
// Ignore errors
SetExpectedErrorsFilter(writeContext =>
{
return true;
});

string? host = null;
Task<HelloReply> UnaryMethod(HelloRequest request, ServerCallContext context)
{
host = context.Host;
return Task.FromResult(new HelloReply { Message = request.Name });
}

// Arrange
using var endpoint = BalancerHelpers.CreateGrpcEndpoint<HelloRequest, HelloReply>(50051, UnaryMethod, nameof(UnaryMethod));

var connectCount = 0;
var channel = await BalancerHelpers.CreateChannel(LoggerFactory, new PickFirstConfig(), new[] { endpoint.Address }, connectTimeout: TimeSpan.FromMilliseconds(200), socketConnect:
async (socket, endpoint, cancellationToken) =>
{
if (Interlocked.Increment(ref connectCount) == 1)
{
await Task.Delay(1000, cancellationToken);
}
await socket.ConnectAsync(endpoint, cancellationToken);
}).DefaultTimeout();
var client = TestClientFactory.Create(channel, endpoint.Method);

// Assert
var ex = await ExceptionAssert.ThrowsAsync<RpcException>(() => client.UnaryCall(new HelloRequest { Name = "Balancer" }).ResponseAsync).DefaultTimeout();
Assert.AreEqual(StatusCode.Unavailable, ex.StatusCode);
Assert.IsInstanceOf(typeof(TimeoutException), ex.InnerException);

await client.UnaryCall(new HelloRequest { Name = "Balancer" }).ResponseAsync.DefaultTimeout();
}

[Test]
public async Task UnaryCall_CallAfterCancellation_Success()
{
Expand Down

0 comments on commit 63914f2

Please sign in to comment.