Skip to content

Commit

Permalink
Fix memory leak when using call context propagation with cancellation…
Browse files Browse the repository at this point in the history
… token (#2421)
  • Loading branch information
JamesNK authored Apr 29, 2024
1 parent c9c902c commit 2d9df58
Show file tree
Hide file tree
Showing 2 changed files with 204 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
using Grpc.Core.Interceptors;
using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;

namespace Grpc.AspNetCore.ClientFactory;

Expand Down Expand Up @@ -53,14 +52,15 @@ public override AsyncClientStreamingCall<TRequest, TResponse> AsyncClientStreami
}
else
{
var state = CreateContextState(call, cts);
return new AsyncClientStreamingCall<TRequest, TResponse>(
requestStream: call.RequestStream,
responseAsync: call.ResponseAsync,
responseAsync: OnResponseAsync(call.ResponseAsync, state),
responseHeadersAsync: ClientStreamingCallbacks<TRequest, TResponse>.GetResponseHeadersAsync,
getStatusFunc: ClientStreamingCallbacks<TRequest, TResponse>.GetStatus,
getTrailersFunc: ClientStreamingCallbacks<TRequest, TResponse>.GetTrailers,
disposeAction: ClientStreamingCallbacks<TRequest, TResponse>.Dispose,
CreateContextState(call, cts));
state);
}
}

Expand All @@ -73,14 +73,15 @@ public override AsyncDuplexStreamingCall<TRequest, TResponse> AsyncDuplexStreami
}
else
{
var state = CreateContextState(call, cts);
return new AsyncDuplexStreamingCall<TRequest, TResponse>(
requestStream: call.RequestStream,
responseStream: call.ResponseStream,
responseStream: new ResponseStreamWrapper<TResponse>(call.ResponseStream, state),
responseHeadersAsync: DuplexStreamingCallbacks<TRequest, TResponse>.GetResponseHeadersAsync,
getStatusFunc: DuplexStreamingCallbacks<TRequest, TResponse>.GetStatus,
getTrailersFunc: DuplexStreamingCallbacks<TRequest, TResponse>.GetTrailers,
disposeAction: DuplexStreamingCallbacks<TRequest, TResponse>.Dispose,
CreateContextState(call, cts));
state);
}
}

Expand All @@ -93,13 +94,14 @@ public override AsyncServerStreamingCall<TResponse> AsyncServerStreamingCall<TRe
}
else
{
var state = CreateContextState(call, cts);
return new AsyncServerStreamingCall<TResponse>(
responseStream: call.ResponseStream,
responseStream: new ResponseStreamWrapper<TResponse>(call.ResponseStream, state),
responseHeadersAsync: ServerStreamingCallbacks<TResponse>.GetResponseHeadersAsync,
getStatusFunc: ServerStreamingCallbacks<TResponse>.GetStatus,
getTrailersFunc: ServerStreamingCallbacks<TResponse>.GetTrailers,
disposeAction: ServerStreamingCallbacks<TResponse>.Dispose,
CreateContextState(call, cts));
state);
}
}

Expand All @@ -112,13 +114,14 @@ public override AsyncUnaryCall<TResponse> AsyncUnaryCall<TRequest, TResponse>(TR
}
else
{
var state = CreateContextState(call, cts);
return new AsyncUnaryCall<TResponse>(
responseAsync: call.ResponseAsync,
responseAsync: OnResponseAsync(call.ResponseAsync, state),
responseHeadersAsync: UnaryCallbacks<TResponse>.GetResponseHeadersAsync,
getStatusFunc: UnaryCallbacks<TResponse>.GetStatus,
getTrailersFunc: UnaryCallbacks<TResponse>.GetTrailers,
disposeAction: UnaryCallbacks<TResponse>.Dispose,
CreateContextState(call, cts));
state);
}
}

Expand All @@ -129,6 +132,19 @@ public override TResponse BlockingUnaryCall<TRequest, TResponse>(TRequest reques
return response;
}

// Automatically dispose state after awaiting the response.
private static async Task<TResponse> OnResponseAsync<TResponse>(Task<TResponse> task, IDisposable state)
{
try
{
return await task.ConfigureAwait(false);
}
finally
{
state.Dispose();
}
}

private ClientInterceptorContext<TRequest, TResponse> ConfigureContext<TRequest, TResponse>(ClientInterceptorContext<TRequest, TResponse> context, out CancellationTokenSource? linkedCts)
where TRequest : class
where TResponse : class
Expand Down Expand Up @@ -197,7 +213,7 @@ private bool TryGetServerCallContext([NotNullWhen(true)] out ServerCallContext?
private ContextState<TCall> CreateContextState<TCall>(TCall call, CancellationTokenSource cancellationTokenSource) where TCall : IDisposable =>
new ContextState<TCall>(call, cancellationTokenSource);

private class ContextState<TCall> : IDisposable where TCall : IDisposable
private sealed class ContextState<TCall> : IDisposable where TCall : IDisposable
{
public ContextState(TCall call, CancellationTokenSource cancellationTokenSource)
{
Expand All @@ -215,6 +231,33 @@ public void Dispose()
}
}

// Automatically dispose state after reading to the end of the stream.
private sealed class ResponseStreamWrapper<TResponse> : IAsyncStreamReader<TResponse>
{
private readonly IAsyncStreamReader<TResponse> _inner;
private readonly IDisposable _state;
private bool _disposed;

public ResponseStreamWrapper(IAsyncStreamReader<TResponse> inner, IDisposable state)
{
_inner = inner;
_state = state;
}

public TResponse Current => _inner.Current;

public async Task<bool> MoveNext(CancellationToken cancellationToken)
{
var result = await _inner.MoveNext(cancellationToken);
if (!result && !_disposed)
{
_state.Dispose();
_disposed = true;
}
return result;
}
}

private static class Log
{
private static readonly Action<ILogger, string, Exception?> _propagateServerCallContextFailure =
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#region Copyright notice and license
#region Copyright notice and license

// Copyright 2019 The gRPC Authors
//
Expand All @@ -20,6 +20,7 @@
using Greet;
using Grpc.AspNetCore.Server.ClientFactory.Tests.TestObjects;
using Grpc.Core;
using Grpc.Core.Interceptors;
using Grpc.Net.ClientFactory;
using Grpc.Net.ClientFactory.Internal;
using Grpc.Tests.Shared;
Expand Down Expand Up @@ -91,6 +92,155 @@ public async Task CreateClient_ServerCallContextHasValues_PropogatedDeadlineAndC
Assert.AreEqual(cancellationToken, options.CancellationToken);
}

[Test]
public async Task CreateClient_Unary_ServerCallContextHasValues_StateDisposed()
{
// Arrange
var baseAddress = new Uri("http://localhost");
var deadline = DateTime.UtcNow.AddDays(1);
var cancellationToken = new CancellationTokenSource().Token;

var interceptor = new OnDisposedInterceptor();

var services = new ServiceCollection();
services.AddOptions();
services.AddSingleton(CreateHttpContextAccessorWithServerCallContext(deadline: deadline, cancellationToken: cancellationToken));
services
.AddGrpcClient<Greeter.GreeterClient>(o =>
{
o.Address = baseAddress;
})
.EnableCallContextPropagation()
.AddInterceptor(() => interceptor)
.ConfigurePrimaryHttpMessageHandler(() => ClientTestHelpers.CreateTestMessageHandler(new HelloReply()));

var serviceProvider = services.BuildServiceProvider(validateScopes: true);

var clientFactory = CreateGrpcClientFactory(serviceProvider);
var client = clientFactory.CreateClient<Greeter.GreeterClient>(nameof(Greeter.GreeterClient));

// Checking that token register calls don't build up on CTS and create a memory leak.
var cts = new CancellationTokenSource();

// Act
// Send calls in a different method so there is no chance that a stack reference
// to a gRPC call is still alive after calls are complete.
var response = await client.SayHelloAsync(new HelloRequest(), cancellationToken: cts.Token);

// Assert
Assert.IsTrue(interceptor.ContextDisposed);
}

[Test]
public async Task CreateClient_ServerStreaming_ServerCallContextHasValues_StateDisposed()
{
// Arrange
var baseAddress = new Uri("http://localhost");
var deadline = DateTime.UtcNow.AddDays(1);
var cancellationToken = new CancellationTokenSource().Token;

var interceptor = new OnDisposedInterceptor();

var services = new ServiceCollection();
services.AddOptions();
services.AddSingleton(CreateHttpContextAccessorWithServerCallContext(deadline: deadline, cancellationToken: cancellationToken));
services
.AddGrpcClient<Greeter.GreeterClient>(o =>
{
o.Address = baseAddress;
})
.EnableCallContextPropagation()
.AddInterceptor(() => interceptor)
.ConfigurePrimaryHttpMessageHandler(() => ClientTestHelpers.CreateTestMessageHandler(new HelloReply()));

var serviceProvider = services.BuildServiceProvider(validateScopes: true);

var clientFactory = CreateGrpcClientFactory(serviceProvider);
var client = clientFactory.CreateClient<Greeter.GreeterClient>(nameof(Greeter.GreeterClient));

// Checking that token register calls don't build up on CTS and create a memory leak.
var cts = new CancellationTokenSource();

// Act
// Send calls in a different method so there is no chance that a stack reference
// to a gRPC call is still alive after calls are complete.
var call = client.SayHellos(new HelloRequest(), cancellationToken: cts.Token);

Assert.IsTrue(await call.ResponseStream.MoveNext());
Assert.IsFalse(await call.ResponseStream.MoveNext());

// Assert
Assert.IsTrue(interceptor.ContextDisposed);
}

private sealed class OnDisposedInterceptor : Interceptor
{
public bool ContextDisposed { get; private set; }

public override TResponse BlockingUnaryCall<TRequest, TResponse>(TRequest request, ClientInterceptorContext<TRequest, TResponse> context, BlockingUnaryCallContinuation<TRequest, TResponse> continuation)
{
return continuation(request, context);
}

public override AsyncUnaryCall<TResponse> AsyncUnaryCall<TRequest, TResponse>(TRequest request, ClientInterceptorContext<TRequest, TResponse> context, AsyncUnaryCallContinuation<TRequest, TResponse> continuation)
{
var call = continuation(request, context);
return new AsyncUnaryCall<TResponse>(call.ResponseAsync,
call.ResponseHeadersAsync,
call.GetStatus,
call.GetTrailers,
() =>
{
call.Dispose();
ContextDisposed = true;
});
}

public override AsyncServerStreamingCall<TResponse> AsyncServerStreamingCall<TRequest, TResponse>(TRequest request, ClientInterceptorContext<TRequest, TResponse> context, AsyncServerStreamingCallContinuation<TRequest, TResponse> continuation)
{
var call = continuation(request, context);
return new AsyncServerStreamingCall<TResponse>(call.ResponseStream,
call.ResponseHeadersAsync,
call.GetStatus,
call.GetTrailers,
() =>
{
call.Dispose();
ContextDisposed = true;
});
}

public override AsyncClientStreamingCall<TRequest, TResponse> AsyncClientStreamingCall<TRequest, TResponse>(ClientInterceptorContext<TRequest, TResponse> context, AsyncClientStreamingCallContinuation<TRequest, TResponse> continuation)
{
var call = continuation(context);
return new AsyncClientStreamingCall<TRequest, TResponse>(call.RequestStream,
call.ResponseAsync,
call.ResponseHeadersAsync,
call.GetStatus,
call.GetTrailers,
() =>
{
call.Dispose();
ContextDisposed = true;
});
}

public override AsyncDuplexStreamingCall<TRequest, TResponse> AsyncDuplexStreamingCall<TRequest, TResponse>(ClientInterceptorContext<TRequest, TResponse> context, AsyncDuplexStreamingCallContinuation<TRequest, TResponse> continuation)
{
var call = continuation(context);
return new AsyncDuplexStreamingCall<TRequest, TResponse>(call.RequestStream,
call.ResponseStream,
call.ResponseHeadersAsync,
call.GetStatus,
call.GetTrailers,
() =>
{
call.Dispose();
ContextDisposed = true;
});
}
}

[TestCase(Canceller.Context)]
[TestCase(Canceller.User)]
public async Task CreateClient_ServerCallContextAndUserCancellationToken_PropogatedDeadlineAndCancellation(Canceller canceller)
Expand Down

0 comments on commit 2d9df58

Please sign in to comment.