diff --git a/core/src/main/java/io/grpc/internal/ServerCallImpl.java b/core/src/main/java/io/grpc/internal/ServerCallImpl.java index 77e74795046..2fad5f46e57 100644 --- a/core/src/main/java/io/grpc/internal/ServerCallImpl.java +++ b/core/src/main/java/io/grpc/internal/ServerCallImpl.java @@ -34,7 +34,6 @@ import io.grpc.InternalDecompressorRegistry; import io.grpc.Metadata; import io.grpc.MethodDescriptor; -import io.grpc.MethodDescriptor.MethodType; import io.grpc.ServerCall; import io.grpc.Status; import java.io.IOException; @@ -42,6 +41,12 @@ import java.util.List; final class ServerCallImpl extends ServerCall { + + @VisibleForTesting + static String TOO_MANY_RESPONSES = "Too many responses"; + @VisibleForTesting + static String MISSING_RESPONSE = "Completed without a response"; + private final ServerStream stream; private final MethodDescriptor method; private final Context.CancellableContext context; @@ -54,6 +59,7 @@ final class ServerCallImpl extends ServerCall { private boolean sendHeadersCalled; private boolean closeCalled; private Compressor compressor; + private boolean messageSent; ServerCallImpl(ServerStream stream, MethodDescriptor method, Metadata inboundHeaders, Context.CancellableContext context, @@ -115,6 +121,13 @@ public void sendHeaders(Metadata headers) { public void sendMessage(RespT message) { checkState(sendHeadersCalled, "sendHeaders has not been called"); checkState(!closeCalled, "call is closed"); + + if (method.getType().serverSendsOneMessage() && messageSent) { + internalClose(Status.INTERNAL.withDescription(TOO_MANY_RESPONSES)); + return; + } + + messageSent = true; try { InputStream resp = method.streamResponse(message); stream.writeMessage(resp); @@ -151,6 +164,12 @@ public boolean isReady() { public void close(Status status, Metadata trailers) { checkState(!closeCalled, "call already closed"); closeCalled = true; + + if (status.isOk() && method.getType().serverSendsOneMessage() && !messageSent) { + internalClose(Status.INTERNAL.withDescription(MISSING_RESPONSE)); + return; + } + stream.close(status, trailers); } @@ -178,6 +197,15 @@ public MethodDescriptor getMethodDescriptor() { return method; } + /** + * Close the {@link ServerStream} because an internal error occurred. Allow the application to + * run until completion, but silently ignore interactions with the {@link ServerStream} from now + * on. + */ + private void internalClose(Status internalError) { + stream.close(internalError, new Metadata()); + } + /** * All of these callbacks are assumed to called on an application thread, and the caller is * responsible for handling thrown exceptions. @@ -187,7 +215,6 @@ static final class ServerStreamListenerImpl implements ServerStreamListene private final ServerCallImpl call; private final ServerCall.Listener listener; private final Context.CancellableContext context; - private boolean messageReceived; public ServerStreamListenerImpl( ServerCallImpl call, ServerCall.Listener listener, @@ -216,15 +243,6 @@ public void messageRead(final InputStream message) { if (call.cancelled) { return; } - // Special case for unary calls. - if (messageReceived && call.method.getType() == MethodType.UNARY) { - call.stream.close(Status.INTERNAL.withDescription( - "More than one request messages for unary call or server streaming call"), - new Metadata()); - return; - } - messageReceived = true; - listener.onMessage(call.method.parseRequest(message)); } catch (Throwable e) { t = e; diff --git a/core/src/main/java/io/grpc/internal/ServerStream.java b/core/src/main/java/io/grpc/internal/ServerStream.java index ee3956d500b..5a794976074 100644 --- a/core/src/main/java/io/grpc/internal/ServerStream.java +++ b/core/src/main/java/io/grpc/internal/ServerStream.java @@ -42,6 +42,10 @@ public interface ServerStream extends Stream { * {@link io.grpc.Status.Code#OK} implies normal termination of the * stream. Any other value implies abnormal termination. * + *

Attempts to read from or write to the stream after closing + * should be ignored by implementations, and should not throw + * exceptions. + * * @param status details of the closure * @param trailers an additional block of metadata to pass to the client on stream closure. */ diff --git a/core/src/test/java/io/grpc/internal/ServerCallImplTest.java b/core/src/test/java/io/grpc/internal/ServerCallImplTest.java index bedbbb35631..c298730f6e2 100644 --- a/core/src/test/java/io/grpc/internal/ServerCallImplTest.java +++ b/core/src/test/java/io/grpc/internal/ServerCallImplTest.java @@ -23,7 +23,10 @@ import static org.junit.Assert.fail; import static org.mockito.Matchers.any; import static org.mockito.Matchers.isA; +import static org.mockito.Matchers.same; import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -48,9 +51,7 @@ import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.mockito.ArgumentCaptor; -import org.mockito.Captor; import org.mockito.Mock; -import org.mockito.Mockito; import org.mockito.MockitoAnnotations; @RunWith(JUnit4.class) @@ -58,17 +59,25 @@ public class ServerCallImplTest { @Rule public final ExpectedException thrown = ExpectedException.none(); @Mock private ServerStream stream; @Mock private ServerCall.Listener callListener; - @Captor private ArgumentCaptor statusCaptor; private ServerCallImpl call; private Context.CancellableContext context; - private final MethodDescriptor method = MethodDescriptor.newBuilder() - .setType(MethodType.UNARY) - .setFullMethodName("/service/method") - .setRequestMarshaller(new LongMarshaller()) - .setResponseMarshaller(new LongMarshaller()) - .build(); + private static final MethodDescriptor UNARY_METHOD = + MethodDescriptor.newBuilder() + .setType(MethodType.UNARY) + .setFullMethodName("/service/method") + .setRequestMarshaller(new LongMarshaller()) + .setResponseMarshaller(new LongMarshaller()) + .build(); + + private static final MethodDescriptor CLIENT_STREAMING_METHOD = + MethodDescriptor.newBuilder() + .setType(MethodType.UNARY) + .setFullMethodName("/service/method") + .setRequestMarshaller(new LongMarshaller()) + .setResponseMarshaller(new LongMarshaller()) + .build(); private final Metadata requestHeaders = new Metadata(); @@ -76,7 +85,7 @@ public class ServerCallImplTest { public void setUp() { MockitoAnnotations.initMocks(this); context = Context.ROOT.withCancellation(); - call = new ServerCallImpl(stream, method, requestHeaders, context, + call = new ServerCallImpl(stream, UNARY_METHOD, requestHeaders, context, DecompressorRegistry.getDefaultInstance(), CompressorRegistry.getDefaultInstance()); } @@ -158,6 +167,114 @@ public void sendMessage_closesOnFailure() { verify(stream).close(isA(Status.class), isA(Metadata.class)); } + @Test + public void sendMessage_serverSendsOne_closeOnSecondCall_unary() { + sendMessage_serverSendsOne_closeOnSecondCall(UNARY_METHOD); + } + + @Test + public void sendMessage_serverSendsOne_closeOnSecondCall_clientStreaming() { + sendMessage_serverSendsOne_closeOnSecondCall(CLIENT_STREAMING_METHOD); + } + + private void sendMessage_serverSendsOne_closeOnSecondCall( + MethodDescriptor method) { + ServerCallImpl serverCall = new ServerCallImpl( + stream, + method, + requestHeaders, + context, + DecompressorRegistry.getDefaultInstance(), + CompressorRegistry.getDefaultInstance()); + serverCall.sendHeaders(new Metadata()); + serverCall.sendMessage(1L); + verify(stream, times(1)).writeMessage(any(InputStream.class)); + verify(stream, never()).close(any(Status.class), any(Metadata.class)); + + // trying to send a second message causes gRPC to close the underlying stream + serverCall.sendMessage(1L); + verify(stream, times(1)).writeMessage(any(InputStream.class)); + ArgumentCaptor statusCaptor = ArgumentCaptor.forClass(Status.class); + ArgumentCaptor metadataCaptor = ArgumentCaptor.forClass(Metadata.class); + verify(stream, times(1)).close(statusCaptor.capture(), metadataCaptor.capture()); + assertEquals(Status.Code.INTERNAL, statusCaptor.getValue().getCode()); + assertEquals(ServerCallImpl.TOO_MANY_RESPONSES, statusCaptor.getValue().getDescription()); + assertTrue(metadataCaptor.getValue().keys().isEmpty()); + } + + @Test + public void sendMessage_serverSendsOne_closeOnSecondCall_appRunToCompletion_unary() { + sendMessage_serverSendsOne_closeOnSecondCall_appRunToCompletion(UNARY_METHOD); + } + + @Test + public void sendMessage_serverSendsOne_closeOnSecondCall_appRunToCompletion_clientStreaming() { + sendMessage_serverSendsOne_closeOnSecondCall_appRunToCompletion(CLIENT_STREAMING_METHOD); + } + + private void sendMessage_serverSendsOne_closeOnSecondCall_appRunToCompletion( + MethodDescriptor method) { + ServerCallImpl serverCall = new ServerCallImpl( + stream, + method, + requestHeaders, + context, + DecompressorRegistry.getDefaultInstance(), + CompressorRegistry.getDefaultInstance()); + serverCall.sendHeaders(new Metadata()); + serverCall.sendMessage(1L); + serverCall.sendMessage(1L); + verify(stream, times(1)).writeMessage(any(InputStream.class)); + verify(stream, times(1)).close(any(Status.class), any(Metadata.class)); + + // App runs to completion but everything is ignored + serverCall.sendMessage(1L); + serverCall.close(Status.OK, new Metadata()); + try { + serverCall.close(Status.OK, new Metadata()); + fail("calling a second time should still cause an error"); + } catch (IllegalStateException expected) { + // noop + } + } + + @Test + public void serverSendsOne_okFailsOnMissingResponse_unary() { + serverSendsOne_okFailsOnMissingResponse(UNARY_METHOD); + } + + @Test + public void serverSendsOne_okFailsOnMissingResponse_clientStreaming() { + serverSendsOne_okFailsOnMissingResponse(CLIENT_STREAMING_METHOD); + } + + private void serverSendsOne_okFailsOnMissingResponse( + MethodDescriptor method) { + ServerCallImpl serverCall = new ServerCallImpl( + stream, + method, + requestHeaders, + context, + DecompressorRegistry.getDefaultInstance(), + CompressorRegistry.getDefaultInstance()); + serverCall.close(Status.OK, new Metadata()); + ArgumentCaptor statusCaptor = ArgumentCaptor.forClass(Status.class); + ArgumentCaptor metadataCaptor = ArgumentCaptor.forClass(Metadata.class); + verify(stream, times(1)).close(statusCaptor.capture(), metadataCaptor.capture()); + assertEquals(Status.Code.INTERNAL, statusCaptor.getValue().getCode()); + assertEquals(ServerCallImpl.MISSING_RESPONSE, statusCaptor.getValue().getDescription()); + assertTrue(metadataCaptor.getValue().keys().isEmpty()); + } + + @Test + public void serverSendsOne_canErrorWithoutResponse() { + final String description = "test description"; + final Status status = Status.RESOURCE_EXHAUSTED.withDescription(description); + final Metadata metadata = new Metadata(); + call.close(status, metadata); + verify(stream, times(1)).close(same(status), same(metadata)); + } + @Test public void isReady() { when(stream.isReady()).thenReturn(true); @@ -260,34 +377,20 @@ public void streamListener_onReady_onlyOnce() { public void streamListener_messageRead() { ServerStreamListenerImpl streamListener = new ServerCallImpl.ServerStreamListenerImpl(call, callListener, context); - streamListener.messageRead(method.streamRequest(1234L)); - - verify(callListener).onMessage(1234L); - } - - @Test - public void streamListener_messageRead_unaryFailsOnMultiple() { - ServerStreamListenerImpl streamListener = - new ServerCallImpl.ServerStreamListenerImpl(call, callListener, context); - streamListener.messageRead(method.streamRequest(1234L)); - streamListener.messageRead(method.streamRequest(1234L)); + streamListener.messageRead(UNARY_METHOD.streamRequest(1234L)); - // Makes sure this was only called once. verify(callListener).onMessage(1234L); - - verify(stream).close(statusCaptor.capture(), Mockito.isA(Metadata.class)); - assertEquals(Status.Code.INTERNAL, statusCaptor.getValue().getCode()); } @Test public void streamListener_messageRead_onlyOnce() { ServerStreamListenerImpl streamListener = new ServerCallImpl.ServerStreamListenerImpl(call, callListener, context); - streamListener.messageRead(method.streamRequest(1234L)); + streamListener.messageRead(UNARY_METHOD.streamRequest(1234L)); // canceling the call should short circuit future halfClosed() calls. streamListener.closed(Status.CANCELLED); - streamListener.messageRead(method.streamRequest(1234L)); + streamListener.messageRead(UNARY_METHOD.streamRequest(1234L)); verify(callListener).onMessage(1234L); } @@ -300,7 +403,7 @@ public void streamListener_unexpectedRuntimeException() { .when(callListener) .onMessage(any(Long.class)); - InputStream inputStream = method.streamRequest(1234L); + InputStream inputStream = UNARY_METHOD.streamRequest(1234L); thrown.expect(RuntimeException.class); thrown.expectMessage("unexpected exception"); diff --git a/examples/src/test/java/io/grpc/examples/routeguide/RouteGuideClientTest.java b/examples/src/test/java/io/grpc/examples/routeguide/RouteGuideClientTest.java index d1fcb8d4fb8..03bb88f0762 100644 --- a/examples/src/test/java/io/grpc/examples/routeguide/RouteGuideClientTest.java +++ b/examples/src/test/java/io/grpc/examples/routeguide/RouteGuideClientTest.java @@ -300,52 +300,6 @@ public void onCompleted() { verify(testHelper, never()).onRpcError(any(Throwable.class)); } - /** - * Example for testing async client-streaming. - */ - @Test - public void recordRoute_wrongResponse() throws Exception { - client.setRandom(noRandomness); - Point point1 = Point.newBuilder().setLatitude(1).setLongitude(1).build(); - final Feature requestFeature1 = - Feature.newBuilder().setLocation(point1).build(); - final List features = Arrays.asList(requestFeature1); - - // implement the fake service - RouteGuideImplBase recordRouteImpl = - new RouteGuideImplBase() { - @Override - public StreamObserver recordRoute(StreamObserver responseObserver) { - RouteSummary response = RouteSummary.getDefaultInstance(); - // sending more than one responses is not right for client-streaming call. - responseObserver.onNext(response); - responseObserver.onNext(response); - responseObserver.onCompleted(); - - return new StreamObserver() { - @Override - public void onNext(Point value) { - } - - @Override - public void onError(Throwable t) { - } - - @Override - public void onCompleted() { - } - }; - } - }; - serviceRegistry.addService(recordRouteImpl); - - client.recordRoute(features, 4); - - ArgumentCaptor errorCaptor = ArgumentCaptor.forClass(Throwable.class); - verify(testHelper).onRpcError(errorCaptor.capture()); - assertEquals(Status.Code.CANCELLED, Status.fromThrowable(errorCaptor.getValue()).getCode()); - } - /** * Example for testing async client-streaming. */ diff --git a/stub/src/main/java/io/grpc/stub/ServerCalls.java b/stub/src/main/java/io/grpc/stub/ServerCalls.java index 38b4ee8bde0..f9a84085ce3 100644 --- a/stub/src/main/java/io/grpc/stub/ServerCalls.java +++ b/stub/src/main/java/io/grpc/stub/ServerCalls.java @@ -18,6 +18,7 @@ import static com.google.common.base.Preconditions.checkNotNull; +import com.google.common.base.Preconditions; import io.grpc.Metadata; import io.grpc.MethodDescriptor; import io.grpc.ServerCall; @@ -30,6 +31,9 @@ */ public final class ServerCalls { + static String TOO_MANY_REQUESTS = "Too many requests"; + static String MISSING_REQUEST = "Half-closed without a request"; + private ServerCalls() { } @@ -112,6 +116,9 @@ private static ServerCallHandler asyncUnaryRequestCal public ServerCall.Listener startCall( final ServerCall call, Metadata headers) { + Preconditions.checkArgument( + call.getMethodDescriptor().getType().clientSendsOneMessage(), + "asyncUnaryRequestCall is only for clientSendsOneMessage methods"); final ServerCallStreamObserverImpl responseObserver = new ServerCallStreamObserverImpl(call); // We expect only 1 request, but we ask for 2 requests here so that if a misbehaving client @@ -119,9 +126,19 @@ public ServerCall.Listener startCall( // inbound flow control has no effect on unary calls. call.request(2); return new EmptyServerCallListener() { + boolean canInvoke = true; ReqT request; @Override public void onMessage(ReqT request) { + if (this.request != null) { + // Safe to close the call, because the application has not yet been invoked + call.close( + Status.INTERNAL.withDescription(TOO_MANY_REQUESTS), + new Metadata()); + canInvoke = false; + return; + } + // We delay calling method.invoke() until onHalfClose() to make sure the client // half-closes. this.request = request; @@ -129,17 +146,23 @@ public void onMessage(ReqT request) { @Override public void onHalfClose() { - if (request != null) { - method.invoke(request, responseObserver); - responseObserver.freeze(); - if (call.isReady()) { - // Since we are calling invoke in halfClose we have missed the onReady - // event from the transport so recover it here. - onReady(); - } - } else { - call.close(Status.INTERNAL.withDescription("Half-closed without a request"), + if (!canInvoke) { + return; + } + if (request == null) { + // Safe to close the call, because the application has not yet been invoked + call.close( + Status.INTERNAL.withDescription(MISSING_REQUEST), new Metadata()); + return; + } + + method.invoke(request, responseObserver); + responseObserver.freeze(); + if (call.isReady()) { + // Since we are calling invoke in halfClose we have missed the onReady + // event from the transport so recover it here. + onReady(); } } diff --git a/stub/src/test/java/io/grpc/stub/ServerCallsTest.java b/stub/src/test/java/io/grpc/stub/ServerCallsTest.java index 6284924ea74..cbcc6d9cb43 100644 --- a/stub/src/test/java/io/grpc/stub/ServerCallsTest.java +++ b/stub/src/test/java/io/grpc/stub/ServerCallsTest.java @@ -62,17 +62,24 @@ public class ServerCallsTest { static final MethodDescriptor STREAMING_METHOD = MethodDescriptor.newBuilder() .setType(MethodDescriptor.MethodType.BIDI_STREAMING) - .setFullMethodName("some/method") + .setFullMethodName("some/bidi_streaming") .setRequestMarshaller(new IntegerMarshaller()) .setResponseMarshaller(new IntegerMarshaller()) .build(); - static final MethodDescriptor UNARY_METHOD = STREAMING_METHOD.toBuilder() - .setType(MethodDescriptor.MethodType.UNARY) - .setFullMethodName("some/unarymethod") - .build(); + static final MethodDescriptor SERVER_STREAMING_METHOD = + STREAMING_METHOD.toBuilder() + .setType(MethodDescriptor.MethodType.SERVER_STREAMING) + .setFullMethodName("some/client_streaming") + .build(); + + static final MethodDescriptor UNARY_METHOD = + STREAMING_METHOD.toBuilder() + .setType(MethodDescriptor.MethodType.UNARY) + .setFullMethodName("some/unary") + .build(); - private final ServerCallRecorder serverCall = new ServerCallRecorder(); + private final ServerCallRecorder serverCall = new ServerCallRecorder(UNARY_METHOD); @Test public void runtimeStreamObserverIsServerCallStreamObserver() throws Exception { @@ -284,6 +291,85 @@ public void run() { assertEquals(2, onReadyCalled.get()); } + @Test + public void clientSendsOne_errorMissingRequest_unary() { + ServerCallRecorder serverCall = new ServerCallRecorder(UNARY_METHOD); + ServerCallHandler callHandler = + ServerCalls.asyncUnaryCall( + new ServerCalls.UnaryMethod() { + @Override + public void invoke(Integer req, StreamObserver responseObserver) { + fail("should not be reached"); + } + }); + ServerCall.Listener listener = callHandler.startCall(serverCall, new Metadata()); + listener.onHalfClose(); + assertThat(serverCall.responses).isEmpty(); + assertEquals(Status.Code.INTERNAL, serverCall.status.getCode()); + assertEquals(ServerCalls.MISSING_REQUEST, serverCall.status.getDescription()); + } + + @Test + public void clientSendsOne_errorMissingRequest_serverStreaming() { + ServerCallRecorder serverCall = new ServerCallRecorder(SERVER_STREAMING_METHOD); + ServerCallHandler callHandler = + ServerCalls.asyncServerStreamingCall( + new ServerCalls.ServerStreamingMethod() { + @Override + public void invoke(Integer req, StreamObserver responseObserver) { + fail("should not be reached"); + } + }); + ServerCall.Listener listener = callHandler.startCall(serverCall, new Metadata()); + listener.onHalfClose(); + assertThat(serverCall.responses).isEmpty(); + assertEquals(Status.Code.INTERNAL, serverCall.status.getCode()); + assertEquals(ServerCalls.MISSING_REQUEST, serverCall.status.getDescription()); + + } + + @Test + public void clientSendsOne_errorTooManyRequests_unary() { + ServerCallRecorder serverCall = new ServerCallRecorder(UNARY_METHOD); + ServerCallHandler callHandler = + ServerCalls.asyncUnaryCall( + new ServerCalls.UnaryMethod() { + @Override + public void invoke(Integer req, StreamObserver responseObserver) { + fail("should not be reached"); + } + }); + ServerCall.Listener listener = callHandler.startCall(serverCall, new Metadata()); + listener.onMessage(1); + listener.onMessage(1); + assertThat(serverCall.responses).isEmpty(); + assertEquals(Status.Code.INTERNAL, serverCall.status.getCode()); + assertEquals(ServerCalls.TOO_MANY_REQUESTS, serverCall.status.getDescription()); + // ensure onHalfClose does not invoke + listener.onHalfClose(); + } + + @Test + public void clientSendsOne_errorTooManyRequests_serverStreaming() { + ServerCallRecorder serverCall = new ServerCallRecorder(SERVER_STREAMING_METHOD); + ServerCallHandler callHandler = + ServerCalls.asyncServerStreamingCall( + new ServerCalls.ServerStreamingMethod() { + @Override + public void invoke(Integer req, StreamObserver responseObserver) { + fail("should not be reached"); + } + }); + ServerCall.Listener listener = callHandler.startCall(serverCall, new Metadata()); + listener.onMessage(1); + listener.onMessage(1); + assertThat(serverCall.responses).isEmpty(); + assertEquals(Status.Code.INTERNAL, serverCall.status.getCode()); + assertEquals(ServerCalls.TOO_MANY_REQUESTS, serverCall.status.getDescription()); + // ensure onHalfClose does not invoke + listener.onHalfClose(); + } + @Test public void inprocessTransportManualFlow() throws Exception { final Semaphore semaphore = new Semaphore(1); @@ -374,15 +460,19 @@ public Integer parse(InputStream stream) { } private static class ServerCallRecorder extends ServerCall { - private List requestCalls = new ArrayList(); + private final MethodDescriptor methodDescriptor; + private final List requestCalls = new ArrayList(); + private final List responses = new ArrayList(); private Metadata headers; - private Integer message; private Metadata trailers; private Status status; private boolean isCancelled; - private MethodDescriptor methodDescriptor; private boolean isReady; + public ServerCallRecorder(MethodDescriptor methodDescriptor) { + this.methodDescriptor = methodDescriptor; + } + @Override public void request(int numMessages) { requestCalls.add(numMessages); @@ -395,7 +485,7 @@ public void sendHeaders(Metadata headers) { @Override public void sendMessage(Integer message) { - this.message = message; + this.responses.add(message); } @Override