Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TESTS: Fix Buf Leaks in HttpReadWriteHandlerTests #32377

Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelPromise;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.codec.http.DefaultFullHttpRequest;
import io.netty.handler.codec.http.FullHttpResponse;
Expand Down Expand Up @@ -116,38 +117,48 @@ public void testSuccessfulDecodeHttpRequest() throws IOException {

ByteBuf buf = requestEncoder.encode(httpRequest);
int slicePoint = randomInt(buf.writerIndex() - 1);

ByteBuf slicedBuf = buf.retainedSlice(0, slicePoint);
ByteBuf slicedBuf2 = buf.retainedSlice(slicePoint, buf.writerIndex());
handler.consumeReads(toChannelBuffer(slicedBuf));
try {
handler.consumeReads(toChannelBuffer(slicedBuf));

verify(transport, times(0)).incomingRequest(any(HttpRequest.class), any(NioHttpChannel.class));
verify(transport, times(0)).incomingRequest(any(HttpRequest.class), any(NioHttpChannel.class));

handler.consumeReads(toChannelBuffer(slicedBuf2));
handler.consumeReads(toChannelBuffer(slicedBuf2));

ArgumentCaptor<HttpRequest> requestCaptor = ArgumentCaptor.forClass(HttpRequest.class);
verify(transport).incomingRequest(requestCaptor.capture(), any(NioHttpChannel.class));
ArgumentCaptor<HttpRequest> requestCaptor = ArgumentCaptor.forClass(HttpRequest.class);
verify(transport).incomingRequest(requestCaptor.capture(), any(NioHttpChannel.class));

HttpRequest nioHttpRequest = requestCaptor.getValue();
assertEquals(HttpRequest.HttpVersion.HTTP_1_1, nioHttpRequest.protocolVersion());
assertEquals(RestRequest.Method.GET, nioHttpRequest.method());
HttpRequest nioHttpRequest = requestCaptor.getValue();
assertEquals(HttpRequest.HttpVersion.HTTP_1_1, nioHttpRequest.protocolVersion());
assertEquals(RestRequest.Method.GET, nioHttpRequest.method());
} finally {
handler.close();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know that we care about closing the handler. It probably does not matter too much, but there should not be any resources hanging around if we properly consume all the requests.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not closing it triggers:

Created at:
	io.netty.buffer.AbstractByteBufAllocator.compositeHeapBuffer(AbstractByteBufAllocator.java:207)
	io.netty.buffer.AbstractByteBufAllocator.compositeBuffer(AbstractByteBufAllocator.java:197)
	io.netty.handler.codec.ByteToMessageDecoder$2.cumulate(ByteToMessageDecoder.java:122)
	io.netty.handler.codec.ByteToMessageDecoder.channelRead(ByteToMessageDecoder.java:263)
	io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:362)
	io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:348)
	io.netty.channel.AbstractChannelHandlerContext.fireChannelRead(AbstractChannelHandlerContext.java:340)
	io.netty.channel.DefaultChannelPipeline$HeadContext.channelRead(DefaultChannelPipeline.java:1359)
	io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:362)
	io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:348)
	io.netty.channel.DefaultChannelPipeline.fireChannelRead(DefaultChannelPipeline.java:935)
	io.netty.channel.embedded.EmbeddedChannel.writeInbound(EmbeddedChannel.java:318)
	org.elasticsearch.http.nio.NettyAdaptor.read(NettyAdaptor.java:104)
	org.elasticsearch.http.nio.HttpReadWriteHandler.consumeReads(HttpReadWriteHandler.java:81)
	org.elasticsearch.http.nio.HttpReadWriteHandlerTests.testSuccessfulDecodeHttpRequest(HttpReadWriteHandlerTests.java:129)
	jdk.internal.reflect.GeneratedMethodAccessor10.invoke(Unknown Source)
	java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	java.base/java.lang.reflect.Method.invoke(Method.java:564)
	com.carrotsearch.randomizedtesting.RandomizedRunner.invoke(RandomizedRunner.java:1713)
	com.carrotsearch.randomizedtesting.RandomizedRunner$8.evaluate(RandomizedRunner.java:907)
	com.carrotsearch.randomizedtesting.RandomizedRunner$9.evaluate(RandomizedRunner.java:943)
	com.carrotsearch.randomizedtesting.RandomizedRunner$10.evaluate(RandomizedRunner.java:957)
	com.carrotsearch.randomizedtesting.rules.StatementAdapter.evaluate(StatementAdapter.java:36)
	org.apache.lucene.util.TestRuleSetupTeardownChained$1.evaluate(TestRuleSetupTeardownChained.java:49)
	org.apache.lucene.util.AbstractBeforeAfterRule$1.evaluate(AbstractBeforeAfterRule.java:45)
	org.apache.lucene.util.TestRuleThreadAndTestName$1.evaluate(TestRuleThreadAndTestName.java:48)
	org.apache.lucene.util.TestRuleIgnoreAfterMaxFailures$1.evaluate(TestRuleIgnoreAfterMaxFailures.java:64)
	org.apache.lucene.util.TestRuleMarkFailure$1.evaluate(TestRuleMarkFailure.java:47)
	com.carrotsearch.randomizedtesting.rules.StatementAdapter.evaluate(StatementAdapter.java:36)
	com.carrotsearch.randomizedtesting.ThreadLeakControl$StatementRunner.run(ThreadLeakControl.java:368)
	com.carrotsearch.randomizedtesting.ThreadLeakControl.forkTimeoutingTask(ThreadLeakControl.java:817)
	com.carrotsearch.randomizedtesting.ThreadLeakControl$3.evaluate(ThreadLeakControl.java:468)
	com.carrotsearch.randomizedtesting.RandomizedRunner.runSingleTest(RandomizedRunner.java:916)
	com.carrotsearch.randomizedtesting.RandomizedRunner$5.evaluate(RandomizedRunner.java:802)
	com.carrotsearch.randomizedtesting.RandomizedRunner$6.evaluate(RandomizedRunner.java:852)
	com.carrotsearch.randomizedtesting.RandomizedRunner$7.evaluate(RandomizedRunner.java:863)
	org.apache.lucene.util.AbstractBeforeAfterRule$1.evaluate(AbstractBeforeAfterRule.java:45)
	com.carrotsearch.randomizedtesting.rules.StatementAdapter.evaluate(StatementAdapter.java:36)
	org.apache.lucene.util.TestRuleStoreClassName$1.evaluate(TestRuleStoreClassName.java:41)
	com.carrotsearch.randomizedtesting.rules.NoShadowingOrOverridesOnMethodsRule$1.evaluate(NoShadowingOrOverridesOnMethodsRule.java:40)
	com.carrotsearch.randomizedtesting.rules.NoShadowingOrOverridesOnMethodsRule$1.evaluate(NoShadowingOrOverridesOnMethodsRule.java:40)
	com.carrotsearch.randomizedtesting.rules.StatementAdapter.evaluate(StatementAdapter.java:36)
	com.carrotsearch.randomizedtesting.rules.StatementAdapter.evaluate(StatementAdapter.java:36)
	org.apache.lucene.util.TestRuleAssertionsRequired$1.evaluate(TestRuleAssertionsRequired.java:53)
	org.apache.lucene.util.TestRuleMarkFailure$1.evaluate(TestRuleMarkFailure.java:47)
	org.apache.lucene.util.TestRuleIgnoreAfterMaxFailures$1.evaluate(TestRuleIgnoreAfterMaxFailures.java:64)
	org.apache.lucene.util.TestRuleIgnoreTestSuites$1.evaluate(TestRuleIgnoreTestSuites.java:54)
	com.carrotsearch.randomizedtesting.rules.StatementAdapter.evaluate(StatementAdapter.java:36)
	com.carrotsearch.randomizedtesting.ThreadLeakControl$StatementRunner.run(ThreadLeakControl.java:368)
	java.base/java.lang.Thread.run(Thread.java:844)

buf.release();
slicedBuf.release();
slicedBuf2.release();
}
}

public void testDecodeHttpRequestError() throws IOException {
String uri = "localhost:9090/" + randomAlphaOfLength(8);
io.netty.handler.codec.http.HttpRequest httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, uri);

ByteBuf buf = requestEncoder.encode(httpRequest);
buf.setByte(0, ' ');
buf.setByte(1, ' ');
buf.setByte(2, ' ');
try {
buf.setByte(0, ' ');
buf.setByte(1, ' ');
buf.setByte(2, ' ');

handler.consumeReads(toChannelBuffer(buf));
handler.consumeReads(toChannelBuffer(buf));

ArgumentCaptor<Exception> exceptionCaptor = ArgumentCaptor.forClass(Exception.class);
verify(transport).incomingRequestError(any(HttpRequest.class), any(NioHttpChannel.class), exceptionCaptor.capture());
ArgumentCaptor<Exception> exceptionCaptor = ArgumentCaptor.forClass(Exception.class);
verify(transport).incomingRequestError(any(HttpRequest.class), any(NioHttpChannel.class), exceptionCaptor.capture());

assertTrue(exceptionCaptor.getValue() instanceof IllegalArgumentException);
assertTrue(exceptionCaptor.getValue() instanceof IllegalArgumentException);
} finally {
buf.release();
}
}

public void testDecodeHttpRequestContentLengthToLongGeneratesOutboundMessage() throws IOException {
Expand All @@ -157,9 +168,11 @@ public void testDecodeHttpRequestContentLengthToLongGeneratesOutboundMessage() t
HttpUtil.setKeepAlive(httpRequest, false);

ByteBuf buf = requestEncoder.encode(httpRequest);

handler.consumeReads(toChannelBuffer(buf));

try {
handler.consumeReads(toChannelBuffer(buf));
} finally {
buf.release();
}
verify(transport, times(0)).incomingRequestError(any(), any(), any());
verify(transport, times(0)).incomingRequest(any(), any());

Expand All @@ -168,13 +181,17 @@ public void testDecodeHttpRequestContentLengthToLongGeneratesOutboundMessage() t

FlushOperation flushOperation = flushOperations.get(0);
FullHttpResponse response = responseDecoder.decode(Unpooled.wrappedBuffer(flushOperation.getBuffersToWrite()));
assertEquals(HttpVersion.HTTP_1_1, response.protocolVersion());
assertEquals(HttpResponseStatus.REQUEST_ENTITY_TOO_LARGE, response.status());

flushOperation.getListener().accept(null, null);
// Since we have keep-alive set to false, we should close the channel after the response has been
// flushed
verify(nioHttpChannel).close();
try {
assertEquals(HttpVersion.HTTP_1_1, response.protocolVersion());
assertEquals(HttpResponseStatus.REQUEST_ENTITY_TOO_LARGE, response.status());

flushOperation.getListener().accept(null, null);
// Since we have keep-alive set to false, we should close the channel after the response has been
// flushed
verify(nioHttpChannel).close();
} finally {
response.release();
}
}

@SuppressWarnings("unchecked")
Expand All @@ -189,21 +206,29 @@ public void testEncodeHttpResponse() throws IOException {
SocketChannelContext context = mock(SocketChannelContext.class);
HttpWriteOperation writeOperation = new HttpWriteOperation(context, httpResponse, mock(BiConsumer.class));
List<FlushOperation> flushOperations = handler.writeToBytes(writeOperation);

FullHttpResponse response = responseDecoder.decode(Unpooled.wrappedBuffer(flushOperations.get(0).getBuffersToWrite()));

assertEquals(HttpResponseStatus.OK, response.status());
assertEquals(HttpVersion.HTTP_1_1, response.protocolVersion());
FlushOperation operation = flushOperations.get(0);
FullHttpResponse response = responseDecoder.decode(Unpooled.wrappedBuffer(operation.getBuffersToWrite()));
((ChannelPromise) operation.getListener()).setSuccess();
try {
assertEquals(HttpResponseStatus.OK, response.status());
assertEquals(HttpVersion.HTTP_1_1, response.protocolVersion());
} finally {
response.release();
}
}

public void testCorsEnabledWithoutAllowOrigins() throws IOException {
// Set up a HTTP transport with only the CORS enabled setting
Settings settings = Settings.builder()
.put(HttpTransportSettings.SETTING_CORS_ENABLED.getKey(), true)
.build();
io.netty.handler.codec.http.HttpResponse response = executeCorsRequest(settings, "remote-host", "request-host");
// inspect response and validate
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), nullValue());
FullHttpResponse response = executeCorsRequest(settings, "remote-host", "request-host");
try {
// inspect response and validate
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), nullValue());
} finally {
response.release();
}
}

public void testCorsEnabledWithAllowOrigins() throws IOException {
Expand All @@ -213,11 +238,15 @@ public void testCorsEnabledWithAllowOrigins() throws IOException {
.put(SETTING_CORS_ENABLED.getKey(), true)
.put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue)
.build();
io.netty.handler.codec.http.HttpResponse response = executeCorsRequest(settings, originValue, "request-host");
// inspect response and validate
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
assertThat(allowedOrigins, is(originValue));
FullHttpResponse response = executeCorsRequest(settings, originValue, "request-host");
try {
// inspect response and validate
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
assertThat(allowedOrigins, is(originValue));
} finally {
response.release();
}
}

public void testCorsAllowOriginWithSameHost() throws IOException {
Expand All @@ -228,29 +257,44 @@ public void testCorsAllowOriginWithSameHost() throws IOException {
.put(SETTING_CORS_ENABLED.getKey(), true)
.build();
FullHttpResponse response = executeCorsRequest(settings, originValue, host);
// inspect response and validate
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
assertThat(allowedOrigins, is(originValue));

String allowedOrigins;
try {
// inspect response and validate
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
assertThat(allowedOrigins, is(originValue));
} finally {
response.release();
}
originValue = "http://" + originValue;
response = executeCorsRequest(settings, originValue, host);
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
assertThat(allowedOrigins, is(originValue));
try {
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
assertThat(allowedOrigins, is(originValue));
} finally {
response.release();
}

originValue = originValue + ":5555";
host = host + ":5555";
response = executeCorsRequest(settings, originValue, host);
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
assertThat(allowedOrigins, is(originValue));

try {
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
assertThat(allowedOrigins, is(originValue));
} finally {
response.release();
}
originValue = originValue.replace("http", "https");
response = executeCorsRequest(settings, originValue, host);
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
assertThat(allowedOrigins, is(originValue));
try {
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
assertThat(allowedOrigins, is(originValue));
} finally {
response.release();
}
}

public void testThatStringLiteralWorksOnMatch() throws IOException {
Expand All @@ -261,12 +305,16 @@ public void testThatStringLiteralWorksOnMatch() throws IOException {
.put(SETTING_CORS_ALLOW_METHODS.getKey(), "get, options, post")
.put(SETTING_CORS_ALLOW_CREDENTIALS.getKey(), true)
.build();
io.netty.handler.codec.http.HttpResponse response = executeCorsRequest(settings, originValue, "request-host");
// inspect response and validate
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
assertThat(allowedOrigins, is(originValue));
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS), equalTo("true"));
FullHttpResponse response = executeCorsRequest(settings, originValue, "request-host");
try {
// inspect response and validate
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
assertThat(allowedOrigins, is(originValue));
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS), equalTo("true"));
} finally {
response.release();
}
}

public void testThatAnyOriginWorks() throws IOException {
Expand All @@ -275,12 +323,16 @@ public void testThatAnyOriginWorks() throws IOException {
.put(SETTING_CORS_ENABLED.getKey(), true)
.put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue)
.build();
io.netty.handler.codec.http.HttpResponse response = executeCorsRequest(settings, originValue, "request-host");
// inspect response and validate
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
assertThat(allowedOrigins, is(originValue));
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS), nullValue());
FullHttpResponse response = executeCorsRequest(settings, originValue, "request-host");
try {
// inspect response and validate
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
assertThat(allowedOrigins, is(originValue));
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS), nullValue());
} finally {
response.release();
}
}

private FullHttpResponse executeCorsRequest(final Settings settings, final String originValue, final String host) throws IOException {
Expand All @@ -300,8 +352,9 @@ private FullHttpResponse executeCorsRequest(final Settings settings, final Strin

SocketChannelContext context = mock(SocketChannelContext.class);
List<FlushOperation> flushOperations = handler.writeToBytes(handler.createWriteOperation(context, response, (v, e) -> {}));

handler.close();
FlushOperation flushOperation = flushOperations.get(0);
((ChannelPromise) flushOperation.getListener()).setSuccess();
return responseDecoder.decode(Unpooled.wrappedBuffer(flushOperation.getBuffersToWrite()));
}

Expand All @@ -314,8 +367,11 @@ private NioHttpRequest prepareHandlerForResponse(HttpReadWriteHandler handler) t

io.netty.handler.codec.http.HttpRequest request = new DefaultFullHttpRequest(version, method, uri);
ByteBuf buf = requestEncoder.encode(request);

handler.consumeReads(toChannelBuffer(buf));
try {
handler.consumeReads(toChannelBuffer(buf));
} finally {
buf.release();
}

ArgumentCaptor<NioHttpRequest> requestCaptor = ArgumentCaptor.forClass(NioHttpRequest.class);
verify(transport, atLeastOnce()).incomingRequest(requestCaptor.capture(), any(HttpChannel.class));
Expand Down