diff --git a/src/Grpc.Net.Client/Balancer/PollingResolver.cs b/src/Grpc.Net.Client/Balancer/PollingResolver.cs index 0d4ced710..bbec7b675 100644 --- a/src/Grpc.Net.Client/Balancer/PollingResolver.cs +++ b/src/Grpc.Net.Client/Balancer/PollingResolver.cs @@ -135,14 +135,33 @@ public sealed override void Refresh() if (_resolveTask.IsCompleted) { - // Run ResolveAsync in a background task. - // This is done to prevent synchronous block inside ResolveAsync from blocking future Refresh calls. - _resolveTask = Task.Run(() => ResolveNowAsync(_cts.Token), _cts.Token); - _resolveTask.ContinueWith(static (t, state) => + // Don't capture the current ExecutionContext and its AsyncLocals onto the connect + var restoreFlow = false; + try + { + if (!ExecutionContext.IsFlowSuppressed()) + { + ExecutionContext.SuppressFlow(); + restoreFlow = true; + } + + // Run ResolveAsync in a background task. + // This is done to prevent synchronous block inside ResolveAsync from blocking future Refresh calls. + _resolveTask = Task.Run(() => ResolveNowAsync(_cts.Token), _cts.Token); + _resolveTask.ContinueWith(static (t, state) => + { + var pollingResolver = (PollingResolver)state!; + Log.ResolveTaskCompleted(pollingResolver._logger, pollingResolver.GetType()); + }, this); + } + finally { - var pollingResolver = (PollingResolver)state!; - Log.ResolveTaskCompleted(pollingResolver._logger, pollingResolver.GetType()); - }, this); + // Restore the current ExecutionContext + if (restoreFlow) + { + ExecutionContext.RestoreFlow(); + } + } } else { diff --git a/src/Grpc.Net.Client/Balancer/Subchannel.cs b/src/Grpc.Net.Client/Balancer/Subchannel.cs index 1bbc22ad3..0f163de9d 100644 --- a/src/Grpc.Net.Client/Balancer/Subchannel.cs +++ b/src/Grpc.Net.Client/Balancer/Subchannel.cs @@ -257,7 +257,7 @@ public void RequestConnection() } // Don't capture the current ExecutionContext and its AsyncLocals onto the connect - bool restoreFlow = false; + var restoreFlow = false; if (!ExecutionContext.IsFlowSuppressed()) { ExecutionContext.SuppressFlow(); diff --git a/src/Shared/NonCapturingTimer.cs b/src/Shared/NonCapturingTimer.cs index 674333969..e957c20ba 100644 --- a/src/Shared/NonCapturingTimer.cs +++ b/src/Shared/NonCapturingTimer.cs @@ -13,7 +13,7 @@ public static Timer Create(TimerCallback callback, object? state, TimeSpan dueTi ArgumentNullThrowHelper.ThrowIfNull(callback); // Don't capture the current ExecutionContext and its AsyncLocals onto the timer - bool restoreFlow = false; + var restoreFlow = false; try { if (!ExecutionContext.IsFlowSuppressed()) diff --git a/test/Grpc.Net.Client.Tests/Balancer/ResolverTests.cs b/test/Grpc.Net.Client.Tests/Balancer/ResolverTests.cs index bb6e3d1eb..03af800f3 100644 --- a/test/Grpc.Net.Client.Tests/Balancer/ResolverTests.cs +++ b/test/Grpc.Net.Client.Tests/Balancer/ResolverTests.cs @@ -111,6 +111,59 @@ protected override Task ResolveAsync(CancellationToken cancellationToken) } } + [Test] + public async Task Refresh_AsyncLocal_NotCaptured() + { + // Arrange + var services = new ServiceCollection(); + services.AddNUnitLogger(); + var loggerFactory = services.BuildServiceProvider().GetRequiredService(); + + var asyncLocal = new AsyncLocal(); + asyncLocal.Value = new object(); + + var callbackAsyncLocalValues = new List(); + + var resolver = new CallbackPollingResolver(loggerFactory, new TestBackoffPolicyFactory(TimeSpan.FromMilliseconds(100)), (listener) => + { + callbackAsyncLocalValues.Add(asyncLocal.Value); + if (callbackAsyncLocalValues.Count >= 2) + { + listener(ResolverResult.ForResult(new List())); + } + + return Task.CompletedTask; + }); + + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + resolver.Start(result => tcs.TrySetResult(result)); + + // Act + resolver.Refresh(); + + // Assert + await tcs.Task.DefaultTimeout(); + + Assert.AreEqual(2, callbackAsyncLocalValues.Count); + Assert.IsNull(callbackAsyncLocalValues[0]); + Assert.IsNull(callbackAsyncLocalValues[1]); + } + + private class CallbackPollingResolver : PollingResolver + { + private readonly Func, Task> _callback; + + public CallbackPollingResolver(ILoggerFactory loggerFactory, IBackoffPolicyFactory backoffPolicyFactory, Func, Task> callback) : base(loggerFactory, backoffPolicyFactory) + { + _callback = callback; + } + + protected override Task ResolveAsync(CancellationToken cancellationToken) + { + return _callback(Listener); + } + } + [Test] public async Task Resolver_ResolveNameFromServices_Success() {