Skip to content

Commit

Permalink
TESTS: Fix Buf Leaks in HttpReadWriteHandlerTests (#32377)
Browse files Browse the repository at this point in the history
* TESTS: Fix Buf Leaks in HttpReadWriteHandlerTests

* Release all ref counted things that weren't getting properly released
* Mannually force channel promise to be completed because mock channel doesn't do it and it prevents one `release` call in `io.netty.channel.ChannelOutboundHandlerAdapter#write` from firing
  • Loading branch information
original-brownbear authored Jul 26, 2018
1 parent 467a60b commit 48885d2
Showing 1 changed file with 125 additions and 69 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelPromise;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.codec.http.DefaultFullHttpRequest;
import io.netty.handler.codec.http.FullHttpResponse;
Expand Down Expand Up @@ -116,38 +117,48 @@ public void testSuccessfulDecodeHttpRequest() throws IOException {

ByteBuf buf = requestEncoder.encode(httpRequest);
int slicePoint = randomInt(buf.writerIndex() - 1);

ByteBuf slicedBuf = buf.retainedSlice(0, slicePoint);
ByteBuf slicedBuf2 = buf.retainedSlice(slicePoint, buf.writerIndex());
handler.consumeReads(toChannelBuffer(slicedBuf));
try {
handler.consumeReads(toChannelBuffer(slicedBuf));

verify(transport, times(0)).incomingRequest(any(HttpRequest.class), any(NioHttpChannel.class));
verify(transport, times(0)).incomingRequest(any(HttpRequest.class), any(NioHttpChannel.class));

handler.consumeReads(toChannelBuffer(slicedBuf2));
handler.consumeReads(toChannelBuffer(slicedBuf2));

ArgumentCaptor<HttpRequest> requestCaptor = ArgumentCaptor.forClass(HttpRequest.class);
verify(transport).incomingRequest(requestCaptor.capture(), any(NioHttpChannel.class));
ArgumentCaptor<HttpRequest> requestCaptor = ArgumentCaptor.forClass(HttpRequest.class);
verify(transport).incomingRequest(requestCaptor.capture(), any(NioHttpChannel.class));

HttpRequest nioHttpRequest = requestCaptor.getValue();
assertEquals(HttpRequest.HttpVersion.HTTP_1_1, nioHttpRequest.protocolVersion());
assertEquals(RestRequest.Method.GET, nioHttpRequest.method());
HttpRequest nioHttpRequest = requestCaptor.getValue();
assertEquals(HttpRequest.HttpVersion.HTTP_1_1, nioHttpRequest.protocolVersion());
assertEquals(RestRequest.Method.GET, nioHttpRequest.method());
} finally {
handler.close();
buf.release();
slicedBuf.release();
slicedBuf2.release();
}
}

public void testDecodeHttpRequestError() throws IOException {
String uri = "localhost:9090/" + randomAlphaOfLength(8);
io.netty.handler.codec.http.HttpRequest httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, uri);

ByteBuf buf = requestEncoder.encode(httpRequest);
buf.setByte(0, ' ');
buf.setByte(1, ' ');
buf.setByte(2, ' ');
try {
buf.setByte(0, ' ');
buf.setByte(1, ' ');
buf.setByte(2, ' ');

handler.consumeReads(toChannelBuffer(buf));
handler.consumeReads(toChannelBuffer(buf));

ArgumentCaptor<Exception> exceptionCaptor = ArgumentCaptor.forClass(Exception.class);
verify(transport).incomingRequestError(any(HttpRequest.class), any(NioHttpChannel.class), exceptionCaptor.capture());
ArgumentCaptor<Exception> exceptionCaptor = ArgumentCaptor.forClass(Exception.class);
verify(transport).incomingRequestError(any(HttpRequest.class), any(NioHttpChannel.class), exceptionCaptor.capture());

assertTrue(exceptionCaptor.getValue() instanceof IllegalArgumentException);
assertTrue(exceptionCaptor.getValue() instanceof IllegalArgumentException);
} finally {
buf.release();
}
}

public void testDecodeHttpRequestContentLengthToLongGeneratesOutboundMessage() throws IOException {
Expand All @@ -157,9 +168,11 @@ public void testDecodeHttpRequestContentLengthToLongGeneratesOutboundMessage() t
HttpUtil.setKeepAlive(httpRequest, false);

ByteBuf buf = requestEncoder.encode(httpRequest);

handler.consumeReads(toChannelBuffer(buf));

try {
handler.consumeReads(toChannelBuffer(buf));
} finally {
buf.release();
}
verify(transport, times(0)).incomingRequestError(any(), any(), any());
verify(transport, times(0)).incomingRequest(any(), any());

Expand All @@ -168,13 +181,17 @@ public void testDecodeHttpRequestContentLengthToLongGeneratesOutboundMessage() t

FlushOperation flushOperation = flushOperations.get(0);
FullHttpResponse response = responseDecoder.decode(Unpooled.wrappedBuffer(flushOperation.getBuffersToWrite()));
assertEquals(HttpVersion.HTTP_1_1, response.protocolVersion());
assertEquals(HttpResponseStatus.REQUEST_ENTITY_TOO_LARGE, response.status());

flushOperation.getListener().accept(null, null);
// Since we have keep-alive set to false, we should close the channel after the response has been
// flushed
verify(nioHttpChannel).close();
try {
assertEquals(HttpVersion.HTTP_1_1, response.protocolVersion());
assertEquals(HttpResponseStatus.REQUEST_ENTITY_TOO_LARGE, response.status());

flushOperation.getListener().accept(null, null);
// Since we have keep-alive set to false, we should close the channel after the response has been
// flushed
verify(nioHttpChannel).close();
} finally {
response.release();
}
}

@SuppressWarnings("unchecked")
Expand All @@ -189,21 +206,29 @@ public void testEncodeHttpResponse() throws IOException {
SocketChannelContext context = mock(SocketChannelContext.class);
HttpWriteOperation writeOperation = new HttpWriteOperation(context, httpResponse, mock(BiConsumer.class));
List<FlushOperation> flushOperations = handler.writeToBytes(writeOperation);

FullHttpResponse response = responseDecoder.decode(Unpooled.wrappedBuffer(flushOperations.get(0).getBuffersToWrite()));

assertEquals(HttpResponseStatus.OK, response.status());
assertEquals(HttpVersion.HTTP_1_1, response.protocolVersion());
FlushOperation operation = flushOperations.get(0);
FullHttpResponse response = responseDecoder.decode(Unpooled.wrappedBuffer(operation.getBuffersToWrite()));
((ChannelPromise) operation.getListener()).setSuccess();
try {
assertEquals(HttpResponseStatus.OK, response.status());
assertEquals(HttpVersion.HTTP_1_1, response.protocolVersion());
} finally {
response.release();
}
}

public void testCorsEnabledWithoutAllowOrigins() throws IOException {
// Set up a HTTP transport with only the CORS enabled setting
Settings settings = Settings.builder()
.put(HttpTransportSettings.SETTING_CORS_ENABLED.getKey(), true)
.build();
io.netty.handler.codec.http.HttpResponse response = executeCorsRequest(settings, "remote-host", "request-host");
// inspect response and validate
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), nullValue());
FullHttpResponse response = executeCorsRequest(settings, "remote-host", "request-host");
try {
// inspect response and validate
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), nullValue());
} finally {
response.release();
}
}

public void testCorsEnabledWithAllowOrigins() throws IOException {
Expand All @@ -213,11 +238,15 @@ public void testCorsEnabledWithAllowOrigins() throws IOException {
.put(SETTING_CORS_ENABLED.getKey(), true)
.put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue)
.build();
io.netty.handler.codec.http.HttpResponse response = executeCorsRequest(settings, originValue, "request-host");
// inspect response and validate
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
assertThat(allowedOrigins, is(originValue));
FullHttpResponse response = executeCorsRequest(settings, originValue, "request-host");
try {
// inspect response and validate
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
assertThat(allowedOrigins, is(originValue));
} finally {
response.release();
}
}

public void testCorsAllowOriginWithSameHost() throws IOException {
Expand All @@ -228,29 +257,44 @@ public void testCorsAllowOriginWithSameHost() throws IOException {
.put(SETTING_CORS_ENABLED.getKey(), true)
.build();
FullHttpResponse response = executeCorsRequest(settings, originValue, host);
// inspect response and validate
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
assertThat(allowedOrigins, is(originValue));

String allowedOrigins;
try {
// inspect response and validate
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
assertThat(allowedOrigins, is(originValue));
} finally {
response.release();
}
originValue = "http://" + originValue;
response = executeCorsRequest(settings, originValue, host);
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
assertThat(allowedOrigins, is(originValue));
try {
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
assertThat(allowedOrigins, is(originValue));
} finally {
response.release();
}

originValue = originValue + ":5555";
host = host + ":5555";
response = executeCorsRequest(settings, originValue, host);
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
assertThat(allowedOrigins, is(originValue));

try {
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
assertThat(allowedOrigins, is(originValue));
} finally {
response.release();
}
originValue = originValue.replace("http", "https");
response = executeCorsRequest(settings, originValue, host);
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
assertThat(allowedOrigins, is(originValue));
try {
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
assertThat(allowedOrigins, is(originValue));
} finally {
response.release();
}
}

public void testThatStringLiteralWorksOnMatch() throws IOException {
Expand All @@ -261,12 +305,16 @@ public void testThatStringLiteralWorksOnMatch() throws IOException {
.put(SETTING_CORS_ALLOW_METHODS.getKey(), "get, options, post")
.put(SETTING_CORS_ALLOW_CREDENTIALS.getKey(), true)
.build();
io.netty.handler.codec.http.HttpResponse response = executeCorsRequest(settings, originValue, "request-host");
// inspect response and validate
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
assertThat(allowedOrigins, is(originValue));
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS), equalTo("true"));
FullHttpResponse response = executeCorsRequest(settings, originValue, "request-host");
try {
// inspect response and validate
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
assertThat(allowedOrigins, is(originValue));
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS), equalTo("true"));
} finally {
response.release();
}
}

public void testThatAnyOriginWorks() throws IOException {
Expand All @@ -275,12 +323,16 @@ public void testThatAnyOriginWorks() throws IOException {
.put(SETTING_CORS_ENABLED.getKey(), true)
.put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue)
.build();
io.netty.handler.codec.http.HttpResponse response = executeCorsRequest(settings, originValue, "request-host");
// inspect response and validate
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
assertThat(allowedOrigins, is(originValue));
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS), nullValue());
FullHttpResponse response = executeCorsRequest(settings, originValue, "request-host");
try {
// inspect response and validate
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
assertThat(allowedOrigins, is(originValue));
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS), nullValue());
} finally {
response.release();
}
}

private FullHttpResponse executeCorsRequest(final Settings settings, final String originValue, final String host) throws IOException {
Expand All @@ -300,8 +352,9 @@ private FullHttpResponse executeCorsRequest(final Settings settings, final Strin

SocketChannelContext context = mock(SocketChannelContext.class);
List<FlushOperation> flushOperations = handler.writeToBytes(handler.createWriteOperation(context, response, (v, e) -> {}));

handler.close();
FlushOperation flushOperation = flushOperations.get(0);
((ChannelPromise) flushOperation.getListener()).setSuccess();
return responseDecoder.decode(Unpooled.wrappedBuffer(flushOperation.getBuffersToWrite()));
}

Expand All @@ -314,8 +367,11 @@ private NioHttpRequest prepareHandlerForResponse(HttpReadWriteHandler handler) t

io.netty.handler.codec.http.HttpRequest request = new DefaultFullHttpRequest(version, method, uri);
ByteBuf buf = requestEncoder.encode(request);

handler.consumeReads(toChannelBuffer(buf));
try {
handler.consumeReads(toChannelBuffer(buf));
} finally {
buf.release();
}

ArgumentCaptor<NioHttpRequest> requestCaptor = ArgumentCaptor.forClass(NioHttpRequest.class);
verify(transport, atLeastOnce()).incomingRequest(requestCaptor.capture(), any(HttpChannel.class));
Expand Down

0 comments on commit 48885d2

Please sign in to comment.