Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve extensibility of gRPC invocation #669

Merged
merged 4 commits into from
Dec 10, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ public class CompressedUnaryClientBenchmark : UnaryClientBenchmarkBase
public CompressedUnaryClientBenchmark()
{
ResponseCompressionAlgorithm = TestCompressionProvider.Name;
CompressionProviders = new Dictionary<string, ICompressionProvider>
CompressionProviders = new List<ICompressionProvider>
{
[TestCompressionProvider.Name] = new TestCompressionProvider()
new TestCompressionProvider()
};
_compressionMetadata = new Metadata
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ namespace Grpc.AspNetCore.Microbenchmarks.Client
{
public class UnaryClientBenchmarkBase
{
protected Dictionary<string, ICompressionProvider>? CompressionProviders { get; set; }
protected List<ICompressionProvider>? CompressionProviders { get; set; }
protected string? ResponseCompressionAlgorithm { get; set; }

private Greeter.GreeterClient? _client;
Expand Down Expand Up @@ -69,7 +69,7 @@ public void GlobalSetup()
var channel = GrpcChannel.ForAddress("http://localhost", new GrpcChannelOptions
{
HttpClient = httpClient,
CompressionProviders = CompressionProviders?.Values?.ToList()
CompressionProviders = CompressionProviders
});

_client = new Greeter.GreeterClient(channel);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ public class CompressedUnaryServerCallHandlerBenchmark : UnaryServerCallHandlerB
public CompressedUnaryServerCallHandlerBenchmark()
{
ResponseCompressionAlgorithm = TestCompressionProvider.Name;
CompressionProviders = new Dictionary<string, ICompressionProvider>
CompressionProviders = new List<ICompressionProvider>
{
[TestCompressionProvider.Name] = new TestCompressionProvider()
new TestCompressionProvider()
};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,10 @@
using Grpc.AspNetCore.Server;
using Grpc.AspNetCore.Server.Internal;
using Grpc.AspNetCore.Server.Internal.CallHandlers;
using Grpc.AspNetCore.Server.Model;
using Grpc.Core;
using Grpc.Net.Compression;
using Grpc.Shared.Server;
using Grpc.Tests.Shared;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.Features;
Expand All @@ -52,7 +54,7 @@ public class UnaryServerCallHandlerBenchmarkBase
private TestPipeReader? _requestPipe;

protected InterceptorCollection? Interceptors { get; set; }
protected Dictionary<string, ICompressionProvider>? CompressionProviders { get; set; }
protected List<ICompressionProvider>? CompressionProviders { get; set; }
protected string? ResponseCompressionAlgorithm { get; set; }

[GlobalSetup]
Expand All @@ -77,15 +79,15 @@ public void GlobalSetup()
var method = new Method<ChatMessage, ChatMessage>(MethodType.Unary, typeof(TestService).FullName, nameof(TestService.SayHello), marshaller, marshaller);
var result = Task.FromResult(message);
_callHandler = new UnaryServerCallHandler<TestService, ChatMessage, ChatMessage>(
method,
(service, request, context) => result,
HttpContextServerCallContextHelper.CreateMethodContext(
compressionProviders: CompressionProviders,
responseCompressionAlgorithm: ResponseCompressionAlgorithm,
interceptors: Interceptors),
NullLoggerFactory.Instance,
new TestGrpcServiceActivator<TestService>(new TestService()),
serviceProvider);
new UnaryServerMethodInvoker<TestService, ChatMessage, ChatMessage>(
(service, request, context) => result,
method,
HttpContextServerCallContextHelper.CreateMethodOptions(
compressionProviders: CompressionProviders,
responseCompressionAlgorithm: ResponseCompressionAlgorithm,
interceptors: Interceptors),
new TestGrpcServiceActivator<TestService>(new TestService())),
NullLoggerFactory.Instance);

_trailers = new HeaderDictionary();

Expand Down
8 changes: 8 additions & 0 deletions src/Grpc.AspNetCore.Server/Grpc.AspNetCore.Server.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,14 @@

<ItemGroup>
<Compile Include="..\Shared\DefaultDeserializationContext.cs" Link="Internal\DefaultDeserializationContext.cs" />
<Compile Include="..\Shared\Server\BindMethodFinder.cs" Link="Model\Internal\BindMethodFinder.cs" />
<Compile Include="..\Shared\Server\ClientStreamingServerMethodInvoker.cs" Link="Model\Internal\ClientStreamingServerMethodInvoker.cs" />
<Compile Include="..\Shared\Server\DuplexStreamingServerMethodInvoker.cs" Link="Model\Internal\DuplexStreamingServerMethodInvoker.cs" />
<Compile Include="..\Shared\Server\InterceptorPipelineBuilder.cs" Link="Model\Internal\InterceptorPipelineBuilder.cs" />
<Compile Include="..\Shared\Server\MethodOptions.cs" Link="Model\Internal\MethodOptions.cs" />
<Compile Include="..\Shared\Server\ServerMethodInvokerBase.cs" Link="Model\Internal\ServerMethodInvokerBase.cs" />
<Compile Include="..\Shared\Server\ServerStreamingServerMethodInvoker.cs" Link="Model\Internal\ServerStreamingServerMethodInvoker.cs" />
<Compile Include="..\Shared\Server\UnaryServerMethodInvoker.cs" Link="Model\Internal\UnaryServerMethodInvoker.cs" />
</ItemGroup>

<ItemGroup>
Expand Down
4 changes: 0 additions & 4 deletions src/Grpc.AspNetCore.Server/InterceptorCollection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,6 @@ namespace Grpc.AspNetCore.Server
/// </summary>
public class InterceptorCollection : Collection<InterceptorRegistration>
{
internal InterceptorCollection()
{
}

/// <summary>
/// Add an interceptor to the end of the pipeline.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,10 @@

#endregion

using System;
using System.Threading.Tasks;
using Grpc.AspNetCore.Server.Model;
using Grpc.Core;
using Grpc.Shared.Server;
using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;

namespace Grpc.AspNetCore.Server.Internal.CallHandlers
Expand All @@ -31,79 +29,23 @@ internal class ClientStreamingServerCallHandler<TService, TRequest, TResponse> :
where TResponse : class
where TService : class
{
private readonly ClientStreamingServerMethod<TService, TRequest, TResponse> _invoker;
private readonly ClientStreamingServerMethod<TRequest, TResponse>? _pipelineInvoker;
private readonly ClientStreamingServerMethodInvoker<TService, TRequest, TResponse> _invoker;

public ClientStreamingServerCallHandler(
Method<TRequest, TResponse> method,
ClientStreamingServerMethod<TService, TRequest, TResponse> invoker,
MethodContext methodContext,
ILoggerFactory loggerFactory,
IGrpcServiceActivator<TService> serviceActivator,
IServiceProvider serviceProvider)
: base(method, methodContext, loggerFactory, serviceActivator, serviceProvider)
ClientStreamingServerMethodInvoker<TService, TRequest, TResponse> invoker,
ILoggerFactory loggerFactory)
: base(invoker, loggerFactory)
{
_invoker = invoker;

if (MethodContext.HasInterceptors)
{
var interceptorPipeline = new InterceptorPipelineBuilder<TRequest, TResponse>(MethodContext.Interceptors, ServiceProvider);
_pipelineInvoker = interceptorPipeline.ClientStreamingPipeline(ResolvedInterceptorInvoker);
}
}

private async Task<TResponse> ResolvedInterceptorInvoker(IAsyncStreamReader<TRequest> resolvedRequestStream, ServerCallContext resolvedContext)
{
GrpcActivatorHandle<TService> serviceHandle = default;
try
{
serviceHandle = ServiceActivator.Create(resolvedContext.GetHttpContext().RequestServices);
return await _invoker(
serviceHandle.Instance,
resolvedRequestStream,
resolvedContext);
}
finally
{
if (serviceHandle.Instance != null)
{
await ServiceActivator.ReleaseAsync(serviceHandle);
}
}
}

protected override async Task HandleCallAsyncCore(HttpContext httpContext, HttpContextServerCallContext serverCallContext)
{
// Disable request body data rate for client streaming
DisableMinRequestBodyDataRateAndMaxRequestBodySize(httpContext);

TResponse? response = null;

if (_pipelineInvoker == null)
{
GrpcActivatorHandle<TService> serviceHandle = default;
try
{
serviceHandle = ServiceActivator.Create(httpContext.RequestServices);
response = await _invoker(
serviceHandle.Instance,
new HttpContextStreamReader<TRequest>(serverCallContext, Method.RequestMarshaller.ContextualDeserializer),
serverCallContext);
}
finally
{
if (serviceHandle.Instance != null)
{
await ServiceActivator.ReleaseAsync(serviceHandle);
}
}
}
else
{
response = await _pipelineInvoker(
new HttpContextStreamReader<TRequest>(serverCallContext, Method.RequestMarshaller.ContextualDeserializer),
serverCallContext);
}
var streamReader = new HttpContextStreamReader<TRequest>(serverCallContext, MethodInvoker.Method.RequestMarshaller.ContextualDeserializer);
var response = await _invoker.Invoke(httpContext, serverCallContext, streamReader);

if (response == null)
{
Expand All @@ -112,7 +54,7 @@ protected override async Task HandleCallAsyncCore(HttpContext httpContext, HttpC
}

var responseBodyWriter = httpContext.Response.BodyWriter;
await responseBodyWriter.WriteMessageAsync(response, serverCallContext, Method.ResponseMarshaller.ContextualSerializer, canFlush: false);
await responseBodyWriter.WriteMessageAsync(response, serverCallContext, MethodInvoker.Method.ResponseMarshaller.ContextualSerializer, canFlush: false);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
using System.Threading.Tasks;
using Grpc.AspNetCore.Server.Model;
using Grpc.Core;
using Grpc.Shared.Server;
using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
Expand All @@ -31,80 +32,25 @@ internal class DuplexStreamingServerCallHandler<TService, TRequest, TResponse> :
where TResponse : class
where TService : class
{
private readonly DuplexStreamingServerMethod<TService, TRequest, TResponse> _invoker;
private readonly DuplexStreamingServerMethod<TRequest, TResponse>? _pipelineInvoker;
private readonly DuplexStreamingServerMethodInvoker<TService, TRequest, TResponse> _invoker;

public DuplexStreamingServerCallHandler(
Method<TRequest, TResponse> method,
DuplexStreamingServerMethod<TService, TRequest, TResponse> invoker,
MethodContext methodContext,
ILoggerFactory loggerFactory,
IGrpcServiceActivator<TService> serviceActivator,
IServiceProvider serviceProvider)
: base(method, methodContext, loggerFactory, serviceActivator, serviceProvider)
DuplexStreamingServerMethodInvoker<TService, TRequest, TResponse> invoker,
ILoggerFactory loggerFactory)
: base(invoker, loggerFactory)
{
_invoker = invoker;

if (MethodContext.HasInterceptors)
{
var interceptorPipeline = new InterceptorPipelineBuilder<TRequest, TResponse>(MethodContext.Interceptors, ServiceProvider);
_pipelineInvoker = interceptorPipeline.DuplexStreamingPipeline(ResolvedInterceptorInvoker);
}
}

private async Task ResolvedInterceptorInvoker(IAsyncStreamReader<TRequest> requestStream, IServerStreamWriter<TResponse> responseStream, ServerCallContext resolvedContext)
{
GrpcActivatorHandle<TService> serviceHandle = default;
try
{
serviceHandle = ServiceActivator.Create(resolvedContext.GetHttpContext().RequestServices);
await _invoker(
serviceHandle.Instance,
requestStream,
responseStream,
resolvedContext);
}
finally
{
if (serviceHandle.Instance != null)
{
await ServiceActivator.ReleaseAsync(serviceHandle);
}
}
}

protected override async Task HandleCallAsyncCore(HttpContext httpContext, HttpContextServerCallContext serverCallContext)
protected override Task HandleCallAsyncCore(HttpContext httpContext, HttpContextServerCallContext serverCallContext)
{
// Disable request body data rate for client streaming
DisableMinRequestBodyDataRateAndMaxRequestBodySize(httpContext);

if (_pipelineInvoker == null)
{
GrpcActivatorHandle<TService> serviceHandle = default;
try
{
serviceHandle = ServiceActivator.Create(httpContext.RequestServices);
await _invoker(
serviceHandle.Instance,
new HttpContextStreamReader<TRequest>(serverCallContext, Method.RequestMarshaller.ContextualDeserializer),
new HttpContextStreamWriter<TResponse>(serverCallContext, Method.ResponseMarshaller.ContextualSerializer),
serverCallContext);
}
finally
{
if (serviceHandle.Instance != null)
{
await ServiceActivator.ReleaseAsync(serviceHandle);
}
}
}
else
{
await _pipelineInvoker(
new HttpContextStreamReader<TRequest>(serverCallContext, Method.RequestMarshaller.ContextualDeserializer),
new HttpContextStreamWriter<TResponse>(serverCallContext, Method.ResponseMarshaller.ContextualSerializer),
serverCallContext);
}
var streamReader = new HttpContextStreamReader<TRequest>(serverCallContext, MethodInvoker.Method.RequestMarshaller.ContextualDeserializer);
var streamWriter = new HttpContextStreamWriter<TResponse>(serverCallContext, MethodInvoker.Method.ResponseMarshaller.ContextualSerializer);

return _invoker.Invoke(httpContext, serverCallContext, streamReader, streamWriter);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@

using System;
using System.Threading.Tasks;
using Grpc.AspNetCore.Server.Model;
using Grpc.Core;
using Grpc.Shared.Server;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.Features;
using Microsoft.AspNetCore.Server.Kestrel.Core.Features;
Expand All @@ -34,23 +36,14 @@ internal abstract class ServerCallHandlerBase<TService, TRequest, TResponse>
{
private const string LoggerName = "Grpc.AspNetCore.Server.ServerCallHandler";

protected Method<TRequest, TResponse> Method { get; }
protected MethodContext MethodContext { get; }
protected IGrpcServiceActivator<TService> ServiceActivator { get; }
protected IServiceProvider ServiceProvider { get; }
protected ServerMethodInvokerBase<TService, TRequest, TResponse> MethodInvoker { get; }
protected ILogger Logger { get; }

protected ServerCallHandlerBase(
Method<TRequest, TResponse> method,
MethodContext methodContext,
ILoggerFactory loggerFactory,
IGrpcServiceActivator<TService> serviceActivator,
IServiceProvider serviceProvider)
ServerMethodInvokerBase<TService, TRequest, TResponse> methodInvoker,
ILoggerFactory loggerFactory)
{
Method = method;
MethodContext = methodContext;
ServiceActivator = serviceActivator;
ServiceProvider = serviceProvider;
MethodInvoker = methodInvoker;
Logger = loggerFactory.CreateLogger(LoggerName);
}

Expand All @@ -74,7 +67,7 @@ public Task HandleCallAsync(HttpContext httpContext)
return Task.CompletedTask;
}

var serverCallContext = new HttpContextServerCallContext(httpContext, MethodContext, Logger);
var serverCallContext = new HttpContextServerCallContext(httpContext, MethodInvoker.Options, typeof(TRequest), typeof(TResponse), Logger);
httpContext.Features.Set<IServerCallContextFeature>(serverCallContext);

GrpcProtocolHelpers.AddProtocolHeaders(httpContext.Response);
Expand All @@ -91,12 +84,12 @@ public Task HandleCallAsync(HttpContext httpContext)
}
else
{
return AwaitHandleCall(serverCallContext, Method, handleCallTask);
return AwaitHandleCall(serverCallContext, MethodInvoker.Method, handleCallTask);
}
}
catch (Exception ex)
{
return serverCallContext.ProcessHandlerErrorAsync(ex, Method.Name);
return serverCallContext.ProcessHandlerErrorAsync(ex, MethodInvoker.Method.Name);
}

static async Task AwaitHandleCall(HttpContextServerCallContext serverCallContext, Method<TRequest, TResponse> method, Task handleCall)
Expand Down
Loading