Skip to content

Commit

Permalink
Don't capture async locals in resolver (#2426)
Browse files Browse the repository at this point in the history
  • Loading branch information
JamesNK authored May 2, 2024
1 parent 63914f2 commit c80f459
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 9 deletions.
33 changes: 26 additions & 7 deletions src/Grpc.Net.Client/Balancer/PollingResolver.cs
Original file line number Diff line number Diff line change
Expand Up @@ -135,14 +135,33 @@ public sealed override void Refresh()

if (_resolveTask.IsCompleted)
{
// Run ResolveAsync in a background task.
// This is done to prevent synchronous block inside ResolveAsync from blocking future Refresh calls.
_resolveTask = Task.Run(() => ResolveNowAsync(_cts.Token), _cts.Token);
_resolveTask.ContinueWith(static (t, state) =>
// Don't capture the current ExecutionContext and its AsyncLocals onto the connect
var restoreFlow = false;
try
{
if (!ExecutionContext.IsFlowSuppressed())
{
ExecutionContext.SuppressFlow();
restoreFlow = true;
}

// Run ResolveAsync in a background task.
// This is done to prevent synchronous block inside ResolveAsync from blocking future Refresh calls.
_resolveTask = Task.Run(() => ResolveNowAsync(_cts.Token), _cts.Token);
_resolveTask.ContinueWith(static (t, state) =>
{
var pollingResolver = (PollingResolver)state!;
Log.ResolveTaskCompleted(pollingResolver._logger, pollingResolver.GetType());
}, this);
}
finally
{
var pollingResolver = (PollingResolver)state!;
Log.ResolveTaskCompleted(pollingResolver._logger, pollingResolver.GetType());
}, this);
// Restore the current ExecutionContext
if (restoreFlow)
{
ExecutionContext.RestoreFlow();
}
}
}
else
{
Expand Down
2 changes: 1 addition & 1 deletion src/Grpc.Net.Client/Balancer/Subchannel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ public void RequestConnection()
}

// Don't capture the current ExecutionContext and its AsyncLocals onto the connect
bool restoreFlow = false;
var restoreFlow = false;
if (!ExecutionContext.IsFlowSuppressed())
{
ExecutionContext.SuppressFlow();
Expand Down
2 changes: 1 addition & 1 deletion src/Shared/NonCapturingTimer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ public static Timer Create(TimerCallback callback, object? state, TimeSpan dueTi
ArgumentNullThrowHelper.ThrowIfNull(callback);

// Don't capture the current ExecutionContext and its AsyncLocals onto the timer
bool restoreFlow = false;
var restoreFlow = false;
try
{
if (!ExecutionContext.IsFlowSuppressed())
Expand Down
53 changes: 53 additions & 0 deletions test/Grpc.Net.Client.Tests/Balancer/ResolverTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,59 @@ protected override Task ResolveAsync(CancellationToken cancellationToken)
}
}

[Test]
public async Task Refresh_AsyncLocal_NotCaptured()
{
// Arrange
var services = new ServiceCollection();
services.AddNUnitLogger();
var loggerFactory = services.BuildServiceProvider().GetRequiredService<ILoggerFactory>();

var asyncLocal = new AsyncLocal<object>();
asyncLocal.Value = new object();

var callbackAsyncLocalValues = new List<object>();

var resolver = new CallbackPollingResolver(loggerFactory, new TestBackoffPolicyFactory(TimeSpan.FromMilliseconds(100)), (listener) =>
{
callbackAsyncLocalValues.Add(asyncLocal.Value);
if (callbackAsyncLocalValues.Count >= 2)
{
listener(ResolverResult.ForResult(new List<BalancerAddress>()));
}

return Task.CompletedTask;
});

var tcs = new TaskCompletionSource<ResolverResult>(TaskCreationOptions.RunContinuationsAsynchronously);
resolver.Start(result => tcs.TrySetResult(result));

// Act
resolver.Refresh();

// Assert
await tcs.Task.DefaultTimeout();

Assert.AreEqual(2, callbackAsyncLocalValues.Count);
Assert.IsNull(callbackAsyncLocalValues[0]);
Assert.IsNull(callbackAsyncLocalValues[1]);
}

private class CallbackPollingResolver : PollingResolver
{
private readonly Func<Action<ResolverResult>, Task> _callback;

public CallbackPollingResolver(ILoggerFactory loggerFactory, IBackoffPolicyFactory backoffPolicyFactory, Func<Action<ResolverResult>, Task> callback) : base(loggerFactory, backoffPolicyFactory)
{
_callback = callback;
}

protected override Task ResolveAsync(CancellationToken cancellationToken)
{
return _callback(Listener);
}
}

[Test]
public async Task Resolver_ResolveNameFromServices_Success()
{
Expand Down

0 comments on commit c80f459

Please sign in to comment.