Skip to content

Commit

Permalink
Fix response stream errors not updating call status (#1198)
Browse files Browse the repository at this point in the history
  • Loading branch information
JamesNK authored Feb 9, 2021
1 parent 4a8a120 commit 9c3385a
Show file tree
Hide file tree
Showing 6 changed files with 280 additions and 25 deletions.
14 changes: 9 additions & 5 deletions src/Grpc.Net.Client/Internal/GrpcCall.cs
Original file line number Diff line number Diff line change
Expand Up @@ -206,12 +206,16 @@ private void FinishResponseAndCleanUp(Status status)
/// Used by response stream reader to report it is finished.
/// </summary>
/// <param name="status">The completed response status code.</param>
public void ResponseStreamEnded(Status status)
/// <param name="finishedGracefully">true when the end of the response stream was read, otherwise false.</param>
public void ResponseStreamEnded(Status status, bool finishedGracefully)
{
// Set response finished immediately rather than set it in logic resumed
// from the callTcs to avoid race condition.
// e.g. response stream finished and then immediately call GetTrailers().
ResponseFinished = true;
if (finishedGracefully)
{
// Set response finished immediately rather than set it in logic resumed
// from the callTcs to avoid race condition.
// e.g. response stream finished and then immediately call GetTrailers().
ResponseFinished = true;
}

_callTcs.TrySetResult(status);
}
Expand Down
36 changes: 30 additions & 6 deletions src/Grpc.Net.Client/Internal/HttpContentClientStreamReader.cs
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ private async Task<bool> MoveNextCore(CancellationToken cancellationToken)
// No more content in response so report status to call.
// The call will handle finishing the response.
var status = GrpcProtocolHelpers.GetResponseStatus(_httpResponse, _call.Channel.OperatingSystem.IsBrowser);
_call.ResponseStreamEnded(status);
_call.ResponseStreamEnded(status, finishedGracefully: true);
if (status.StatusCode != StatusCode.OK)
{
throw _call.CreateFailureStatusException(status);
Expand All @@ -180,7 +180,7 @@ private async Task<bool> MoveNextCore(CancellationToken cancellationToken)
GrpcEventSource.Log.MessageReceived();
return true;
}
catch (OperationCanceledException) when (!_call.Channel.ThrowOperationCanceledOnCancellation)
catch (OperationCanceledException ex)
{
if (_call.ResponseFinished)
{
Expand All @@ -192,13 +192,37 @@ private async Task<bool> MoveNextCore(CancellationToken cancellationToken)
return false;
}
}
else
{
_call.ResponseStreamEnded(new Status(StatusCode.Cancelled, ex.Message, ex), finishedGracefully: false);
}

throw _call.CreateCanceledStatusException();
if (!_call.Channel.ThrowOperationCanceledOnCancellation)
{
throw _call.CreateCanceledStatusException();
}
else
{
throw;
}
}
catch (Exception ex) when (_call.ResolveException("Error reading next message.", ex, out _, out var resolvedException))
catch (Exception ex)
{
// Throw RpcException from MoveNext. Consistent with Grpc.Core.
throw resolvedException;
var newException = _call.ResolveException("Error reading next message.", ex, out var status, out var resolvedException);
if (!_call.ResponseFinished)
{
_call.ResponseStreamEnded(status.Value, finishedGracefully: false);
}

if (newException)
{
// Throw RpcException from MoveNext. Consistent with Grpc.Core.
throw resolvedException;
}
else
{
throw;
}
}
finally
{
Expand Down
131 changes: 131 additions & 0 deletions test/Grpc.Net.Client.Tests/AsyncServerStreamingCallTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,137 @@ await streamContent.AddDataAndWait(await ClientTestHelpers.GetResponseDataAsync(
Assert.IsFalse(await moveNextTask3.DefaultTimeout());
}

[Test]
public async Task AsyncServerStreamingCall_MessagesStreamedThenError_ErrorStatus()
{
// Arrange
var streamContent = new SyncPointMemoryStream();

var httpClient = ClientTestHelpers.CreateTestClient(request =>
{
return Task.FromResult(ResponseUtils.CreateResponse(HttpStatusCode.OK, new StreamContent(streamContent)));
});
var invoker = HttpClientCallInvokerFactory.Create(httpClient);

// Act
var call = invoker.AsyncServerStreamingCall<HelloRequest, HelloReply>(ClientTestHelpers.ServiceMethod, string.Empty, new CallOptions(), new HelloRequest());

var responseStream = call.ResponseStream;

// Assert
Assert.IsNull(responseStream.Current);

var moveNextTask1 = responseStream.MoveNext(CancellationToken.None);
Assert.IsFalse(moveNextTask1.IsCompleted);

await streamContent.AddDataAndWait(await ClientTestHelpers.GetResponseDataAsync(new HelloReply
{
Message = "Hello world 1"
}).DefaultTimeout()).DefaultTimeout();

Assert.IsTrue(await moveNextTask1.DefaultTimeout());
Assert.IsNotNull(responseStream.Current);
Assert.AreEqual("Hello world 1", responseStream.Current.Message);

var moveNextTask2 = responseStream.MoveNext(CancellationToken.None);
Assert.IsFalse(moveNextTask2.IsCompleted);

await streamContent.AddExceptionAndWait(new Exception("Exception!")).DefaultTimeout();

var ex = await ExceptionAssert.ThrowsAsync<RpcException>(() => moveNextTask2).DefaultTimeout();
Assert.AreEqual(StatusCode.Internal, ex.StatusCode);
Assert.AreEqual(StatusCode.Internal, call.GetStatus().StatusCode);
Assert.AreEqual("Error reading next message. Exception: Exception!", call.GetStatus().Detail);
}

[Test]
public async Task AsyncServerStreamingCall_MessagesStreamedThenCancellation_ErrorStatus()
{
// Arrange
var streamContent = new SyncPointMemoryStream();

var httpClient = ClientTestHelpers.CreateTestClient(request =>
{
return Task.FromResult(ResponseUtils.CreateResponse(HttpStatusCode.OK, new StreamContent(streamContent)));
});
var invoker = HttpClientCallInvokerFactory.Create(httpClient);

// Act
var call = invoker.AsyncServerStreamingCall<HelloRequest, HelloReply>(ClientTestHelpers.ServiceMethod, string.Empty, new CallOptions(), new HelloRequest());

var responseStream = call.ResponseStream;

// Assert
Assert.IsNull(responseStream.Current);

var moveNextTask1 = responseStream.MoveNext(CancellationToken.None);
Assert.IsFalse(moveNextTask1.IsCompleted);

await streamContent.AddDataAndWait(await ClientTestHelpers.GetResponseDataAsync(new HelloReply
{
Message = "Hello world 1"
}).DefaultTimeout()).DefaultTimeout();

Assert.IsTrue(await moveNextTask1.DefaultTimeout());
Assert.IsNotNull(responseStream.Current);
Assert.AreEqual("Hello world 1", responseStream.Current.Message);

var cts = new CancellationTokenSource();

var moveNextTask2 = responseStream.MoveNext(cts.Token);
Assert.IsFalse(moveNextTask2.IsCompleted);

cts.Cancel();

var ex = await ExceptionAssert.ThrowsAsync<RpcException>(() => moveNextTask2).DefaultTimeout();
Assert.AreEqual(StatusCode.Cancelled, ex.StatusCode);
Assert.AreEqual(StatusCode.Cancelled, call.GetStatus().StatusCode);
Assert.AreEqual("Call canceled by the client.", call.GetStatus().Detail);
}

[Test]
public async Task AsyncServerStreamingCall_MessagesStreamedThenDispose_ErrorStatus()
{
// Arrange
var streamContent = new SyncPointMemoryStream();

var httpClient = ClientTestHelpers.CreateTestClient(request =>
{
return Task.FromResult(ResponseUtils.CreateResponse(HttpStatusCode.OK, new StreamContent(streamContent)));
});
var invoker = HttpClientCallInvokerFactory.Create(httpClient);

// Act
var call = invoker.AsyncServerStreamingCall<HelloRequest, HelloReply>(ClientTestHelpers.ServiceMethod, string.Empty, new CallOptions(), new HelloRequest());

var responseStream = call.ResponseStream;

// Assert
Assert.IsNull(responseStream.Current);

var moveNextTask1 = responseStream.MoveNext(CancellationToken.None);
Assert.IsFalse(moveNextTask1.IsCompleted);

await streamContent.AddDataAndWait(await ClientTestHelpers.GetResponseDataAsync(new HelloReply
{
Message = "Hello world 1"
}).DefaultTimeout()).DefaultTimeout();

Assert.IsTrue(await moveNextTask1.DefaultTimeout());
Assert.IsNotNull(responseStream.Current);
Assert.AreEqual("Hello world 1", responseStream.Current.Message);

var moveNextTask2 = responseStream.MoveNext(CancellationToken.None);
Assert.IsFalse(moveNextTask2.IsCompleted);

call.Dispose();

var ex = await ExceptionAssert.ThrowsAsync<RpcException>(() => moveNextTask2).DefaultTimeout();
Assert.AreEqual(StatusCode.Cancelled, ex.StatusCode);
Assert.AreEqual(StatusCode.Cancelled, call.GetStatus().StatusCode);
Assert.AreEqual("gRPC call disposed.", call.GetStatus().Detail);
}

[Test]
public async Task ClientStreamReader_WriteWithInvalidHttpStatus_ErrorThrown()
{
Expand Down
25 changes: 21 additions & 4 deletions test/Grpc.Net.Client.Tests/AsyncUnaryCallTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#endregion

using System;
using System.Collections.Generic;
using System.Linq;
using System.Net;
using System.Net.Http;
Expand Down Expand Up @@ -125,7 +126,7 @@ public async Task AsyncUnaryCall_Success_RequestContentSent()
}

[Test]
public async Task AsyncUnaryCall_NonOkStatusTrailer_ThrowRpcError()
public async Task AsyncUnaryCall_NonOkStatusTrailer_AccessResponse_ReturnHeaders()
{
// Arrange
var httpClient = ClientTestHelpers.CreateTestClient(request =>
Expand All @@ -142,16 +143,32 @@ public async Task AsyncUnaryCall_NonOkStatusTrailer_ThrowRpcError()
Assert.AreEqual(StatusCode.Unimplemented, ex.StatusCode);
}

[Test]
public async Task AsyncUnaryCall_NonOkStatusTrailer_AccessHeaders_ThrowRpcError()
{
// Arrange
var httpClient = ClientTestHelpers.CreateTestClient(request =>
{
var response = ResponseUtils.CreateHeadersOnlyResponse(HttpStatusCode.OK, StatusCode.Unimplemented, customHeaders: new Dictionary<string, string> { ["custom"] = "true" });
return Task.FromResult(response);
});
var invoker = HttpClientCallInvokerFactory.Create(httpClient);

// Act
var headers = await invoker.AsyncUnaryCall<HelloRequest, HelloReply>(ClientTestHelpers.ServiceMethod, string.Empty, new CallOptions(), new HelloRequest()).ResponseHeadersAsync.DefaultTimeout();

// Assert
Assert.AreEqual("true", headers.GetValue("custom"));
}

[Test]
public async Task AsyncUnaryCall_SuccessTrailersOnly_ThrowNoMessageError()
{
// Arrange
HttpResponseMessage? responseMessage = null;
var httpClient = ClientTestHelpers.CreateTestClient(request =>
{
responseMessage = ResponseUtils.CreateResponse(HttpStatusCode.OK, new ByteArrayContent(Array.Empty<byte>()), grpcStatusCode: null);
responseMessage.Headers.Add(GrpcProtocolConstants.StatusTrailer, StatusCode.OK.ToString("D"));
responseMessage.Headers.Add(GrpcProtocolConstants.MessageTrailer, "Detail!");
responseMessage = ResponseUtils.CreateHeadersOnlyResponse(HttpStatusCode.OK, StatusCode.OK, customHeaders: new Dictionary<string, string> { [GrpcProtocolConstants.MessageTrailer] = "Detail!" });
return Task.FromResult(responseMessage);
});
var invoker = HttpClientCallInvokerFactory.Create(httpClient);
Expand Down
61 changes: 59 additions & 2 deletions test/Shared/ResponseUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

using System;
using System.Buffers.Binary;
using System.Collections;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.Net;
Expand Down Expand Up @@ -48,7 +50,10 @@ public static HttpResponseMessage CreateResponse(
HttpContent payload,
StatusCode? grpcStatusCode = StatusCode.OK,
string? grpcEncoding = null,
Version? version = null)
Version? version = null,
string? retryPushbackHeader = null,
IDictionary<string, string>? customHeaders = null,
IDictionary<string, string>? customTrailers = null)
{
payload.Headers.ContentType = GrpcContentTypeHeaderValue;

Expand All @@ -59,12 +64,64 @@ public static HttpResponseMessage CreateResponse(
};

message.Headers.Add(MessageEncodingHeader, grpcEncoding ?? IdentityGrpcEncoding);
if (retryPushbackHeader != null)
{
message.Headers.Add("grpc-retry-pushback-ms", retryPushbackHeader);
}

if (customHeaders != null)
{
foreach (var customHeader in customHeaders)
{
message.Headers.Add(customHeader.Key, customHeader.Value);
}
}

if (grpcStatusCode != null)
{
message.TrailingHeaders.Add(StatusTrailer, grpcStatusCode.Value.ToString("D"));
}

if (customTrailers != null)
{
foreach (var customTrailer in customTrailers)
{
message.TrailingHeaders.Add(customTrailer.Key, customTrailer.Value);
}
}

return message;
}

public static HttpResponseMessage CreateHeadersOnlyResponse(
HttpStatusCode statusCode,
StatusCode grpcStatusCode,
string? grpcEncoding = null,
Version? version = null,
string? retryPushbackHeader = null,
IDictionary<string, string>? customHeaders = null)
{
var message = new HttpResponseMessage(statusCode)
{
Version = version ?? ProtocolVersion
};

message.Headers.Add(MessageEncodingHeader, grpcEncoding ?? IdentityGrpcEncoding);
if (retryPushbackHeader != null)
{
message.Headers.Add("grpc-retry-pushback-ms", retryPushbackHeader);
}

if (customHeaders != null)
{
foreach (var customHeader in customHeaders)
{
message.Headers.Add(customHeader.Key, customHeader.Value);
}
}

message.Headers.Add(StatusTrailer, grpcStatusCode.ToString("D"));

return message;
}

Expand All @@ -91,4 +148,4 @@ private static void EncodeMessageLength(int messageLength, Span<byte> destinatio
BinaryPrimitives.WriteUInt32BigEndian(destination, (uint)messageLength);
}
}
}
}
Loading

0 comments on commit 9c3385a

Please sign in to comment.