Skip to content

Commit

Permalink
Fix a JsonEmitter leak and an over-release.
Browse files Browse the repository at this point in the history
I've made the JsonEmitter implement AutoCloseable so that the underlying stream can be closed, which in turn will release the compositeByteBuf that it uses to buffer contents.
The NettyJsonBodySerializeHandler calls fireChannelRead(), then immediately called release on the same buffer that it just sent downstream on the channel.  That latter release() has no been removed.
I've added a new unit test to help to test the NettyJsonBodySerializeHandler and to make sure that the memory accesses are solid.  I've also updated the JsonEmitter test to not leak so that we're closer to a zero-leaks target.

Signed-off-by: Greg Schohn <greg.schohn@gmail.com>
  • Loading branch information
gregschohn committed Sep 13, 2023
1 parent 6de4dd4 commit c54e360
Show file tree
Hide file tree
Showing 11 changed files with 121 additions and 59 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,13 @@

import com.google.protobuf.CodedOutputStream;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import io.netty.buffer.Unpooled;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.util.ResourceLeakDetector;
import io.netty.util.ResourceLeakDetectorFactory;
import lombok.extern.slf4j.Slf4j;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.opensearch.migrations.testutils.CountingNettyResourceLeakDetector;
import org.opensearch.migrations.testutils.TestUtilities;
import org.opensearch.migrations.testutils.TestWithNettyLeakDetection;
import org.opensearch.migrations.testutils.WrapWithNettyLeakDetection;
import org.opensearch.migrations.trafficcapture.StreamChannelConnectionCaptureSerializer;
import org.opensearch.migrations.trafficcapture.protos.TrafficStream;

Expand Down Expand Up @@ -111,15 +104,15 @@ public void testThatAPostInTinyPacketsBlocksFutureActivity(boolean usePool) thro

@ParameterizedTest
@ValueSource(booleans = {true, false})
@TestWithNettyLeakDetection(repetitions = 16)
@WrapWithNettyLeakDetection(repetitions = 16)
public void testThatAPostInTinyPacketsBlocksFutureActivity_withLeakDetection(boolean usePool) throws Exception {
testThatAPostInTinyPacketsBlocksFutureActivity(usePool);
//MyResourceLeakDetector.dumpHeap("nettyWireLogging_"+COUNT+"_"+ Instant.now() +".hprof", true);
}

@ParameterizedTest
@ValueSource(booleans = {true, false})
@TestWithNettyLeakDetection(repetitions = 32)
@WrapWithNettyLeakDetection(repetitions = 32)
public void testThatAPostInASinglePacketBlocksFutureActivity_withLeakDetection(boolean usePool) throws Exception {
testThatAPostInASinglePacketBlocksFutureActivity(usePool);
//MyResourceLeakDetector.dumpHeap("nettyWireLogging_"+COUNT+"_"+ Instant.now() +".hprof", true);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import org.junit.jupiter.api.extension.ReflectiveInvocationContext;

import java.lang.reflect.Method;
import java.util.Optional;
import java.util.concurrent.Callable;

public class NettyLeakCheckTestExtension implements InvocationInterceptor {
Expand All @@ -15,7 +14,7 @@ public NettyLeakCheckTestExtension() {}
private void wrapWithLeakChecks(ExtensionContext extensionContext, Callable repeatCall, Callable finalCall)
throws Throwable {
int repetitions = extensionContext.getTestMethod()
.map(ec->ec.getAnnotation(TestWithNettyLeakDetection.class).repetitions())
.map(ec->ec.getAnnotation(WrapWithNettyLeakDetection.class).repetitions())
.orElseThrow(() -> new IllegalStateException("No test method present"));
CountingNettyResourceLeakDetector.activate();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
@Target({ ElementType.TYPE, ElementType.METHOD })
@Retention(RetentionPolicy.RUNTIME)
@ExtendWith(NettyLeakCheckTestExtension.class)
public @interface TestWithNettyLeakDetection {
public @interface WrapWithNettyLeakDetection {
/**
* Some leaks might need to put a bit more stress on the GC for objects to get cleared out
* and trigger potential checks within any resource finalizers to determine if there have
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
* the memory load rather than expanding the working set for a large textual stream.
*/
@Slf4j
public class JsonEmitter {
public class JsonEmitter implements AutoCloseable {

public final static int NUM_SEGMENT_THRESHOLD = 256;

Expand Down Expand Up @@ -61,12 +61,12 @@ public LevelContext(Iterator<Object> iterator, Runnable onPopContinuation) {
}
}

static class ImmediateByteBufOutputStream extends OutputStream {
static class ChunkingByteBufOutputStream extends OutputStream {
private final ByteBufAllocator byteBufAllocator;
@Getter
CompositeByteBuf compositeByteBuf;

public ImmediateByteBufOutputStream(ByteBufAllocator byteBufAllocator) {
public ChunkingByteBufOutputStream(ByteBufAllocator byteBufAllocator) {
this.byteBufAllocator = byteBufAllocator;
compositeByteBuf = byteBufAllocator.compositeBuffer();
}
Expand All @@ -88,9 +88,13 @@ private void write(ByteBuf byteBuf) {
}

@Override
public void close() throws IOException {
public void close() {
compositeByteBuf.release();
super.close();
try {
super.close();
} catch (IOException e) {
throw new RuntimeException("Expected OutputStream::close() to be empty as per docs in Java 11");
}
}

/**
Expand All @@ -106,18 +110,23 @@ public ByteBuf recycleByteBufRetained() {
}

private JsonGenerator jsonGenerator;
private ImmediateByteBufOutputStream outputStream;
private ChunkingByteBufOutputStream outputStream;
private ObjectMapper objectMapper;
private Stack<LevelContext> cursorStack;

@SneakyThrows
public JsonEmitter(ByteBufAllocator byteBufAllocator) {
outputStream = new ImmediateByteBufOutputStream(byteBufAllocator);
outputStream = new ChunkingByteBufOutputStream(byteBufAllocator);
jsonGenerator = new JsonFactory().createGenerator(outputStream, JsonEncoding.UTF8);
objectMapper = new ObjectMapper();
cursorStack = new Stack<>();
}

@Override
public void close() {
outputStream.close();
}

/**
* This returns a ByteBuf block of serialized data plus a recursive generation function (Supplier)
* for the next block and continuation. The size of the ByteBuf that will be returned can be
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,18 @@
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.handler.codec.http.DefaultHttpContent;
import io.netty.handler.codec.http.LastHttpContent;
import lombok.extern.slf4j.Slf4j;
import org.opensearch.migrations.replay.datahandlers.JsonEmitter;
import org.opensearch.migrations.replay.datahandlers.PayloadAccessFaultingMap;

import java.io.IOException;
import java.util.Map;

@Slf4j
public class NettyJsonBodySerializeHandler extends ChannelInboundHandlerAdapter {
public static final int NUM_BYTES_TO_ACCUMULATE_BEFORE_FIRING = 1024;
final JsonEmitter jsonEmitter;

public NettyJsonBodySerializeHandler(ChannelHandlerContext ctx) {
this.jsonEmitter = new JsonEmitter(ctx.alloc());
public NettyJsonBodySerializeHandler() {
}

@Override
Expand All @@ -34,16 +34,17 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception
}
}

private void serializePayload(ChannelHandlerContext ctx, Map<String,Object> payload) throws IOException {
var pac = jsonEmitter.getChunkAndContinuations(payload, NUM_BYTES_TO_ACCUMULATE_BEFORE_FIRING);
while (true) {
ctx.fireChannelRead(new DefaultHttpContent(pac.partialSerializedContents));
pac.partialSerializedContents.release();
if (pac.nextSupplier == null) {
ctx.fireChannelRead(LastHttpContent.EMPTY_LAST_CONTENT);
break;
private void serializePayload(ChannelHandlerContext ctx, Map<String, Object> payload) throws IOException {
try (var jsonEmitter = new JsonEmitter(ctx.alloc())) {
var pac = jsonEmitter.getChunkAndContinuations(payload, NUM_BYTES_TO_ACCUMULATE_BEFORE_FIRING);
while (true) {
ctx.fireChannelRead(new DefaultHttpContent(pac.partialSerializedContents));
if (pac.nextSupplier == null) {
ctx.fireChannelRead(LastHttpContent.EMPTY_LAST_CONTENT);
break;
}
pac = pac.nextSupplier.get();
}
pac = pac.nextSupplier.get();
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ void addContentParsingHandlers(ChannelHandlerContext ctx,
pipeline.addLast(new NettyJsonBodyConvertHandler(transformer));
// IN: Netty HttpRequest(2) + HttpJsonMessage(3) with headers AND payload
// OUT: Netty HttpRequest(2) + HttpJsonMessage(3) with headers only + HttpContent(3) blocks
pipeline.addLast(new NettyJsonBodySerializeHandler(ctx));
pipeline.addLast(new NettyJsonBodySerializeHandler());
addLoggingHandler(pipeline, "F");
}
if (authTransfomer != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import org.junit.jupiter.params.provider.MethodSource;
import org.opensearch.migrations.testutils.CountingNettyResourceLeakDetector;
import org.opensearch.migrations.testutils.TestUtilities;
import org.opensearch.migrations.testutils.TestWithNettyLeakDetection;
import org.opensearch.migrations.testutils.WrapWithNettyLeakDetection;

import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
Expand All @@ -16,7 +16,7 @@
import java.util.stream.Collectors;
import java.util.stream.Stream;

@TestWithNettyLeakDetection
@WrapWithNettyLeakDetection
public class PrettyPrinterTest {

@BeforeAll
Expand Down Expand Up @@ -64,7 +64,7 @@ private static Stream<Arguments> makeCombos() {

@ParameterizedTest
@MethodSource("makeCombos")
@TestWithNettyLeakDetection(repetitions = 4)
@WrapWithNettyLeakDetection(repetitions = 4)
public void httpPacketBufsToString(PrettyPrinter.PacketPrintFormat format, BufferType bufferType) {
byte[] fullTrafficBytes = SAMPLE_REQUEST_STRING.getBytes(StandardCharsets.UTF_8);
var byteArrays = new ArrayList<byte[]>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ public class JsonAccumulatorTest {
public static final String LARGE_PACKED = "largeAndPacked";
GenerateRandomNestedJsonObject randomJsonGenerator = new GenerateRandomNestedJsonObject();

private static Object readJson(byte[] testFileBytes, int chunkBound) throws IOException {
static Object readJson(byte[] testFileBytes, int chunkBound) throws IOException {
var jsonParser = new JsonAccumulator();
Random r = new Random(2);
var entireByteBuffer = ByteBuffer.wrap(testFileBytes);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,38 +5,39 @@
import lombok.extern.slf4j.Slf4j;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.opensearch.migrations.testutils.WrapWithNettyLeakDetection;

import java.io.IOException;
import java.io.StringReader;
import java.io.StringWriter;
import java.nio.charset.StandardCharsets;

@Slf4j
class JsonEmitterTest {
public class JsonEmitterTest {
@Test
@WrapWithNettyLeakDetection(repetitions = 2)
public void testEmitterWorksRoundTrip() throws IOException {
JsonEmitter jse = new JsonEmitter(ByteBufAllocator.DEFAULT);
var mapper = new ObjectMapper();
try (JsonEmitter jse = new JsonEmitter(ByteBufAllocator.DEFAULT)) {
var mapper = new ObjectMapper();

//var originalTree = mapper.readTree(new File("bigfile.json"));
var originalTree = mapper.readTree(new StringReader("{\"index\":{\"_id\":\"1\"}}"));
var writer = new StringWriter();
var pac = jse.getChunkAndContinuations(originalTree, 10*1024);
while (true) {
var asBytes = new byte[pac.partialSerializedContents.readableBytes()];
log.info("Got "+asBytes.length+" bytes back");
pac.partialSerializedContents.readBytes(asBytes);
var nextChunkStr = new String(asBytes, StandardCharsets.UTF_8);
writer.append(nextChunkStr);
if (pac.nextSupplier == null) {
break;
var originalTree = mapper.readTree(new StringReader("{\"index\":{\"_id\":\"1\"}}"));
var writer = new StringWriter();
var pac = jse.getChunkAndContinuations(originalTree, 10 * 1024);
while (true) {
var chunk = pac.partialSerializedContents.toString(StandardCharsets.UTF_8);
pac.partialSerializedContents.release();
log.info("Got: " + chunk);
writer.append(chunk);
if (pac.nextSupplier == null) {
break;
}
pac = pac.nextSupplier.get();
}
pac = pac.nextSupplier.get();
writer.flush();
var streamedToStringRoundTripped =
mapper.writeValueAsString(mapper.readTree(new StringReader(writer.toString())));
var originalRoundTripped = mapper.writeValueAsString(originalTree);
Assertions.assertEquals(originalRoundTripped, streamedToStringRoundTripped);
}
writer.flush();
var streamedToStringRoundTripped =
mapper.writeValueAsString(mapper.readTree(new StringReader(writer.toString())));
var originalRoundTripped = mapper.writeValueAsString(originalTree);
Assertions.assertEquals(originalRoundTripped, streamedToStringRoundTripped);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package org.opensearch.migrations.replay.datahandlers.http;

import com.fasterxml.jackson.databind.ObjectMapper;
import io.netty.buffer.ByteBuf;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.codec.http.HttpContent;
import lombok.extern.slf4j.Slf4j;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.opensearch.migrations.replay.GenerateRandomNestedJsonObject;
import org.opensearch.migrations.replay.ReplayUtils;
import org.opensearch.migrations.replay.datahandlers.PayloadAccessFaultingMap;
import org.opensearch.migrations.testutils.WrapWithNettyLeakDetection;

import java.io.ByteArrayOutputStream;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.Random;
import java.util.stream.Stream;

import static org.junit.jupiter.api.Assertions.*;

@Slf4j
public class NettyJsonBodySerializeHandlerTest {
@Test
@WrapWithNettyLeakDetection
public void testJsonSerializerHandler() throws Exception {
var randomJsonGenerator = new GenerateRandomNestedJsonObject();
var randomJson = randomJsonGenerator.makeRandomJsonObject(new Random(2), 2, 1);
var headers = new StrictCaseInsensitiveHttpHeadersMap();
headers.put(IHttpMessage.CONTENT_TYPE, List.of(IHttpMessage.APPLICATION_JSON));
var fullHttpMessageWithJsonBody = new HttpJsonMessageWithFaultingPayload(headers);
fullHttpMessageWithJsonBody.setPayloadFaultMap(new PayloadAccessFaultingMap(headers));
fullHttpMessageWithJsonBody.payload().put(PayloadAccessFaultingMap.INLINED_JSON_BODY_DOCUMENT_KEY, randomJson);

var channel = new EmbeddedChannel(new NettyJsonBodySerializeHandler());
channel.writeInbound(fullHttpMessageWithJsonBody);

var handlerAccumulatedStream =
ReplayUtils.byteBufsToInputStream(getByteBufStreamFromChannel(channel));

String originalTreeStr = new ObjectMapper().writeValueAsString(randomJson);
var reconstitutedTreeStr = new String(handlerAccumulatedStream.readAllBytes(),StandardCharsets.UTF_8);
Assertions.assertEquals(originalTreeStr, reconstitutedTreeStr);

getByteBufStreamFromChannel(channel).forEach(bb->bb.release());
}

private static Stream<ByteBuf> getByteBufStreamFromChannel(EmbeddedChannel channel) {
return channel.inboundMessages().stream()
.filter(x -> x instanceof HttpContent)
.map(x -> {
var rval = ((HttpContent) x).content();
log.info("refCnt=" + rval.refCnt() + " for " + x);
return rval;
});
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class NettyJsonToByteBufHandlerTest {
public static final int ORIGINAL_NUMBER_OF_CHUNKS = 4;
public static final int ADDITIONAL_CHUNKS = 2;

// TODO - replace this with just an operation on the EmbeddedChannel.inboundMessage
static class ValidationHandler extends ChannelInboundHandlerAdapter {
final List<Integer> chunkSizesReceivedList;
int totalBytesRead;
Expand Down

0 comments on commit c54e360

Please sign in to comment.