Skip to content

Commit

Permalink
Change some RequestStream.WriteAsync errors to match Grpc.Core (#1199)
Browse files Browse the repository at this point in the history
  • Loading branch information
JamesNK authored Feb 9, 2021
1 parent a3e6744 commit 4a8a120
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 36 deletions.
30 changes: 14 additions & 16 deletions src/Grpc.Net.Client/Internal/HttpContentClientStreamWriter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ internal class HttpContentClientStreamWriter<TRequest, TResponse> : IClientStrea
private readonly ILogger _logger;
private readonly object _writeLock;
private Task? _writeTask;
private bool _completeCalled;

public TaskCompletionSource<Stream> WriteStreamTcs { get; }
public TaskCompletionSource<bool> CompleteTcs { get; }
Expand Down Expand Up @@ -73,6 +74,7 @@ public Task CompleteAsync()

// Notify that the client stream is complete
CompleteTcs.TrySetResult(true);
_completeCalled = true;
}
}

Expand All @@ -92,33 +94,27 @@ public Task WriteAsync(TRequest message)
{
using (_call.StartScope())
{
// CompleteAsync has already been called
// Use explicit flag here. This error takes precedence over others.
if (_completeCalled)
{
return CreateErrorTask("Request stream has already been completed.");
}

// Call has already completed
if (_call.CallTask.IsCompletedSuccessfully)
{
var status = _call.CallTask.Result;
if (_call.CancellationToken.IsCancellationRequested &&
_call.Channel.ThrowOperationCanceledOnCancellation &&
(status.StatusCode == StatusCode.Cancelled || status.StatusCode == StatusCode.DeadlineExceeded))
{
if (!_call.Channel.ThrowOperationCanceledOnCancellation)
{
return Task.FromException(_call.CreateCanceledStatusException());
}
else
{
return Task.FromCanceled(_call.CancellationToken);
}
return Task.FromCanceled(_call.CancellationToken);
}

return CreateErrorTask("Can't write the message because the call is complete.");
return Task.FromException(_call.CreateCanceledStatusException());
}

// CompleteAsync has already been called
// Use IsCompleted here because that will track success and cancellation
if (CompleteTcs.Task.IsCompleted)
{
return CreateErrorTask("Can't write the message because the client stream writer is complete.");
}

// Pending writes need to be awaited first
if (IsWriteInProgressUnsynchronized)
{
Expand All @@ -144,6 +140,8 @@ public void Dispose()
{
}

public GrpcCall<TRequest, TResponse> Call => _call;

private async Task WriteAsyncCore(TRequest message)
{
try
Expand Down
8 changes: 4 additions & 4 deletions test/FunctionalTests/Client/StreamingTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -655,8 +655,8 @@ async Task<DataMessage> ClientStreamingWithTrailers(IAsyncStreamReader<DataMessa
var ex = await ExceptionAssert.ThrowsAsync<InvalidOperationException>(() => readTask!).DefaultTimeout();
Assert.AreEqual("Can't read messages after the request is complete.", ex.Message);

var clientException = await ExceptionAssert.ThrowsAsync<InvalidOperationException>(() => call.RequestStream.WriteAsync(new DataMessage())).DefaultTimeout();
Assert.AreEqual("Can't write the message because the call is complete.", clientException.Message);
var clientException = await ExceptionAssert.ThrowsAsync<RpcException>(() => call.RequestStream.WriteAsync(new DataMessage())).DefaultTimeout();
Assert.AreEqual(StatusCode.OK, clientException.StatusCode);
}

[TestCase(true)]
Expand Down Expand Up @@ -731,8 +731,8 @@ async Task<DataMessage> ClientStreamingWithTrailers(IAsyncStreamReader<DataMessa
// Ensure the server abort reaches the client
await Task.Delay(100);

var clientException = await ExceptionAssert.ThrowsAsync<InvalidOperationException>(() => call.RequestStream.WriteAsync(new DataMessage())).DefaultTimeout();
Assert.AreEqual("Can't write the message because the call is complete.", clientException.Message);
var clientException = await ExceptionAssert.ThrowsAsync<RpcException>(() => call.RequestStream.WriteAsync(new DataMessage())).DefaultTimeout();
Assert.AreEqual(StatusCode.Unavailable, clientException.StatusCode);
}

[Test]
Expand Down
54 changes: 38 additions & 16 deletions test/Grpc.Net.Client.Tests/AsyncClientStreamingCallTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -188,21 +188,33 @@ public async Task ClientStreamWriter_CompleteWhilePendingWrite_ErrorThrown()
public async Task ClientStreamWriter_WriteWhileComplete_ErrorThrown()
{
// Arrange
var streamContent = new SyncPointMemoryStream();
var httpClient = ClientTestHelpers.CreateTestClient(request =>
{
var streamContent = new StreamContent(new SyncPointMemoryStream());
return Task.FromResult(ResponseUtils.CreateResponse(HttpStatusCode.OK, streamContent));
return Task.FromResult(ResponseUtils.CreateResponse(HttpStatusCode.OK, new StreamContent(streamContent)));
});
var invoker = HttpClientCallInvokerFactory.Create(httpClient);

// Act
var call = invoker.AsyncClientStreamingCall<HelloRequest, HelloReply>(ClientTestHelpers.ServiceMethod, string.Empty, new CallOptions());
await call.RequestStream.CompleteAsync().DefaultTimeout();
var resultTask = call.ResponseAsync;

// Assert
var ex = await ExceptionAssert.ThrowsAsync<InvalidOperationException>(() => call.RequestStream.WriteAsync(new HelloRequest { Name = "1" })).DefaultTimeout();
var writeException1 = await ExceptionAssert.ThrowsAsync<InvalidOperationException>(() => call.RequestStream.WriteAsync(new HelloRequest { Name = "1" })).DefaultTimeout();
Assert.AreEqual("Request stream has already been completed.", writeException1.Message);

await streamContent.AddDataAndWait(await ClientTestHelpers.GetResponseDataAsync(new HelloReply
{
Message = "Hello world 1"
}).DefaultTimeout()).DefaultTimeout();
await streamContent.AddDataAndWait(new byte[0]);

Assert.AreEqual("Can't write the message because the client stream writer is complete.", ex.Message);
var result = await resultTask.DefaultTimeout();
Assert.AreEqual("Hello world 1", result.Message);

var writeException2 = await ExceptionAssert.ThrowsAsync<InvalidOperationException>(() => call.RequestStream.WriteAsync(new HelloRequest { Name = "2" })).DefaultTimeout();
Assert.AreEqual("Request stream has already been completed.", writeException2.Message);
}

[Test]
Expand All @@ -218,32 +230,41 @@ public async Task ClientStreamWriter_WriteWithInvalidHttpStatus_ErrorThrown()

// Act
var call = invoker.AsyncClientStreamingCall<HelloRequest, HelloReply>(ClientTestHelpers.ServiceMethod, string.Empty, new CallOptions());
var writeException = await ExceptionAssert.ThrowsAsync<RpcException>(() => call.RequestStream.WriteAsync(new HelloRequest { Name = "1" })).DefaultTimeout();
var resultException = await ExceptionAssert.ThrowsAsync<RpcException>(() => call.ResponseAsync).DefaultTimeout();

// Assert
var ex = await ExceptionAssert.ThrowsAsync<InvalidOperationException>(() => call.RequestStream.WriteAsync(new HelloRequest { Name = "1" })).DefaultTimeout();
Assert.AreEqual("Bad gRPC response. HTTP status code: 404", writeException.Status.Detail);
Assert.AreEqual(StatusCode.Unimplemented, writeException.StatusCode);

Assert.AreEqual("Can't write the message because the call is complete.", ex.Message);
Assert.AreEqual("Bad gRPC response. HTTP status code: 404", resultException.Status.Detail);
Assert.AreEqual(StatusCode.Unimplemented, resultException.StatusCode);
}

[Test]
public async Task ClientStreamWriter_WriteAfterResponseHasFinished_ErrorThrown()
{
// Arrange
var httpClient = ClientTestHelpers.CreateTestClient(request =>
var httpClient = ClientTestHelpers.CreateTestClient(async request =>
{
return Task.FromResult(ResponseUtils.CreateResponse(HttpStatusCode.OK));
var reply = new HelloReply { Message = "Hello world" };
var streamContent = await ClientTestHelpers.CreateResponseContent(reply).DefaultTimeout();

return ResponseUtils.CreateResponse(HttpStatusCode.OK, streamContent);
});
var invoker = HttpClientCallInvokerFactory.Create(httpClient);

// Act
var call = invoker.AsyncClientStreamingCall<HelloRequest, HelloReply>(ClientTestHelpers.ServiceMethod, string.Empty, new CallOptions());

var ex = await ExceptionAssert.ThrowsAsync<InvalidOperationException>(() => call.RequestStream.WriteAsync(new HelloRequest())).DefaultTimeout();
var ex = await ExceptionAssert.ThrowsAsync<RpcException>(() => call.RequestStream.WriteAsync(new HelloRequest())).DefaultTimeout();
var result = await call.ResponseAsync.DefaultTimeout();

// Assert
Assert.AreEqual("Can't write the message because the call is complete.", ex.Message);
Assert.AreEqual(StatusCode.Internal, call.GetStatus().StatusCode);
Assert.AreEqual("Failed to deserialize response message.", call.GetStatus().Detail);
Assert.AreEqual(StatusCode.OK, ex.StatusCode);
Assert.AreEqual(StatusCode.OK, call.GetStatus().StatusCode);
Assert.AreEqual(null, call.GetStatus().Detail);

Assert.AreEqual("Hello world", result.Message);
}

[Test]
Expand Down Expand Up @@ -297,11 +318,12 @@ public async Task ClientStreamWriter_CallThrowsException_WriteAsyncThrowsError()

// Act
var call = invoker.AsyncClientStreamingCall<HelloRequest, HelloReply>(ClientTestHelpers.ServiceMethod, string.Empty, new CallOptions());

var ex = await ExceptionAssert.ThrowsAsync<InvalidOperationException>(() => call.RequestStream.WriteAsync(new HelloRequest())).DefaultTimeout();
var writeException = await ExceptionAssert.ThrowsAsync<RpcException>(() => call.RequestStream.WriteAsync(new HelloRequest())).DefaultTimeout();
var resultException = await ExceptionAssert.ThrowsAsync<RpcException>(() => call.ResponseAsync).DefaultTimeout();

// Assert
Assert.AreEqual("Can't write the message because the call is complete.", ex.Message);
Assert.AreEqual("Error starting gRPC call. InvalidOperationException: Error!", writeException.Status.Detail);
Assert.AreEqual("Error starting gRPC call. InvalidOperationException: Error!", resultException.Status.Detail);
Assert.AreEqual(StatusCode.Internal, call.GetStatus().StatusCode);
}
}
Expand Down

0 comments on commit 4a8a120

Please sign in to comment.