Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update context propagation to use call static delegates #1216

Merged
merged 4 commits into from
Feb 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Directory.Build.props
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
<WarningsNotAsErrors>$(WarningsNotAsErrors);CS1591</WarningsNotAsErrors>

<EmbedUntrackedSources>true</EmbedUntrackedSources>
<LangVersion>8.0</LangVersion>
<LangVersion>9.0</LangVersion>
<Nullable>enable</Nullable>
</PropertyGroup>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
using System;
using System.Diagnostics.CodeAnalysis;
using System.Threading;
using System.Threading.Tasks;
using Grpc.AspNetCore.Server;
using Grpc.Core;
using Grpc.Core.Interceptors;
Expand Down Expand Up @@ -56,12 +57,13 @@ public override AsyncClientStreamingCall<TRequest, TResponse> AsyncClientStreami
else
{
return new AsyncClientStreamingCall<TRequest, TResponse>(
call.RequestStream,
call.ResponseAsync,
call.ResponseHeadersAsync,
call.GetStatus,
call.GetTrailers,
() => { call.Dispose(); cts.Dispose(); });
requestStream: call.RequestStream,
responseAsync: call.ResponseAsync,
responseHeadersAsync: ClientStreamingCallbacks<TRequest, TResponse>.GetResponseHeadersAsync,
getStatusFunc: ClientStreamingCallbacks<TRequest, TResponse>.GetStatus,
getTrailersFunc: ClientStreamingCallbacks<TRequest, TResponse>.GetTrailers,
disposeAction: ClientStreamingCallbacks<TRequest, TResponse>.Dispose,
CreateContextState(call, cts));
}
}

Expand All @@ -75,12 +77,13 @@ public override AsyncDuplexStreamingCall<TRequest, TResponse> AsyncDuplexStreami
else
{
return new AsyncDuplexStreamingCall<TRequest, TResponse>(
call.RequestStream,
call.ResponseStream,
call.ResponseHeadersAsync,
call.GetStatus,
call.GetTrailers,
() => { call.Dispose(); cts.Dispose(); });
requestStream: call.RequestStream,
responseStream: call.ResponseStream,
responseHeadersAsync: DuplexStreamingCallbacks<TRequest, TResponse>.GetResponseHeadersAsync,
getStatusFunc: DuplexStreamingCallbacks<TRequest, TResponse>.GetStatus,
getTrailersFunc: DuplexStreamingCallbacks<TRequest, TResponse>.GetTrailers,
disposeAction: DuplexStreamingCallbacks<TRequest, TResponse>.Dispose,
CreateContextState(call, cts));
}
}

Expand All @@ -94,11 +97,12 @@ public override AsyncServerStreamingCall<TResponse> AsyncServerStreamingCall<TRe
else
{
return new AsyncServerStreamingCall<TResponse>(
call.ResponseStream,
call.ResponseHeadersAsync,
call.GetStatus,
call.GetTrailers,
() => { call.Dispose(); cts.Dispose(); });
responseStream: call.ResponseStream,
responseHeadersAsync: ServerStreamingCallbacks<TResponse>.GetResponseHeadersAsync,
getStatusFunc: ServerStreamingCallbacks<TResponse>.GetStatus,
getTrailersFunc: ServerStreamingCallbacks<TResponse>.GetTrailers,
disposeAction: ServerStreamingCallbacks<TResponse>.Dispose,
CreateContextState(call, cts));
}
}

Expand All @@ -112,11 +116,12 @@ public override AsyncUnaryCall<TResponse> AsyncUnaryCall<TRequest, TResponse>(TR
else
{
return new AsyncUnaryCall<TResponse>(
call.ResponseAsync,
call.ResponseHeadersAsync,
call.GetStatus,
call.GetTrailers,
() => { call.Dispose(); cts.Dispose(); });
responseAsync: call.ResponseAsync,
responseHeadersAsync: UnaryCallbacks<TResponse>.GetResponseHeadersAsync,
getStatusFunc: UnaryCallbacks<TResponse>.GetStatus,
getTrailersFunc: UnaryCallbacks<TResponse>.GetTrailers,
disposeAction: UnaryCallbacks<TResponse>.Dispose,
CreateContextState(call, cts));
}
}

Expand Down Expand Up @@ -192,6 +197,27 @@ private bool TryGetServerCallContext([NotNullWhen(true)]out ServerCallContext? s
return true;
}

private ContextState<TCall> CreateContextState<TCall>(TCall call, CancellationTokenSource cancellationTokenSource) where TCall : IDisposable =>
new ContextState<TCall>(call, cancellationTokenSource);

private class ContextState<TCall> : IDisposable where TCall : IDisposable
{
public ContextState(TCall call, CancellationTokenSource cancellationTokenSource)
{
Call = call;
CancellationTokenSource = cancellationTokenSource;
}

public TCall Call { get; }
public CancellationTokenSource CancellationTokenSource { get; }

public void Dispose()
{
Call.Dispose();
CancellationTokenSource.Dispose();
}
}

private static class Log
{
private static readonly Action<ILogger, string, Exception?> _propagateServerCallContextFailure =
Expand All @@ -202,5 +228,44 @@ public static void PropagateServerCallContextFailure(ILogger logger, string erro
_propagateServerCallContextFailure(logger, errorMessage, null);
}
}

// Store static callbacks so delegates are allocated once
private static class UnaryCallbacks<TResponse>
where TResponse : class
{
internal static readonly Func<object, Task<Metadata>> GetResponseHeadersAsync = state => ((ContextState<AsyncUnaryCall<TResponse>>)state).Call.ResponseHeadersAsync;
internal static readonly Func<object, Status> GetStatus = state => ((ContextState<AsyncUnaryCall<TResponse>>)state).Call.GetStatus();
internal static readonly Func<object, Metadata> GetTrailers = state => ((ContextState<AsyncUnaryCall<TResponse>>)state).Call.GetTrailers();
internal static readonly Action<object> Dispose = state => ((ContextState<AsyncUnaryCall<TResponse>>)state).Dispose();
}

private static class ServerStreamingCallbacks<TResponse>
where TResponse : class
{
internal static readonly Func<object, Task<Metadata>> GetResponseHeadersAsync = state => ((ContextState<AsyncServerStreamingCall<TResponse>>)state).Call.ResponseHeadersAsync;
internal static readonly Func<object, Status> GetStatus = state => ((ContextState<AsyncServerStreamingCall<TResponse>>)state).Call.GetStatus();
internal static readonly Func<object, Metadata> GetTrailers = state => ((ContextState<AsyncServerStreamingCall<TResponse>>)state).Call.GetTrailers();
internal static readonly Action<object> Dispose = state => ((ContextState<AsyncServerStreamingCall<TResponse>>)state).Dispose();
}

private static class DuplexStreamingCallbacks<TRequest, TResponse>
where TRequest : class
where TResponse : class
{
internal static readonly Func<object, Task<Metadata>> GetResponseHeadersAsync = state => ((ContextState<AsyncDuplexStreamingCall<TRequest, TResponse>>)state).Call.ResponseHeadersAsync;
internal static readonly Func<object, Status> GetStatus = state => ((ContextState<AsyncDuplexStreamingCall<TRequest, TResponse>>)state).Call.GetStatus();
internal static readonly Func<object, Metadata> GetTrailers = state => ((ContextState<AsyncDuplexStreamingCall<TRequest, TResponse>>)state).Call.GetTrailers();
internal static readonly Action<object> Dispose = state => ((ContextState<AsyncDuplexStreamingCall<TRequest, TResponse>>)state).Dispose();
}

private static class ClientStreamingCallbacks<TRequest, TResponse>
where TRequest : class
where TResponse : class
{
internal static readonly Func<object, Task<Metadata>> GetResponseHeadersAsync = state => ((ContextState<AsyncClientStreamingCall<TRequest, TResponse>>)state).Call.ResponseHeadersAsync;
internal static readonly Func<object, Status> GetStatus = state => ((ContextState<AsyncClientStreamingCall<TRequest, TResponse>>)state).Call.GetStatus();
internal static readonly Func<object, Metadata> GetTrailers = state => ((ContextState<AsyncClientStreamingCall<TRequest, TResponse>>)state).Call.GetTrailers();
internal static readonly Action<object> Dispose = state => ((ContextState<AsyncClientStreamingCall<TRequest, TResponse>>)state).Dispose();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ internal sealed class DefaultGrpcServiceActivator<
#if NET5_0
internal const DynamicallyAccessedMemberTypes ServiceAccessibility = DynamicallyAccessedMemberTypes.PublicConstructors;
#endif
private static readonly Lazy<ObjectFactory> _objectFactory = new Lazy<ObjectFactory>(() => ActivatorUtilities.CreateFactory(typeof(TGrpcService), Type.EmptyTypes));
private static readonly Lazy<ObjectFactory> _objectFactory = new Lazy<ObjectFactory>(static () => ActivatorUtilities.CreateFactory(typeof(TGrpcService), Type.EmptyTypes));

public GrpcActivatorHandle<TGrpcService> Create(IServiceProvider serviceProvider)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ namespace Grpc.Net.ClientFactory.Internal
// Should be registered as a singleton, so it that it can act as a cache for the Activator.
internal class DefaultClientActivator<TClient> where TClient : class
{
private readonly static Func<ObjectFactory> _createActivator = () => ActivatorUtilities.CreateFactory(typeof(TClient), new Type[] { typeof(CallInvoker), });
private readonly static Func<ObjectFactory> _createActivator = static () => ActivatorUtilities.CreateFactory(typeof(TClient), new Type[] { typeof(CallInvoker), });

private readonly IServiceProvider _services;
private ObjectFactory? _activator;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

using System;
using System.Linq;
using System.Net;
using System.Net.Http;
using System.Threading;
using System.Threading.Tasks;
Expand Down Expand Up @@ -95,6 +96,88 @@ public async Task CreateClient_ServerCallContextHasValues_PropogatedDeadlineAndC
Assert.AreEqual(cancellationToken, options.CancellationToken);
}

[TestCase(Canceller.Context)]
[TestCase(Canceller.User)]
public async Task CreateClient_ServerCallContextAndUserCancellationToken_PropogatedDeadlineAndCancellation(Canceller canceller)
{
// Arrange
var baseAddress = new Uri("http://localhost");
var deadline = DateTime.UtcNow.AddDays(1);
var contextCts = new CancellationTokenSource();
var userCts = new CancellationTokenSource();
var tcs = new TaskCompletionSource<object?>(TaskCreationOptions.RunContinuationsAsynchronously);

CallOptions options = default;

var handler = TestHttpMessageHandler.Create(async (r, token) =>
{
token.Register(() => tcs.SetCanceled());

await tcs.Task;

var streamContent = await ClientTestHelpers.CreateResponseContent(new HelloReply()).DefaultTimeout();
return ResponseUtils.CreateResponse(HttpStatusCode.OK, streamContent);
});

var services = new ServiceCollection();
services.AddOptions();
services.AddSingleton(CreateHttpContextAccessorWithServerCallContext(deadline, contextCts.Token));
services
.AddGrpcClient<Greeter.GreeterClient>(o =>
{
o.Address = baseAddress;
})
.EnableCallContextPropagation()
.AddInterceptor(() => new CallbackInterceptor(o => options = o))
.AddHttpMessageHandler(() => handler);

var serviceProvider = services.BuildServiceProvider(validateScopes: true);

var clientFactory = CreateGrpcClientFactory(serviceProvider);
var client = clientFactory.CreateClient<Greeter.GreeterClient>(nameof(Greeter.GreeterClient));

// Act
using var call = client.SayHelloAsync(new HelloRequest(), cancellationToken: userCts.Token);
var responseTask = call.ResponseAsync;

// Assert
Assert.AreEqual(deadline, options.Deadline);

// CancellationToken passed to call is a linked cancellation token.
// It's created from the context and user tokens.
Assert.AreNotEqual(contextCts.Token, options.CancellationToken);
Assert.AreNotEqual(userCts.Token, options.CancellationToken);
Assert.AreNotEqual(CancellationToken.None, options.CancellationToken);

Assert.IsFalse(responseTask.IsCompleted);

// Either CTS should cancel call.
switch (canceller)
{
case Canceller.Context:
contextCts.Cancel();
break;
case Canceller.User:
userCts.Cancel();
break;
}

var ex = await ExceptionAssert.ThrowsAsync<RpcException>(() => responseTask).DefaultTimeout();
Assert.AreEqual(StatusCode.Cancelled, ex.StatusCode);
ex = await ExceptionAssert.ThrowsAsync<RpcException>(() => call.ResponseHeadersAsync).DefaultTimeout();
Assert.AreEqual(StatusCode.Cancelled, ex.StatusCode);

Assert.AreEqual(StatusCode.Cancelled, call.GetStatus().StatusCode);
Assert.Throws<InvalidOperationException>(() => call.GetTrailers());
}

public enum Canceller
{
None,
Context,
User
}

[Test]
public async Task CreateClient_NoHttpContext_ThrowError()
{
Expand Down