Skip to content

Commit

Permalink
netty: Allow deframer errors to close stream with a status code
Browse files Browse the repository at this point in the history
Today, deframer errors cancel the stream without communicating a status code
to the peer. This change causes deframer errors to trigger a best-effort
attempt to send trailers with a status code so that the peer understands
why the stream is being closed.

Fixes #3996
  • Loading branch information
ryanpbrewster authored and ejona86 committed Apr 24, 2024
1 parent 11612b4 commit e036b1b
Show file tree
Hide file tree
Showing 5 changed files with 117 additions and 10 deletions.
26 changes: 25 additions & 1 deletion netty/src/main/java/io/grpc/netty/CancelServerStreamCommand.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,23 @@
final class CancelServerStreamCommand extends WriteQueue.AbstractQueuedCommand {
private final NettyServerStream.TransportState stream;
private final Status reason;
private final PeerNotify peerNotify;

CancelServerStreamCommand(NettyServerStream.TransportState stream, Status reason) {
private CancelServerStreamCommand(
NettyServerStream.TransportState stream, Status reason, PeerNotify peerNotify) {
this.stream = Preconditions.checkNotNull(stream, "stream");
this.reason = Preconditions.checkNotNull(reason, "reason");
this.peerNotify = Preconditions.checkNotNull(peerNotify, "peerNotify");
}

static CancelServerStreamCommand withReset(
NettyServerStream.TransportState stream, Status reason) {
return new CancelServerStreamCommand(stream, reason, PeerNotify.RESET);
}

static CancelServerStreamCommand withReason(
NettyServerStream.TransportState stream, Status reason) {
return new CancelServerStreamCommand(stream, reason, PeerNotify.BEST_EFFORT_STATUS);
}

NettyServerStream.TransportState stream() {
Expand All @@ -41,6 +54,10 @@ Status reason() {
return reason;
}

boolean wantsHeaders() {
return peerNotify == PeerNotify.BEST_EFFORT_STATUS;
}

@Override
public boolean equals(Object o) {
if (this == o) {
Expand Down Expand Up @@ -68,4 +85,11 @@ public String toString() {
.add("reason", reason)
.toString();
}

private enum PeerNotify {
/** Notify the peer by sending a RST_STREAM with no other information. */
RESET,
/** Notify the peer about the {@link #reason} by sending structured headers, if possible. */
BEST_EFFORT_STATUS,
}
}
32 changes: 30 additions & 2 deletions netty/src/main/java/io/grpc/netty/NettyServerHandler.java
Original file line number Diff line number Diff line change
Expand Up @@ -788,9 +788,37 @@ private void cancelStream(ChannelHandlerContext ctx, CancelServerStreamCommand c
PerfMark.linkIn(cmd.getLink());
// Notify the listener if we haven't already.
cmd.stream().transportReportStatus(cmd.reason());
// Terminate the stream.
encoder().writeRstStream(ctx, cmd.stream().id(), Http2Error.CANCEL.code(), promise);

// Now we need to decide how we're going to notify the peer that this stream is closed.
// If possible, it's nice to inform the peer _why_ this stream was cancelled by sending
// a structured headers frame.
if (shouldCloseStreamWithHeaders(cmd, connection())) {
Metadata md = new Metadata();
md.put(InternalStatus.CODE_KEY, cmd.reason());
if (cmd.reason().getDescription() != null) {
md.put(InternalStatus.MESSAGE_KEY, cmd.reason().getDescription());
}
Http2Headers headers = Utils.convertServerHeaders(md);
encoder().writeHeaders(
ctx, cmd.stream().id(), headers, /* padding = */ 0, /* endStream = */ true, promise);
} else {
// Terminate the stream.
encoder().writeRstStream(ctx, cmd.stream().id(), Http2Error.CANCEL.code(), promise);
}
}
}

// Determine whether a CancelServerStreamCommand should try to close the stream with a
// HEADERS or a RST_STREAM frame. The caller has some influence over this (they can
// configure cmd.wantsHeaders()). The state of the stream also has an influence: we
// only try to send HEADERS if the stream exists and hasn't already sent any headers.
private static boolean shouldCloseStreamWithHeaders(
CancelServerStreamCommand cmd, Http2Connection conn) {
if (!cmd.wantsHeaders()) {
return false;
}
Http2Stream stream = conn.stream(cmd.stream().id());
return stream != null && !stream.isHeadersSent();
}

private void gracefulClose(final ChannelHandlerContext ctx, final GracefulServerCloseCommand msg,
Expand Down
6 changes: 3 additions & 3 deletions netty/src/main/java/io/grpc/netty/NettyServerStream.java
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ public void writeTrailers(Metadata trailers, boolean headersSent, Status status)
@Override
public void cancel(Status status) {
try (TaskCloseable ignore = PerfMark.traceTask("NettyServerStream$Sink.cancel")) {
writeQueue.enqueue(new CancelServerStreamCommand(transportState(), status), true);
writeQueue.enqueue(CancelServerStreamCommand.withReset(transportState(), status), true);
}
}
}
Expand Down Expand Up @@ -189,7 +189,7 @@ public void deframeFailed(Throwable cause) {
log.log(Level.WARNING, "Exception processing message", cause);
Status status = Status.fromThrowable(cause);
transportReportStatus(status);
handler.getWriteQueue().enqueue(new CancelServerStreamCommand(this, status), true);
handler.getWriteQueue().enqueue(CancelServerStreamCommand.withReason(this, status), true);
}

private void onWriteFrameData(ChannelFuture future, int numMessages, int numBytes) {
Expand Down Expand Up @@ -222,7 +222,7 @@ private void handleWriteFutureFailures(ChannelFuture future) {
*/
protected void http2ProcessingFailed(Status status) {
transportReportStatus(status);
handler.getWriteQueue().enqueue(new CancelServerStreamCommand(this, status), true);
handler.getWriteQueue().enqueue(CancelServerStreamCommand.withReset(this, status), true);
}

void inboundDataReceived(ByteBuf frame, boolean endOfStream) {
Expand Down
34 changes: 33 additions & 1 deletion netty/src/test/java/io/grpc/netty/NettyServerHandlerTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,10 @@
import java.io.InputStream;
import java.nio.channels.ClosedChannelException;
import java.util.Arrays;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Queue;
import java.util.concurrent.TimeUnit;
import org.junit.Before;
Expand Down Expand Up @@ -469,11 +471,41 @@ public void connectionWindowShouldBeOverridden() throws Exception {
public void cancelShouldSendRstStream() throws Exception {
manualSetUp();
createStream();
enqueue(new CancelServerStreamCommand(stream.transportState(), Status.DEADLINE_EXCEEDED));
enqueue(CancelServerStreamCommand.withReset(stream.transportState(), Status.DEADLINE_EXCEEDED));
verifyWrite().writeRstStream(eq(ctx()), eq(stream.transportState().id()),
eq(Http2Error.CANCEL.code()), any(ChannelPromise.class));
}

@Test
public void cancelWithNotify_shouldSendHeaders() throws Exception {
manualSetUp();
createStream();

enqueue(CancelServerStreamCommand.withReason(
stream.transportState(),
Status.RESOURCE_EXHAUSTED.withDescription("my custom description")
));

ArgumentCaptor<Http2Headers> captor = ArgumentCaptor.forClass(Http2Headers.class);
verifyWrite()
.writeHeaders(
eq(ctx()),
eq(STREAM_ID),
captor.capture(),
eq(0),
eq(true),
any(ChannelPromise.class));

// For arcane reasons, the specific implementation of Http2Headers here doesn't actually support
// methods like `get(...)`, so we have to manually convert it into a map.
Map<String, String> actualHeaders = new HashMap<>();
for (Map.Entry<CharSequence, CharSequence> entry : captor.getValue()) {
actualHeaders.put(entry.getKey().toString(), entry.getValue().toString());
}
assertEquals("8", actualHeaders.get(InternalStatus.CODE_KEY.name()));
assertEquals("my custom description", actualHeaders.get(InternalStatus.MESSAGE_KEY.name()));
}

@Test
public void headersWithInvalidContentTypeShouldFail() throws Exception {
manualSetUp();
Expand Down
29 changes: 26 additions & 3 deletions netty/src/test/java/io/grpc/netty/NettyServerStreamTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

import static com.google.common.truth.Truth.assertThat;
import static com.google.common.truth.Truth.assertWithMessage;
import static io.grpc.internal.GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE;
import static io.grpc.netty.NettyTestUtil.messageFrame;
import static io.netty.handler.codec.http2.Http2Error.PROTOCOL_ERROR;
import static io.netty.handler.codec.http2.Http2Exception.connectionError;
Expand All @@ -37,6 +36,7 @@
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;

import com.google.common.base.Strings;
import com.google.common.collect.ImmutableListMultimap;
import com.google.common.collect.ListMultimap;
import io.grpc.Attributes;
Expand Down Expand Up @@ -73,6 +73,8 @@
/** Unit tests for {@link NettyServerStream}. */
@RunWith(JUnit4.class)
public class NettyServerStreamTest extends NettyStreamTestBase<NettyServerStream> {
private static final int TEST_MAX_MESSAGE_SIZE = 128;

@Mock
protected ServerStreamListener serverListener;

Expand Down Expand Up @@ -380,18 +382,39 @@ public void emptyFramerShouldSendNoPayload() {
public void cancelStreamShouldSucceed() {
stream().cancel(Status.DEADLINE_EXCEEDED);
verify(writeQueue).enqueue(
new CancelServerStreamCommand(stream().transportState(), Status.DEADLINE_EXCEEDED),
CancelServerStreamCommand.withReset(stream().transportState(), Status.DEADLINE_EXCEEDED),
true);
}

@Test
public void oversizedMessagesResultInResourceExhaustedTrailers() throws Exception {
@SuppressWarnings("InlineMeInliner") // Requires Java 11
String oversizedMsg = Strings.repeat("a", TEST_MAX_MESSAGE_SIZE + 1);
stream.request(1);
stream.transportState().inboundDataReceived(messageFrame(oversizedMsg), false);
assertNull("message should have caused a deframer error", listenerMessageQueue().poll());

ArgumentCaptor<CancelServerStreamCommand> cancelCmdCap =
ArgumentCaptor.forClass(CancelServerStreamCommand.class);
verify(writeQueue).enqueue(cancelCmdCap.capture(), eq(true));

Status status = Status.RESOURCE_EXHAUSTED
.withDescription("gRPC message exceeds maximum size 128: 129");

CancelServerStreamCommand actualCmd = cancelCmdCap.getValue();
assertThat(actualCmd.reason().getCode()).isEqualTo(status.getCode());
assertThat(actualCmd.reason().getDescription()).isEqualTo(status.getDescription());
assertThat(actualCmd.wantsHeaders()).isTrue();
}

@Override
@SuppressWarnings("DirectInvocationOnMock")
protected NettyServerStream createStream() {
when(handler.getWriteQueue()).thenReturn(writeQueue);
StatsTraceContext statsTraceCtx = StatsTraceContext.NOOP;
TransportTracer transportTracer = new TransportTracer();
NettyServerStream.TransportState state = new NettyServerStream.TransportState(
handler, channel.eventLoop(), http2Stream, DEFAULT_MAX_MESSAGE_SIZE, statsTraceCtx,
handler, channel.eventLoop(), http2Stream, TEST_MAX_MESSAGE_SIZE, statsTraceCtx,
transportTracer, "method");
NettyServerStream stream = new NettyServerStream(channel, state, Attributes.EMPTY,
"test-authority", statsTraceCtx);
Expand Down

0 comments on commit e036b1b

Please sign in to comment.