Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

netty: Handle write queue promise failures #11016

Merged
merged 16 commits into from
Apr 16, 2024
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,7 @@ public void inboundDataReceived(ReadableBuffer frame, boolean endOfStream) {
*/
public final void transportReportStatus(final Status status) {
Preconditions.checkArgument(!status.isOk(), "status must not be OK");
onStreamDeallocated();
if (deframerClosed) {
deframerClosedTask = null;
closeListener(status);
Expand All @@ -300,6 +301,7 @@ public void run() {
* #transportReportStatus}.
*/
public void complete() {
onStreamDeallocated();
if (deframerClosed) {
deframerClosedTask = null;
closeListener(Status.OK);
Expand Down Expand Up @@ -335,7 +337,6 @@ private void closeListener(Status newStatus) {
getTransportTracer().reportStreamClosed(closedStatus.isOk());
}
listenerClosed = true;
onStreamDeallocated();
listener().closed(newStatus);
}
}
Expand Down
6 changes: 6 additions & 0 deletions core/src/main/java/io/grpc/internal/AbstractStream.java
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,12 @@ protected final void onStreamDeallocated() {
}
}

protected boolean isStreamDeallocated() {
synchronized (onReadyLock) {
return deallocated;
}
}

/**
* Event handler to be called by the subclass when a number of bytes are being queued for
* sending to the remote endpoint.
Expand Down
39 changes: 26 additions & 13 deletions netty/src/main/java/io/grpc/netty/NettyClientStream.java
Original file line number Diff line number Diff line change
Expand Up @@ -182,20 +182,10 @@
if (numBytes > 0) {
// Add the bytes to outbound flow control.
onSendingBytes(numBytes);
ChannelFutureListener failureListener =
future -> transportState().onWriteFrameData(future, numMessages, numBytes);
writeQueue.enqueue(new SendGrpcFrameCommand(transportState(), bytebuf, endOfStream), flush)
.addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
// If the future succeeds when http2stream is null, the stream has been cancelled
// before it began and Netty is purging pending writes from the flow-controller.
if (future.isSuccess() && transportState().http2Stream() != null) {
// Remove the bytes from outbound flow control, optionally notifying
// the client that they can send more bytes.
transportState().onSentBytes(numBytes);
NettyClientStream.this.getTransportTracer().reportMessageSent(numMessages);
}
}
});
.addListener(failureListener);
} else {
// The frame is empty and will not impact outbound flow control. Just send it.
writeQueue.enqueue(
Expand Down Expand Up @@ -306,6 +296,29 @@
handler.getWriteQueue().enqueue(new CancelClientStreamCommand(this, status), true);
}

private void onWriteFrameData(ChannelFuture future, int numMessages, int numBytes) {
// If the future succeeds when http2stream is null, the stream has been cancelled
// before it began and Netty is purging pending writes from the flow-controller.
if (future.isSuccess() && http2Stream() == null) {
return;
}

Check warning on line 304 in netty/src/main/java/io/grpc/netty/NettyClientStream.java

View check run for this annotation

Codecov / codecov/patch

netty/src/main/java/io/grpc/netty/NettyClientStream.java#L304

Added line #L304 was not covered by tests

if (future.isSuccess()) {
// Remove the bytes from outbound flow control, optionally notifying
// the client that they can send more bytes.
onSentBytes(numBytes);
getTransportTracer().reportMessageSent(numMessages);
} else if (!isStreamDeallocated()) {
// Future failed, fail RPC.
// Normally we don't need to do anything here because the cause of a failed future
// while writing DATA frames would be an IO error and the stream is already closed.
// However, we still need handle any unexpected failures raised in Netty.
// Note: isStreamDeallocated() protects from spamming stream resets by scheduling multiple
// CancelClientStreamCommand commands.
http2ProcessingFailed(statusFromFailedFuture(future), true, new Metadata());
}
}

@Override
public void runOnTransportThread(final Runnable r) {
if (eventLoop.inEventLoop()) {
Expand Down
3 changes: 1 addition & 2 deletions netty/src/main/java/io/grpc/netty/NettyServerHandler.java
Original file line number Diff line number Diff line change
Expand Up @@ -502,8 +502,7 @@ private void onHeadersRead(ChannelHandlerContext ctx, int streamId, Http2Headers
state,
attributes,
authority,
statsTraceCtx,
transportTracer);
statsTraceCtx);
transportListener.streamCreated(stream, method, metadata);
state.onStreamAllocated();
http2Stream.setProperty(streamKey, state);
Expand Down
86 changes: 53 additions & 33 deletions netty/src/main/java/io/grpc/netty/NettyServerStream.java
Original file line number Diff line number Diff line change
Expand Up @@ -52,22 +52,19 @@ class NettyServerStream extends AbstractServerStream {
private final WriteQueue writeQueue;
private final Attributes attributes;
private final String authority;
private final TransportTracer transportTracer;
private final int streamId;

public NettyServerStream(
Channel channel,
TransportState state,
Attributes transportAttrs,
String authority,
StatsTraceContext statsTraceCtx,
TransportTracer transportTracer) {
StatsTraceContext statsTraceCtx) {
super(new NettyWritableBufferAllocator(channel.alloc()), statsTraceCtx);
this.state = checkNotNull(state, "transportState");
this.writeQueue = state.handler.getWriteQueue();
this.attributes = checkNotNull(transportAttrs);
this.authority = authority;
this.transportTracer = checkNotNull(transportTracer, "transportTracer");
// Read the id early to avoid reading transportState later.
this.streamId = transportState().id();
}
Expand Down Expand Up @@ -96,48 +93,37 @@ private class Sink implements AbstractServerStream.Sink {
@Override
public void writeHeaders(Metadata headers, boolean flush) {
try (TaskCloseable ignore = PerfMark.traceTask("NettyServerStream$Sink.writeHeaders")) {
writeQueue.enqueue(
SendResponseHeadersCommand.createHeaders(
transportState(),
Utils.convertServerHeaders(headers)),
flush);
Http2Headers http2headers = Utils.convertServerHeaders(headers);
SendResponseHeadersCommand headersCommand =
SendResponseHeadersCommand.createHeaders(transportState(), http2headers);
writeQueue.enqueue(headersCommand, true)
sergiitk marked this conversation as resolved.
Show resolved Hide resolved
.addListener((ChannelFutureListener) transportState()::handleWriteFutureFailures);
}
}

private void writeFrameInternal(WritableBuffer frame, boolean flush, final int numMessages) {
Preconditions.checkArgument(numMessages >= 0);
ByteBuf bytebuf = ((NettyWritableBuffer) frame).bytebuf().touch();
final int numBytes = bytebuf.readableBytes();
// Add the bytes to outbound flow control.
onSendingBytes(numBytes);
writeQueue.enqueue(new SendGrpcFrameCommand(transportState(), bytebuf, false), flush)
.addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
// Remove the bytes from outbound flow control, optionally notifying
// the client that they can send more bytes.
transportState().onSentBytes(numBytes);
if (future.isSuccess()) {
transportTracer.reportMessageSent(numMessages);
}
}
});
}

@Override
public void writeFrame(WritableBuffer frame, boolean flush, final int numMessages) {
try (TaskCloseable ignore = PerfMark.traceTask("NettyServerStream$Sink.writeFrame")) {
writeFrameInternal(frame, flush, numMessages);
Preconditions.checkArgument(numMessages >= 0);
ByteBuf bytebuf = ((NettyWritableBuffer) frame).bytebuf().touch();
final int numBytes = bytebuf.readableBytes();
// Add the bytes to outbound flow control.
onSendingBytes(numBytes);
ChannelFutureListener failureListener =
future -> transportState().onWriteFrameData(future, numMessages, numBytes);
writeQueue.enqueue(new SendGrpcFrameCommand(transportState(), bytebuf, false), flush)
.addListener(failureListener);
}
}

@Override
public void writeTrailers(Metadata trailers, boolean headersSent, Status status) {
try (TaskCloseable ignore = PerfMark.traceTask("NettyServerStream$Sink.writeTrailers")) {
Http2Headers http2Trailers = Utils.convertTrailers(trailers, headersSent);
writeQueue.enqueue(
SendResponseHeadersCommand.createTrailers(transportState(), http2Trailers, status),
true);
SendResponseHeadersCommand trailersCommand =
SendResponseHeadersCommand.createTrailers(transportState(), http2Trailers, status);
writeQueue.enqueue(trailersCommand, true)
.addListener((ChannelFutureListener) transportState()::handleWriteFutureFailures);
}
}

Expand Down Expand Up @@ -206,6 +192,40 @@ public void deframeFailed(Throwable cause) {
handler.getWriteQueue().enqueue(new CancelServerStreamCommand(this, status), true);
}

private void onWriteFrameData(ChannelFuture future, int numMessages, int numBytes) {
// Remove the bytes from outbound flow control, optionally notifying
// the client that they can send more bytes.
// TODO(sergiitk): should onSentBytes be called only on success?
sergiitk marked this conversation as resolved.
Show resolved Hide resolved
onSentBytes(numBytes);
if (future.isSuccess()) {
getTransportTracer().reportMessageSent(numMessages);
} else {
handleWriteFutureFailures(future);
}
}

private void handleWriteFutureFailures(ChannelFuture future) {
// isStreamDeallocated() check protects from spamming stream resets by scheduling multiple
// CancelServerStreamCommand commands.
if (future.isSuccess() || isStreamDeallocated()) {
return;
}

// Future failed, fail RPC.
// Normally we don't need to do anything here because the cause of a failed future
// while writing DATA frames would be an IO error and the stream is already closed.
// However, we still need handle any unexpected failures raised in Netty.
http2ProcessingFailed(Utils.statusFromThrowable(future.cause()));
}

/**
* Called to process a failure in HTTP/2 processing.
*/
protected void http2ProcessingFailed(Status status) {
transportReportStatus(status);
handler.getWriteQueue().enqueue(new CancelServerStreamCommand(this, status), true);
}

void inboundDataReceived(ByteBuf frame, boolean endOfStream) {
super.inboundDataReceived(new NettyReadableBuffer(frame.retain()), endOfStream);
}
Expand Down
50 changes: 50 additions & 0 deletions netty/src/test/java/io/grpc/netty/NettyClientStreamTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
import static io.grpc.netty.Utils.CONTENT_TYPE_GRPC;
import static io.grpc.netty.Utils.CONTENT_TYPE_HEADER;
import static io.grpc.netty.Utils.STATUS_OK;
import static io.netty.handler.codec.http2.Http2Error.PROTOCOL_ERROR;
import static io.netty.handler.codec.http2.Http2Exception.connectionError;
import static io.netty.util.CharsetUtil.UTF_8;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
Expand All @@ -34,6 +36,7 @@
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.ArgumentMatchers.isA;
import static org.mockito.ArgumentMatchers.same;
import static org.mockito.Mockito.atLeast;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
Expand Down Expand Up @@ -62,6 +65,7 @@
import io.netty.channel.ChannelPromise;
import io.netty.channel.DefaultChannelPromise;
import io.netty.handler.codec.http2.DefaultHttp2Headers;
import io.netty.handler.codec.http2.Http2Exception;
import io.netty.handler.codec.http2.Http2Headers;
import io.netty.util.AsciiString;
import java.io.BufferedInputStream;
Expand All @@ -75,6 +79,7 @@
import org.junit.runners.JUnit4;
import org.mockito.ArgumentCaptor;
import org.mockito.ArgumentMatchers;
import org.mockito.InOrder;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.invocation.InvocationOnMock;
Expand Down Expand Up @@ -205,6 +210,51 @@ public void writeMessageShouldSendRequestUnknownLength() throws Exception {
eq(true));
}

@Test
public void writeFrameFutureFailedShouldCancelRpc() {
Http2Exception h2Error = connectionError(PROTOCOL_ERROR, "Stream does not exist %d", STREAM_ID);
// Fail all SendGrpcFrameCommands command sent to the queue.
when(writeQueue.enqueue(any(SendGrpcFrameCommand.class), anyBoolean())).thenReturn(
new DefaultChannelPromise(channel).setFailure(h2Error));

// Write multiple messages to ensure multiple SendGrpcFrameCommand are enqueued. We set up all
// of them to fail, which allows us to assert that only a single cancel is sent, and the stream
// isn't spammed with multiple RST_STREAM.
stream().transportState().setId(STREAM_ID);
stream.writeMessage(new ByteArrayInputStream(smallMessage()));
stream.writeMessage(new ByteArrayInputStream(largeMessage()));
stream.flush();

InOrder inOrder = Mockito.inOrder(writeQueue);
// Normal stream create and write frame.
inOrder.verify(writeQueue).enqueue(any(CreateStreamCommand.class), eq(false));
inOrder.verify(writeQueue).enqueue(any(SendGrpcFrameCommand.class), eq(false));
// Verify that failed SendGrpcFrameCommand results in immediate CancelClientStreamCommand.
inOrder.verify(writeQueue).enqueue(any(CancelClientStreamCommand.class), eq(true));
// Verify that any other failures do not produce another CancelClientStreamCommand in the queue.
inOrder.verify(writeQueue, atLeast(1)).enqueue(any(SendGrpcFrameCommand.class), eq(false));
inOrder.verify(writeQueue).enqueue(any(SendGrpcFrameCommand.class), eq(true));
inOrder.verifyNoMoreInteractions();

// Get the CancelClientStreamCommand written to the queue. Above we verified that there is
// only one CancelClientStreamCommand enqueued, and is the third enqueued command (create,
// frame write failure, cancel).
CancelClientStreamCommand cancelCommand = Mockito.mockingDetails(writeQueue).getInvocations()
// Get enqueue() innovations only
.stream().filter(invocation -> invocation.getMethod().getName().equals("enqueue"))
// Get the third invocation of enqueue()
.skip(2).findFirst().get()
// Get the first argument (QueuedCommand command)
.getArgument(0);

Status cancelReason = cancelCommand.reason();
assertThat(cancelReason.getCode()).isEqualTo(Status.INTERNAL.getCode());
assertThat(cancelReason.getCause()).isEqualTo(h2Error);
// Verify listener closed.
// TODO(sergiitk): should we expect REFUSED/MISCARRIED instead?
verify(listener).closed(same(cancelReason), eq(PROCESSED), any(Metadata.class));
sergiitk marked this conversation as resolved.
Show resolved Hide resolved
}

@Test
public void setStatusWithOkShouldCloseStream() {
stream().transportState().setId(STREAM_ID);
Expand Down
Loading
Loading