diff --git a/Directory.Build.props b/Directory.Build.props index 0fb0ae78c..0434644c6 100644 --- a/Directory.Build.props +++ b/Directory.Build.props @@ -18,7 +18,7 @@ $(WarningsNotAsErrors);CS1591 true - 8.0 + 9.0 enable diff --git a/src/Grpc.AspNetCore.Server.ClientFactory/ContextPropagationInterceptor.cs b/src/Grpc.AspNetCore.Server.ClientFactory/ContextPropagationInterceptor.cs index becc177a6..891ac27b2 100644 --- a/src/Grpc.AspNetCore.Server.ClientFactory/ContextPropagationInterceptor.cs +++ b/src/Grpc.AspNetCore.Server.ClientFactory/ContextPropagationInterceptor.cs @@ -19,6 +19,7 @@ using System; using System.Diagnostics.CodeAnalysis; using System.Threading; +using System.Threading.Tasks; using Grpc.AspNetCore.Server; using Grpc.Core; using Grpc.Core.Interceptors; @@ -56,12 +57,13 @@ public override AsyncClientStreamingCall AsyncClientStreami else { return new AsyncClientStreamingCall( - call.RequestStream, - call.ResponseAsync, - call.ResponseHeadersAsync, - call.GetStatus, - call.GetTrailers, - () => { call.Dispose(); cts.Dispose(); }); + requestStream: call.RequestStream, + responseAsync: call.ResponseAsync, + responseHeadersAsync: ClientStreamingCallbacks.GetResponseHeadersAsync, + getStatusFunc: ClientStreamingCallbacks.GetStatus, + getTrailersFunc: ClientStreamingCallbacks.GetTrailers, + disposeAction: ClientStreamingCallbacks.Dispose, + CreateContextState(call, cts)); } } @@ -75,12 +77,13 @@ public override AsyncDuplexStreamingCall AsyncDuplexStreami else { return new AsyncDuplexStreamingCall( - call.RequestStream, - call.ResponseStream, - call.ResponseHeadersAsync, - call.GetStatus, - call.GetTrailers, - () => { call.Dispose(); cts.Dispose(); }); + requestStream: call.RequestStream, + responseStream: call.ResponseStream, + responseHeadersAsync: DuplexStreamingCallbacks.GetResponseHeadersAsync, + getStatusFunc: DuplexStreamingCallbacks.GetStatus, + getTrailersFunc: DuplexStreamingCallbacks.GetTrailers, + disposeAction: DuplexStreamingCallbacks.Dispose, + CreateContextState(call, cts)); } } @@ -94,11 +97,12 @@ public override AsyncServerStreamingCall AsyncServerStreamingCall( - call.ResponseStream, - call.ResponseHeadersAsync, - call.GetStatus, - call.GetTrailers, - () => { call.Dispose(); cts.Dispose(); }); + responseStream: call.ResponseStream, + responseHeadersAsync: ServerStreamingCallbacks.GetResponseHeadersAsync, + getStatusFunc: ServerStreamingCallbacks.GetStatus, + getTrailersFunc: ServerStreamingCallbacks.GetTrailers, + disposeAction: ServerStreamingCallbacks.Dispose, + CreateContextState(call, cts)); } } @@ -112,11 +116,12 @@ public override AsyncUnaryCall AsyncUnaryCall(TR else { return new AsyncUnaryCall( - call.ResponseAsync, - call.ResponseHeadersAsync, - call.GetStatus, - call.GetTrailers, - () => { call.Dispose(); cts.Dispose(); }); + responseAsync: call.ResponseAsync, + responseHeadersAsync: UnaryCallbacks.GetResponseHeadersAsync, + getStatusFunc: UnaryCallbacks.GetStatus, + getTrailersFunc: UnaryCallbacks.GetTrailers, + disposeAction: UnaryCallbacks.Dispose, + CreateContextState(call, cts)); } } @@ -192,6 +197,27 @@ private bool TryGetServerCallContext([NotNullWhen(true)]out ServerCallContext? s return true; } + private ContextState CreateContextState(TCall call, CancellationTokenSource cancellationTokenSource) where TCall : IDisposable => + new ContextState(call, cancellationTokenSource); + + private class ContextState : IDisposable where TCall : IDisposable + { + public ContextState(TCall call, CancellationTokenSource cancellationTokenSource) + { + Call = call; + CancellationTokenSource = cancellationTokenSource; + } + + public TCall Call { get; } + public CancellationTokenSource CancellationTokenSource { get; } + + public void Dispose() + { + Call.Dispose(); + CancellationTokenSource.Dispose(); + } + } + private static class Log { private static readonly Action _propagateServerCallContextFailure = @@ -202,5 +228,44 @@ public static void PropagateServerCallContextFailure(ILogger logger, string erro _propagateServerCallContextFailure(logger, errorMessage, null); } } + + // Store static callbacks so delegates are allocated once + private static class UnaryCallbacks + where TResponse : class + { + internal static readonly Func> GetResponseHeadersAsync = state => ((ContextState>)state).Call.ResponseHeadersAsync; + internal static readonly Func GetStatus = state => ((ContextState>)state).Call.GetStatus(); + internal static readonly Func GetTrailers = state => ((ContextState>)state).Call.GetTrailers(); + internal static readonly Action Dispose = state => ((ContextState>)state).Dispose(); + } + + private static class ServerStreamingCallbacks + where TResponse : class + { + internal static readonly Func> GetResponseHeadersAsync = state => ((ContextState>)state).Call.ResponseHeadersAsync; + internal static readonly Func GetStatus = state => ((ContextState>)state).Call.GetStatus(); + internal static readonly Func GetTrailers = state => ((ContextState>)state).Call.GetTrailers(); + internal static readonly Action Dispose = state => ((ContextState>)state).Dispose(); + } + + private static class DuplexStreamingCallbacks + where TRequest : class + where TResponse : class + { + internal static readonly Func> GetResponseHeadersAsync = state => ((ContextState>)state).Call.ResponseHeadersAsync; + internal static readonly Func GetStatus = state => ((ContextState>)state).Call.GetStatus(); + internal static readonly Func GetTrailers = state => ((ContextState>)state).Call.GetTrailers(); + internal static readonly Action Dispose = state => ((ContextState>)state).Dispose(); + } + + private static class ClientStreamingCallbacks + where TRequest : class + where TResponse : class + { + internal static readonly Func> GetResponseHeadersAsync = state => ((ContextState>)state).Call.ResponseHeadersAsync; + internal static readonly Func GetStatus = state => ((ContextState>)state).Call.GetStatus(); + internal static readonly Func GetTrailers = state => ((ContextState>)state).Call.GetTrailers(); + internal static readonly Action Dispose = state => ((ContextState>)state).Dispose(); + } } } diff --git a/src/Grpc.AspNetCore.Server/Internal/DefaultGrpcServiceActivator.cs b/src/Grpc.AspNetCore.Server/Internal/DefaultGrpcServiceActivator.cs index af4dc2847..103c404cc 100644 --- a/src/Grpc.AspNetCore.Server/Internal/DefaultGrpcServiceActivator.cs +++ b/src/Grpc.AspNetCore.Server/Internal/DefaultGrpcServiceActivator.cs @@ -32,7 +32,7 @@ internal sealed class DefaultGrpcServiceActivator< #if NET5_0 internal const DynamicallyAccessedMemberTypes ServiceAccessibility = DynamicallyAccessedMemberTypes.PublicConstructors; #endif - private static readonly Lazy _objectFactory = new Lazy(() => ActivatorUtilities.CreateFactory(typeof(TGrpcService), Type.EmptyTypes)); + private static readonly Lazy _objectFactory = new Lazy(static () => ActivatorUtilities.CreateFactory(typeof(TGrpcService), Type.EmptyTypes)); public GrpcActivatorHandle Create(IServiceProvider serviceProvider) { diff --git a/src/Grpc.Net.ClientFactory/Internal/DefaultClientActivator.cs b/src/Grpc.Net.ClientFactory/Internal/DefaultClientActivator.cs index 78f51621b..17aa35b48 100644 --- a/src/Grpc.Net.ClientFactory/Internal/DefaultClientActivator.cs +++ b/src/Grpc.Net.ClientFactory/Internal/DefaultClientActivator.cs @@ -26,7 +26,7 @@ namespace Grpc.Net.ClientFactory.Internal // Should be registered as a singleton, so it that it can act as a cache for the Activator. internal class DefaultClientActivator where TClient : class { - private readonly static Func _createActivator = () => ActivatorUtilities.CreateFactory(typeof(TClient), new Type[] { typeof(CallInvoker), }); + private readonly static Func _createActivator = static () => ActivatorUtilities.CreateFactory(typeof(TClient), new Type[] { typeof(CallInvoker), }); private readonly IServiceProvider _services; private ObjectFactory? _activator; diff --git a/test/Grpc.AspNetCore.Server.ClientFactory.Tests/DefaultGrpcClientFactoryTests.cs b/test/Grpc.AspNetCore.Server.ClientFactory.Tests/DefaultGrpcClientFactoryTests.cs index e15d6ffcb..07c875861 100644 --- a/test/Grpc.AspNetCore.Server.ClientFactory.Tests/DefaultGrpcClientFactoryTests.cs +++ b/test/Grpc.AspNetCore.Server.ClientFactory.Tests/DefaultGrpcClientFactoryTests.cs @@ -18,6 +18,7 @@ using System; using System.Linq; +using System.Net; using System.Net.Http; using System.Threading; using System.Threading.Tasks; @@ -95,6 +96,88 @@ public async Task CreateClient_ServerCallContextHasValues_PropogatedDeadlineAndC Assert.AreEqual(cancellationToken, options.CancellationToken); } + [TestCase(Canceller.Context)] + [TestCase(Canceller.User)] + public async Task CreateClient_ServerCallContextAndUserCancellationToken_PropogatedDeadlineAndCancellation(Canceller canceller) + { + // Arrange + var baseAddress = new Uri("http://localhost"); + var deadline = DateTime.UtcNow.AddDays(1); + var contextCts = new CancellationTokenSource(); + var userCts = new CancellationTokenSource(); + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + CallOptions options = default; + + var handler = TestHttpMessageHandler.Create(async (r, token) => + { + token.Register(() => tcs.SetCanceled()); + + await tcs.Task; + + var streamContent = await ClientTestHelpers.CreateResponseContent(new HelloReply()).DefaultTimeout(); + return ResponseUtils.CreateResponse(HttpStatusCode.OK, streamContent); + }); + + var services = new ServiceCollection(); + services.AddOptions(); + services.AddSingleton(CreateHttpContextAccessorWithServerCallContext(deadline, contextCts.Token)); + services + .AddGrpcClient(o => + { + o.Address = baseAddress; + }) + .EnableCallContextPropagation() + .AddInterceptor(() => new CallbackInterceptor(o => options = o)) + .AddHttpMessageHandler(() => handler); + + var serviceProvider = services.BuildServiceProvider(validateScopes: true); + + var clientFactory = CreateGrpcClientFactory(serviceProvider); + var client = clientFactory.CreateClient(nameof(Greeter.GreeterClient)); + + // Act + using var call = client.SayHelloAsync(new HelloRequest(), cancellationToken: userCts.Token); + var responseTask = call.ResponseAsync; + + // Assert + Assert.AreEqual(deadline, options.Deadline); + + // CancellationToken passed to call is a linked cancellation token. + // It's created from the context and user tokens. + Assert.AreNotEqual(contextCts.Token, options.CancellationToken); + Assert.AreNotEqual(userCts.Token, options.CancellationToken); + Assert.AreNotEqual(CancellationToken.None, options.CancellationToken); + + Assert.IsFalse(responseTask.IsCompleted); + + // Either CTS should cancel call. + switch (canceller) + { + case Canceller.Context: + contextCts.Cancel(); + break; + case Canceller.User: + userCts.Cancel(); + break; + } + + var ex = await ExceptionAssert.ThrowsAsync(() => responseTask).DefaultTimeout(); + Assert.AreEqual(StatusCode.Cancelled, ex.StatusCode); + ex = await ExceptionAssert.ThrowsAsync(() => call.ResponseHeadersAsync).DefaultTimeout(); + Assert.AreEqual(StatusCode.Cancelled, ex.StatusCode); + + Assert.AreEqual(StatusCode.Cancelled, call.GetStatus().StatusCode); + Assert.Throws(() => call.GetTrailers()); + } + + public enum Canceller + { + None, + Context, + User + } + [Test] public async Task CreateClient_NoHttpContext_ThrowError() {