Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

stub: detect invalid states on server side (eg zero responses for unary) #3068

Merged
40 changes: 29 additions & 11 deletions core/src/main/java/io/grpc/internal/ServerCallImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,19 @@
import io.grpc.InternalDecompressorRegistry;
import io.grpc.Metadata;
import io.grpc.MethodDescriptor;
import io.grpc.MethodDescriptor.MethodType;
import io.grpc.ServerCall;
import io.grpc.Status;
import java.io.IOException;
import java.io.InputStream;
import java.util.List;

final class ServerCallImpl<ReqT, RespT> extends ServerCall<ReqT, RespT> {

@VisibleForTesting
static String TOO_MANY_RESPONSES = "Too many responses";
@VisibleForTesting
static String MISSING_RESPONSE = "Completed without a response";

private final ServerStream stream;
private final MethodDescriptor<ReqT, RespT> method;
private final Context.CancellableContext context;
Expand All @@ -54,6 +59,7 @@ final class ServerCallImpl<ReqT, RespT> extends ServerCall<ReqT, RespT> {
private boolean sendHeadersCalled;
private boolean closeCalled;
private Compressor compressor;
private boolean messageSent;

ServerCallImpl(ServerStream stream, MethodDescriptor<ReqT, RespT> method,
Metadata inboundHeaders, Context.CancellableContext context,
Expand Down Expand Up @@ -115,6 +121,13 @@ public void sendHeaders(Metadata headers) {
public void sendMessage(RespT message) {
checkState(sendHeadersCalled, "sendHeaders has not been called");
checkState(!closeCalled, "call is closed");

if (method.getType().serverSendsOneMessage() && messageSent) {
internalClose(Status.INTERNAL.withDescription(TOO_MANY_RESPONSES));
return;
}

messageSent = true;
try {
InputStream resp = method.streamResponse(message);
stream.writeMessage(resp);
Expand Down Expand Up @@ -151,6 +164,12 @@ public boolean isReady() {
public void close(Status status, Metadata trailers) {
checkState(!closeCalled, "call already closed");
closeCalled = true;

if (status.isOk() && method.getType().serverSendsOneMessage() && !messageSent) {
internalClose(Status.INTERNAL.withDescription(MISSING_RESPONSE));
return;
}

stream.close(status, trailers);
}

Expand Down Expand Up @@ -178,6 +197,15 @@ public MethodDescriptor<ReqT, RespT> getMethodDescriptor() {
return method;
}

/**
* Close the {@link ServerStream} because an internal error occurred. Allow the application to
* run until completion, but silently ignore interactions with the {@link ServerStream} from now
* on.
*/
private void internalClose(Status internalError) {
stream.close(internalError, new Metadata());
}

/**
* All of these callbacks are assumed to called on an application thread, and the caller is
* responsible for handling thrown exceptions.
Expand All @@ -187,7 +215,6 @@ static final class ServerStreamListenerImpl<ReqT> implements ServerStreamListene
private final ServerCallImpl<ReqT, ?> call;
private final ServerCall.Listener<ReqT> listener;
private final Context.CancellableContext context;
private boolean messageReceived;

public ServerStreamListenerImpl(
ServerCallImpl<ReqT, ?> call, ServerCall.Listener<ReqT> listener,
Expand Down Expand Up @@ -216,15 +243,6 @@ public void messageRead(final InputStream message) {
if (call.cancelled) {
return;
}
// Special case for unary calls.
if (messageReceived && call.method.getType() == MethodType.UNARY) {
call.stream.close(Status.INTERNAL.withDescription(
"More than one request messages for unary call or server streaming call"),
new Metadata());
return;
}
messageReceived = true;

listener.onMessage(call.method.parseRequest(message));
} catch (Throwable e) {
t = e;
Expand Down
4 changes: 4 additions & 0 deletions core/src/main/java/io/grpc/internal/ServerStream.java
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ public interface ServerStream extends Stream {
* {@link io.grpc.Status.Code#OK} implies normal termination of the
* stream. Any other value implies abnormal termination.
*
* <p>Attempts to read from or write to the stream after closing
* should be ignored by implementations, and should not throw
* exceptions.
*
* @param status details of the closure
* @param trailers an additional block of metadata to pass to the client on stream closure.
*/
Expand Down
159 changes: 131 additions & 28 deletions core/src/test/java/io/grpc/internal/ServerCallImplTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@
import static org.junit.Assert.fail;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.isA;
import static org.mockito.Matchers.same;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

Expand All @@ -48,35 +51,41 @@
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.ArgumentCaptor;
import org.mockito.Captor;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.MockitoAnnotations;

@RunWith(JUnit4.class)
public class ServerCallImplTest {
@Rule public final ExpectedException thrown = ExpectedException.none();
@Mock private ServerStream stream;
@Mock private ServerCall.Listener<Long> callListener;
@Captor private ArgumentCaptor<Status> statusCaptor;

private ServerCallImpl<Long, Long> call;
private Context.CancellableContext context;

private final MethodDescriptor<Long, Long> method = MethodDescriptor.<Long, Long>newBuilder()
.setType(MethodType.UNARY)
.setFullMethodName("/service/method")
.setRequestMarshaller(new LongMarshaller())
.setResponseMarshaller(new LongMarshaller())
.build();
private static final MethodDescriptor<Long, Long> UNARY_METHOD =
MethodDescriptor.<Long, Long>newBuilder()
.setType(MethodType.UNARY)
.setFullMethodName("/service/method")
.setRequestMarshaller(new LongMarshaller())
.setResponseMarshaller(new LongMarshaller())
.build();

private static final MethodDescriptor<Long, Long> CLIENT_STREAMING_METHOD =
MethodDescriptor.<Long, Long>newBuilder()
.setType(MethodType.UNARY)
.setFullMethodName("/service/method")
.setRequestMarshaller(new LongMarshaller())
.setResponseMarshaller(new LongMarshaller())
.build();

private final Metadata requestHeaders = new Metadata();

@Before
public void setUp() {
MockitoAnnotations.initMocks(this);
context = Context.ROOT.withCancellation();
call = new ServerCallImpl<Long, Long>(stream, method, requestHeaders, context,
call = new ServerCallImpl<Long, Long>(stream, UNARY_METHOD, requestHeaders, context,
DecompressorRegistry.getDefaultInstance(), CompressorRegistry.getDefaultInstance());
}

Expand Down Expand Up @@ -158,6 +167,114 @@ public void sendMessage_closesOnFailure() {
verify(stream).close(isA(Status.class), isA(Metadata.class));
}

@Test
public void sendMessage_serverSendsOne_closeOnSecondCall_unary() {
sendMessage_serverSendsOne_closeOnSecondCall(UNARY_METHOD);
}

@Test
public void sendMessage_serverSendsOne_closeOnSecondCall_clientStreaming() {
sendMessage_serverSendsOne_closeOnSecondCall(CLIENT_STREAMING_METHOD);
}

private void sendMessage_serverSendsOne_closeOnSecondCall(
MethodDescriptor<Long, Long> method) {
ServerCallImpl<Long, Long> serverCall = new ServerCallImpl<Long, Long>(
stream,
method,
requestHeaders,
context,
DecompressorRegistry.getDefaultInstance(),
CompressorRegistry.getDefaultInstance());
serverCall.sendHeaders(new Metadata());
serverCall.sendMessage(1L);
verify(stream, times(1)).writeMessage(any(InputStream.class));
verify(stream, never()).close(any(Status.class), any(Metadata.class));

// trying to send a second message causes gRPC to close the underlying stream
serverCall.sendMessage(1L);
verify(stream, times(1)).writeMessage(any(InputStream.class));
ArgumentCaptor<Status> statusCaptor = ArgumentCaptor.forClass(Status.class);
ArgumentCaptor<Metadata> metadataCaptor = ArgumentCaptor.forClass(Metadata.class);
verify(stream, times(1)).close(statusCaptor.capture(), metadataCaptor.capture());
assertEquals(Status.Code.INTERNAL, statusCaptor.getValue().getCode());
assertEquals(ServerCallImpl.TOO_MANY_RESPONSES, statusCaptor.getValue().getDescription());
assertTrue(metadataCaptor.getValue().keys().isEmpty());
}

@Test
public void sendMessage_serverSendsOne_closeOnSecondCall_appRunToCompletion_unary() {
sendMessage_serverSendsOne_closeOnSecondCall_appRunToCompletion(UNARY_METHOD);
}

@Test
public void sendMessage_serverSendsOne_closeOnSecondCall_appRunToCompletion_clientStreaming() {
sendMessage_serverSendsOne_closeOnSecondCall_appRunToCompletion(CLIENT_STREAMING_METHOD);
}

private void sendMessage_serverSendsOne_closeOnSecondCall_appRunToCompletion(
MethodDescriptor<Long, Long> method) {
ServerCallImpl<Long, Long> serverCall = new ServerCallImpl<Long, Long>(
stream,
method,
requestHeaders,
context,
DecompressorRegistry.getDefaultInstance(),
CompressorRegistry.getDefaultInstance());
serverCall.sendHeaders(new Metadata());
serverCall.sendMessage(1L);
serverCall.sendMessage(1L);
verify(stream, times(1)).writeMessage(any(InputStream.class));
verify(stream, times(1)).close(any(Status.class), any(Metadata.class));

// App runs to completion but everything is ignored
serverCall.sendMessage(1L);
serverCall.close(Status.OK, new Metadata());
try {
serverCall.close(Status.OK, new Metadata());
fail("calling a second time should still cause an error");
} catch (IllegalStateException expected) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of the try-catch, use thrown.expect(IllegalStateException.class) just before calling close the second time?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

// noop
}
}

@Test
public void serverSendsOne_okFailsOnMissingResponse_unary() {
serverSendsOne_okFailsOnMissingResponse(UNARY_METHOD);
}

@Test
public void serverSendsOne_okFailsOnMissingResponse_clientStreaming() {
serverSendsOne_okFailsOnMissingResponse(CLIENT_STREAMING_METHOD);
}

private void serverSendsOne_okFailsOnMissingResponse(
MethodDescriptor<Long, Long> method) {
ServerCallImpl<Long, Long> serverCall = new ServerCallImpl<Long, Long>(
stream,
method,
requestHeaders,
context,
DecompressorRegistry.getDefaultInstance(),
CompressorRegistry.getDefaultInstance());
serverCall.close(Status.OK, new Metadata());
ArgumentCaptor<Status> statusCaptor = ArgumentCaptor.forClass(Status.class);
ArgumentCaptor<Metadata> metadataCaptor = ArgumentCaptor.forClass(Metadata.class);
verify(stream, times(1)).close(statusCaptor.capture(), metadataCaptor.capture());
assertEquals(Status.Code.INTERNAL, statusCaptor.getValue().getCode());
assertEquals(ServerCallImpl.MISSING_RESPONSE, statusCaptor.getValue().getDescription());
assertTrue(metadataCaptor.getValue().keys().isEmpty());
}

@Test
public void serverSendsOne_canErrorWithoutResponse() {
final String description = "test description";
final Status status = Status.RESOURCE_EXHAUSTED.withDescription(description);
final Metadata metadata = new Metadata();
call.close(status, metadata);
verify(stream, times(1)).close(same(status), same(metadata));
}

@Test
public void isReady() {
when(stream.isReady()).thenReturn(true);
Expand Down Expand Up @@ -260,34 +377,20 @@ public void streamListener_onReady_onlyOnce() {
public void streamListener_messageRead() {
ServerStreamListenerImpl<Long> streamListener =
new ServerCallImpl.ServerStreamListenerImpl<Long>(call, callListener, context);
streamListener.messageRead(method.streamRequest(1234L));

verify(callListener).onMessage(1234L);
}

@Test
public void streamListener_messageRead_unaryFailsOnMultiple() {
ServerStreamListenerImpl<Long> streamListener =
new ServerCallImpl.ServerStreamListenerImpl<Long>(call, callListener, context);
streamListener.messageRead(method.streamRequest(1234L));
streamListener.messageRead(method.streamRequest(1234L));
streamListener.messageRead(UNARY_METHOD.streamRequest(1234L));

// Makes sure this was only called once.
verify(callListener).onMessage(1234L);

verify(stream).close(statusCaptor.capture(), Mockito.isA(Metadata.class));
assertEquals(Status.Code.INTERNAL, statusCaptor.getValue().getCode());
}

@Test
public void streamListener_messageRead_onlyOnce() {
ServerStreamListenerImpl<Long> streamListener =
new ServerCallImpl.ServerStreamListenerImpl<Long>(call, callListener, context);
streamListener.messageRead(method.streamRequest(1234L));
streamListener.messageRead(UNARY_METHOD.streamRequest(1234L));
// canceling the call should short circuit future halfClosed() calls.
streamListener.closed(Status.CANCELLED);

streamListener.messageRead(method.streamRequest(1234L));
streamListener.messageRead(UNARY_METHOD.streamRequest(1234L));

verify(callListener).onMessage(1234L);
}
Expand All @@ -300,7 +403,7 @@ public void streamListener_unexpectedRuntimeException() {
.when(callListener)
.onMessage(any(Long.class));

InputStream inputStream = method.streamRequest(1234L);
InputStream inputStream = UNARY_METHOD.streamRequest(1234L);

thrown.expect(RuntimeException.class);
thrown.expectMessage("unexpected exception");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -300,52 +300,6 @@ public void onCompleted() {
verify(testHelper, never()).onRpcError(any(Throwable.class));
}

/**
* Example for testing async client-streaming.
*/
@Test
public void recordRoute_wrongResponse() throws Exception {
client.setRandom(noRandomness);
Point point1 = Point.newBuilder().setLatitude(1).setLongitude(1).build();
final Feature requestFeature1 =
Feature.newBuilder().setLocation(point1).build();
final List<Feature> features = Arrays.asList(requestFeature1);

// implement the fake service
RouteGuideImplBase recordRouteImpl =
new RouteGuideImplBase() {
@Override
public StreamObserver<Point> recordRoute(StreamObserver<RouteSummary> responseObserver) {
RouteSummary response = RouteSummary.getDefaultInstance();
// sending more than one responses is not right for client-streaming call.
responseObserver.onNext(response);
responseObserver.onNext(response);
responseObserver.onCompleted();

return new StreamObserver<Point>() {
@Override
public void onNext(Point value) {
}

@Override
public void onError(Throwable t) {
}

@Override
public void onCompleted() {
}
};
}
};
serviceRegistry.addService(recordRouteImpl);

client.recordRoute(features, 4);

ArgumentCaptor<Throwable> errorCaptor = ArgumentCaptor.forClass(Throwable.class);
verify(testHelper).onRpcError(errorCaptor.capture());
assertEquals(Status.Code.CANCELLED, Status.fromThrowable(errorCaptor.getValue()).getCode());
}

/**
* Example for testing async client-streaming.
*/
Expand Down
Loading