From d6c82fc55a76481c676f541a255571e8950bb8c3 Mon Sep 17 00:00:00 2001 From: zhangskz Date: Wed, 18 Sep 2024 15:22:47 -0400 Subject: [PATCH] Add recursion check when parsing unknown fields in Java. (#18388) * Internal change PiperOrigin-RevId: 663919912 * Internal change PiperOrigin-RevId: 653615736 * Add recursion check when parsing unknown fields in Java. PiperOrigin-RevId: 675657198 --------- Co-authored-by: Protobuf Team Bot --- java/core/BUILD.bazel | 2 + .../com/google/protobuf/ArrayDecoders.java | 31 ++- .../com/google/protobuf/CodedInputStream.java | 112 +++------ .../InvalidProtocolBufferException.java | 2 +- .../com/google/protobuf/MessageSchema.java | 9 +- .../com/google/protobuf/MessageSetSchema.java | 2 +- .../google/protobuf/UnknownFieldSchema.java | 28 ++- .../google/protobuf/CodedInputStreamTest.java | 158 ++++++++++++ .../com/google/protobuf/map_lite_test.proto | 1 + .../proto/com/google/protobuf/map_test.proto | 1 + .../java/com/google/protobuf/LiteTest.java | 236 ++++++++++++++++++ src/google/protobuf/unittest_lite.proto | 4 + 12 files changed, 493 insertions(+), 93 deletions(-) diff --git a/java/core/BUILD.bazel b/java/core/BUILD.bazel index 6fe4ccff6af46..1f7b0cb5b8bbb 100644 --- a/java/core/BUILD.bazel +++ b/java/core/BUILD.bazel @@ -608,6 +608,7 @@ junit_tests( "src/test/java/com/google/protobuf/DescriptorsTest.java", "src/test/java/com/google/protobuf/DebugFormatTest.java", "src/test/java/com/google/protobuf/CodedOutputStreamTest.java", + "src/test/java/com/google/protobuf/CodedInputStreamTest.java", "src/test/java/com/google/protobuf/ProtobufToStringOutputTest.java", # Excluded in core_tests "src/test/java/com/google/protobuf/DecodeUtf8Test.java", @@ -656,6 +657,7 @@ junit_tests( "src/test/java/com/google/protobuf/DescriptorsTest.java", "src/test/java/com/google/protobuf/DebugFormatTest.java", "src/test/java/com/google/protobuf/CodedOutputStreamTest.java", + "src/test/java/com/google/protobuf/CodedInputStreamTest.java", "src/test/java/com/google/protobuf/ProtobufToStringOutputTest.java", # Excluded in core_tests "src/test/java/com/google/protobuf/DecodeUtf8Test.java", diff --git a/java/core/src/main/java/com/google/protobuf/ArrayDecoders.java b/java/core/src/main/java/com/google/protobuf/ArrayDecoders.java index f3241de5095c0..bf5f922b073c5 100644 --- a/java/core/src/main/java/com/google/protobuf/ArrayDecoders.java +++ b/java/core/src/main/java/com/google/protobuf/ArrayDecoders.java @@ -23,9 +23,12 @@ */ @CheckReturnValue final class ArrayDecoders { + static final int DEFAULT_RECURSION_LIMIT = 100; - private ArrayDecoders() { - } + @SuppressWarnings("NonFinalStaticField") + private static volatile int recursionLimit = DEFAULT_RECURSION_LIMIT; + + private ArrayDecoders() {} /** * A helper used to return multiple values in a Java function. Java doesn't natively support @@ -38,6 +41,7 @@ static final class Registers { public long long1; public Object object1; public final ExtensionRegistryLite extensionRegistry; + public int recursionDepth; Registers() { this.extensionRegistry = ExtensionRegistryLite.getEmptyRegistry(); @@ -245,7 +249,10 @@ static int mergeMessageField( if (length < 0 || length > limit - position) { throw InvalidProtocolBufferException.truncatedMessage(); } + registers.recursionDepth++; + checkRecursionLimit(registers.recursionDepth); schema.mergeFrom(msg, data, position, position + length, registers); + registers.recursionDepth--; registers.object1 = msg; return position + length; } @@ -263,8 +270,11 @@ static int mergeGroupField( // A group field must has a MessageSchema (the only other subclass of Schema is MessageSetSchema // and it can't be used in group fields). final MessageSchema messageSchema = (MessageSchema) schema; + registers.recursionDepth++; + checkRecursionLimit(registers.recursionDepth); final int endPosition = messageSchema.parseMessage(msg, data, position, limit, endGroup, registers); + registers.recursionDepth--; registers.object1 = msg; return endPosition; } @@ -1025,6 +1035,8 @@ static int decodeUnknownField( final UnknownFieldSetLite child = UnknownFieldSetLite.newInstance(); final int endGroup = (tag & ~0x7) | WireFormat.WIRETYPE_END_GROUP; int lastTag = 0; + registers.recursionDepth++; + checkRecursionLimit(registers.recursionDepth); while (position < limit) { position = decodeVarint32(data, position, registers); lastTag = registers.int1; @@ -1033,6 +1045,7 @@ static int decodeUnknownField( } position = decodeUnknownField(lastTag, data, position, limit, child, registers); } + registers.recursionDepth--; if (position > limit || lastTag != endGroup) { throw InvalidProtocolBufferException.parseFailure(); } @@ -1079,4 +1092,18 @@ static int skipField(int tag, byte[] data, int position, int limit, Registers re throw InvalidProtocolBufferException.invalidTag(); } } + + /** + * Set the maximum recursion limit that ArrayDecoders will allow. An exception will be thrown if + * the depth of the message exceeds this limit. + */ + public static void setRecursionLimit(int limit) { + recursionLimit = limit; + } + + private static void checkRecursionLimit(int depth) throws InvalidProtocolBufferException { + if (depth >= recursionLimit) { + throw InvalidProtocolBufferException.recursionLimitExceeded(); + } + } } diff --git a/java/core/src/main/java/com/google/protobuf/CodedInputStream.java b/java/core/src/main/java/com/google/protobuf/CodedInputStream.java index 81da417783362..6b3573bf92f10 100644 --- a/java/core/src/main/java/com/google/protobuf/CodedInputStream.java +++ b/java/core/src/main/java/com/google/protobuf/CodedInputStream.java @@ -224,13 +224,41 @@ public abstract boolean skipField(final int tag, final CodedOutputStream output) * Reads and discards an entire message. This will read either until EOF or until an endgroup tag, * whichever comes first. */ - public abstract void skipMessage() throws IOException; + public void skipMessage() throws IOException { + while (true) { + final int tag = readTag(); + if (tag == 0) { + return; + } + checkRecursionLimit(); + ++recursionDepth; + boolean fieldSkipped = skipField(tag); + --recursionDepth; + if (!fieldSkipped) { + return; + } + } + } /** * Reads an entire message and writes it to output in wire format. This will read either until EOF * or until an endgroup tag, whichever comes first. */ - public abstract void skipMessage(CodedOutputStream output) throws IOException; + public void skipMessage(CodedOutputStream output) throws IOException { + while (true) { + final int tag = readTag(); + if (tag == 0) { + return; + } + checkRecursionLimit(); + ++recursionDepth; + boolean fieldSkipped = skipField(tag, output); + --recursionDepth; + if (!fieldSkipped) { + return; + } + } + } // ----------------------------------------------------------------- @@ -700,26 +728,6 @@ public boolean skipField(final int tag, final CodedOutputStream output) throws I } } - @Override - public void skipMessage() throws IOException { - while (true) { - final int tag = readTag(); - if (tag == 0 || !skipField(tag)) { - return; - } - } - } - - @Override - public void skipMessage(CodedOutputStream output) throws IOException { - while (true) { - final int tag = readTag(); - if (tag == 0 || !skipField(tag, output)) { - return; - } - } - } - // ----------------------------------------------------------------- @Override @@ -1412,26 +1420,6 @@ public boolean skipField(final int tag, final CodedOutputStream output) throws I } } - @Override - public void skipMessage() throws IOException { - while (true) { - final int tag = readTag(); - if (tag == 0 || !skipField(tag)) { - return; - } - } - } - - @Override - public void skipMessage(CodedOutputStream output) throws IOException { - while (true) { - final int tag = readTag(); - if (tag == 0 || !skipField(tag, output)) { - return; - } - } - } - // ----------------------------------------------------------------- @Override @@ -2178,26 +2166,6 @@ public boolean skipField(final int tag, final CodedOutputStream output) throws I } } - @Override - public void skipMessage() throws IOException { - while (true) { - final int tag = readTag(); - if (tag == 0 || !skipField(tag)) { - return; - } - } - } - - @Override - public void skipMessage(CodedOutputStream output) throws IOException { - while (true) { - final int tag = readTag(); - if (tag == 0 || !skipField(tag, output)) { - return; - } - } - } - /** Collects the bytes skipped and returns the data in a ByteBuffer. */ private class SkippedDataSink implements RefillCallback { private int lastPos = pos; @@ -3322,26 +3290,6 @@ public boolean skipField(final int tag, final CodedOutputStream output) throws I } } - @Override - public void skipMessage() throws IOException { - while (true) { - final int tag = readTag(); - if (tag == 0 || !skipField(tag)) { - return; - } - } - } - - @Override - public void skipMessage(CodedOutputStream output) throws IOException { - while (true) { - final int tag = readTag(); - if (tag == 0 || !skipField(tag, output)) { - return; - } - } - } - // ----------------------------------------------------------------- @Override diff --git a/java/core/src/main/java/com/google/protobuf/InvalidProtocolBufferException.java b/java/core/src/main/java/com/google/protobuf/InvalidProtocolBufferException.java index 5d10e48884c1a..dbcb9e899dbee 100644 --- a/java/core/src/main/java/com/google/protobuf/InvalidProtocolBufferException.java +++ b/java/core/src/main/java/com/google/protobuf/InvalidProtocolBufferException.java @@ -132,7 +132,7 @@ public InvalidWireTypeException(String description) { static InvalidProtocolBufferException recursionLimitExceeded() { return new InvalidProtocolBufferException( "Protocol message had too many levels of nesting. May be malicious. " - + "Use CodedInputStream.setRecursionLimit() to increase the depth limit."); + + "Use setRecursionLimit() to increase the recursion depth limit."); } static InvalidProtocolBufferException sizeLimitExceeded() { diff --git a/java/core/src/main/java/com/google/protobuf/MessageSchema.java b/java/core/src/main/java/com/google/protobuf/MessageSchema.java index de3890f7023aa..5ad6762b0dc4e 100644 --- a/java/core/src/main/java/com/google/protobuf/MessageSchema.java +++ b/java/core/src/main/java/com/google/protobuf/MessageSchema.java @@ -3006,7 +3006,8 @@ private > void mergeFromHelper( unknownFields = unknownFieldSchema.getBuilderFromMessage(message); } // Unknown field. - if (unknownFieldSchema.mergeOneFieldFrom(unknownFields, reader)) { + if (unknownFieldSchema.mergeOneFieldFrom( + unknownFields, reader, /* currentDepth= */ 0)) { continue; } } @@ -3381,7 +3382,8 @@ private > void mergeFromHelper( if (unknownFields == null) { unknownFields = unknownFieldSchema.getBuilderFromMessage(message); } - if (!unknownFieldSchema.mergeOneFieldFrom(unknownFields, reader)) { + if (!unknownFieldSchema.mergeOneFieldFrom( + unknownFields, reader, /* currentDepth= */ 0)) { return; } break; @@ -3397,7 +3399,8 @@ private > void mergeFromHelper( if (unknownFields == null) { unknownFields = unknownFieldSchema.getBuilderFromMessage(message); } - if (!unknownFieldSchema.mergeOneFieldFrom(unknownFields, reader)) { + if (!unknownFieldSchema.mergeOneFieldFrom( + unknownFields, reader, /* currentDepth= */ 0)) { return; } } diff --git a/java/core/src/main/java/com/google/protobuf/MessageSetSchema.java b/java/core/src/main/java/com/google/protobuf/MessageSetSchema.java index eec3acd35ca79..ec37d41f98c5a 100644 --- a/java/core/src/main/java/com/google/protobuf/MessageSetSchema.java +++ b/java/core/src/main/java/com/google/protobuf/MessageSetSchema.java @@ -278,7 +278,7 @@ boolean parseMessageSetItemOrUnknownField( reader, extension, extensionRegistry, extensions); return true; } else { - return unknownFieldSchema.mergeOneFieldFrom(unknownFields, reader); + return unknownFieldSchema.mergeOneFieldFrom(unknownFields, reader, /* currentDepth= */ 0); } } else { return reader.skipField(); diff --git a/java/core/src/main/java/com/google/protobuf/UnknownFieldSchema.java b/java/core/src/main/java/com/google/protobuf/UnknownFieldSchema.java index c4ec645bf7fc8..80602b16359aa 100644 --- a/java/core/src/main/java/com/google/protobuf/UnknownFieldSchema.java +++ b/java/core/src/main/java/com/google/protobuf/UnknownFieldSchema.java @@ -13,6 +13,11 @@ @CheckReturnValue abstract class UnknownFieldSchema { + static final int DEFAULT_RECURSION_LIMIT = 100; + + @SuppressWarnings("NonFinalStaticField") + private static volatile int recursionLimit = DEFAULT_RECURSION_LIMIT; + /** Whether unknown fields should be dropped. */ abstract boolean shouldDiscardUnknownFields(Reader reader); @@ -56,7 +61,8 @@ abstract class UnknownFieldSchema { abstract void makeImmutable(Object message); /** Merges one field into the unknown fields. */ - final boolean mergeOneFieldFrom(B unknownFields, Reader reader) throws IOException { + final boolean mergeOneFieldFrom(B unknownFields, Reader reader, int currentDepth) + throws IOException { int tag = reader.getTag(); int fieldNumber = WireFormat.getTagFieldNumber(tag); switch (WireFormat.getTagWireType(tag)) { @@ -75,7 +81,12 @@ final boolean mergeOneFieldFrom(B unknownFields, Reader reader) throws IOExcepti case WireFormat.WIRETYPE_START_GROUP: final B subFields = newBuilder(); int endGroupTag = WireFormat.makeTag(fieldNumber, WireFormat.WIRETYPE_END_GROUP); - mergeFrom(subFields, reader); + currentDepth++; + if (currentDepth >= recursionLimit) { + throw InvalidProtocolBufferException.recursionLimitExceeded(); + } + mergeFrom(subFields, reader, currentDepth); + currentDepth--; if (endGroupTag != reader.getTag()) { throw InvalidProtocolBufferException.invalidEndTag(); } @@ -88,10 +99,11 @@ final boolean mergeOneFieldFrom(B unknownFields, Reader reader) throws IOExcepti } } - final void mergeFrom(B unknownFields, Reader reader) throws IOException { + private final void mergeFrom(B unknownFields, Reader reader, int currentDepth) + throws IOException { while (true) { if (reader.getFieldNumber() == Reader.READ_DONE - || !mergeOneFieldFrom(unknownFields, reader)) { + || !mergeOneFieldFrom(unknownFields, reader, currentDepth)) { break; } } @@ -108,4 +120,12 @@ final void mergeFrom(B unknownFields, Reader reader) throws IOException { abstract int getSerializedSizeAsMessageSet(T message); abstract int getSerializedSize(T unknowns); + + /** + * Set the maximum recursion limit that ArrayDecoders will allow. An exception will be thrown if + * the depth of the message exceeds this limit. + */ + public void setRecursionLimit(int limit) { + recursionLimit = limit; + } } diff --git a/java/core/src/test/java/com/google/protobuf/CodedInputStreamTest.java b/java/core/src/test/java/com/google/protobuf/CodedInputStreamTest.java index ff700587a1601..f73cb3b0eecf3 100644 --- a/java/core/src/test/java/com/google/protobuf/CodedInputStreamTest.java +++ b/java/core/src/test/java/com/google/protobuf/CodedInputStreamTest.java @@ -11,6 +11,9 @@ import static com.google.common.truth.Truth.assertWithMessage; import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertThrows; + +import com.google.common.primitives.Bytes; +import map_test.MapTestProto.MapContainer; import protobuf_unittest.UnittestProto.BoolMessage; import protobuf_unittest.UnittestProto.Int32Message; import protobuf_unittest.UnittestProto.Int64Message; @@ -35,6 +38,13 @@ public class CodedInputStreamTest { private static final int DEFAULT_BLOCK_SIZE = 4096; + private static final int GROUP_TAP = WireFormat.makeTag(3, WireFormat.WIRETYPE_START_GROUP); + + private static final byte[] NESTING_SGROUP = generateSGroupTags(); + + private static final byte[] NESTING_SGROUP_WITH_INITIAL_BYTES = generateSGroupTagsForMapField(); + + private enum InputType { ARRAY { @Override @@ -117,6 +127,17 @@ private byte[] bytes(int... bytesAsInts) { return bytes; } + private static byte[] generateSGroupTags() { + byte[] bytes = new byte[100000]; + Arrays.fill(bytes, (byte) GROUP_TAP); + return bytes; + } + + private static byte[] generateSGroupTagsForMapField() { + byte[] initialBytes = {18, 1, 75, 26, (byte) 198, (byte) 154, 12}; + return Bytes.concat(initialBytes, NESTING_SGROUP); + } + /** * An InputStream which limits the number of bytes it reads at a time. We use this to make sure * that CodedInputStream doesn't screw up when reading in small blocks. @@ -740,6 +761,143 @@ public void testMaliciousRecursion() throws Exception { } } + @Test + public void testMaliciousRecursion_unknownFields() throws Exception { + Throwable thrown = + assertThrows( + InvalidProtocolBufferException.class, + () -> TestRecursiveMessage.parseFrom(NESTING_SGROUP)); + + assertThat(thrown).hasMessageThat().contains("Protocol message had too many levels of nesting"); + } + + @Test + public void testMaliciousRecursion_skippingUnknownField() throws Exception { + Throwable thrown = + assertThrows( + InvalidProtocolBufferException.class, + () -> + DiscardUnknownFieldsParser.wrap(TestRecursiveMessage.parser()) + .parseFrom(NESTING_SGROUP)); + + assertThat(thrown).hasMessageThat().contains("Protocol message had too many levels of nesting"); + } + + @Test + public void testMaliciousSGroupTagsWithMapField_fromInputStream() throws Exception { + Throwable parseFromThrown = + assertThrows( + InvalidProtocolBufferException.class, + () -> + MapContainer.parseFrom( + new ByteArrayInputStream(NESTING_SGROUP_WITH_INITIAL_BYTES))); + Throwable mergeFromThrown = + assertThrows( + InvalidProtocolBufferException.class, + () -> + MapContainer.newBuilder() + .mergeFrom(new ByteArrayInputStream(NESTING_SGROUP_WITH_INITIAL_BYTES))); + + assertThat(parseFromThrown) + .hasMessageThat() + .contains("Protocol message had too many levels of nesting"); + assertThat(mergeFromThrown) + .hasMessageThat() + .contains("Protocol message had too many levels of nesting"); + } + + @Test + public void testMaliciousSGroupTags_inputStream_skipMessage() throws Exception { + ByteArrayInputStream inputSteam = new ByteArrayInputStream(NESTING_SGROUP); + CodedInputStream input = CodedInputStream.newInstance(inputSteam); + CodedOutputStream output = CodedOutputStream.newInstance(new byte[NESTING_SGROUP.length]); + + Throwable thrown = assertThrows(InvalidProtocolBufferException.class, input::skipMessage); + Throwable thrown2 = + assertThrows(InvalidProtocolBufferException.class, () -> input.skipMessage(output)); + + assertThat(thrown).hasMessageThat().contains("Protocol message had too many levels of nesting"); + assertThat(thrown2) + .hasMessageThat() + .contains("Protocol message had too many levels of nesting"); + } + + @Test + public void testMaliciousSGroupTagsWithMapField_fromByteArray() throws Exception { + Throwable parseFromThrown = + assertThrows( + InvalidProtocolBufferException.class, + () -> MapContainer.parseFrom(NESTING_SGROUP_WITH_INITIAL_BYTES)); + Throwable mergeFromThrown = + assertThrows( + InvalidProtocolBufferException.class, + () -> MapContainer.newBuilder().mergeFrom(NESTING_SGROUP_WITH_INITIAL_BYTES)); + + assertThat(parseFromThrown) + .hasMessageThat() + .contains("the input ended unexpectedly in the middle of a field"); + assertThat(mergeFromThrown) + .hasMessageThat() + .contains("the input ended unexpectedly in the middle of a field"); + } + + @Test + public void testMaliciousSGroupTags_arrayDecoder_skipMessage() throws Exception { + CodedInputStream input = CodedInputStream.newInstance(NESTING_SGROUP); + CodedOutputStream output = CodedOutputStream.newInstance(new byte[NESTING_SGROUP.length]); + + Throwable thrown = assertThrows(InvalidProtocolBufferException.class, input::skipMessage); + Throwable thrown2 = + assertThrows(InvalidProtocolBufferException.class, () -> input.skipMessage(output)); + + assertThat(thrown).hasMessageThat().contains("Protocol message had too many levels of nesting"); + assertThat(thrown2) + .hasMessageThat() + .contains("Protocol message had too many levels of nesting"); + } + + @Test + public void testMaliciousSGroupTagsWithMapField_fromByteBuffer() throws Exception { + Throwable thrown = + assertThrows( + InvalidProtocolBufferException.class, + () -> MapContainer.parseFrom(ByteBuffer.wrap(NESTING_SGROUP_WITH_INITIAL_BYTES))); + + assertThat(thrown) + .hasMessageThat() + .contains("the input ended unexpectedly in the middle of a field"); + } + + @Test + public void testMaliciousSGroupTags_byteBuffer_skipMessage() throws Exception { + CodedInputStream input = InputType.NIO_DIRECT.newDecoder(NESTING_SGROUP); + CodedOutputStream output = CodedOutputStream.newInstance(new byte[NESTING_SGROUP.length]); + + Throwable thrown = assertThrows(InvalidProtocolBufferException.class, input::skipMessage); + Throwable thrown2 = + assertThrows(InvalidProtocolBufferException.class, () -> input.skipMessage(output)); + + assertThat(thrown).hasMessageThat().contains("Protocol message had too many levels of nesting"); + assertThat(thrown2) + .hasMessageThat() + .contains("Protocol message had too many levels of nesting"); + } + + @Test + public void testMaliciousSGroupTags_iterableByteBuffer() throws Exception { + CodedInputStream input = InputType.ITER_DIRECT.newDecoder(NESTING_SGROUP); + CodedOutputStream output = CodedOutputStream.newInstance(new byte[NESTING_SGROUP.length]); + + Throwable thrown = assertThrows(InvalidProtocolBufferException.class, input::skipMessage); + Throwable thrown2 = + assertThrows(InvalidProtocolBufferException.class, () -> input.skipMessage(output)); + + assertThat(thrown).hasMessageThat().contains("Protocol message had too many levels of nesting"); + assertThat(thrown2) + .hasMessageThat() + .contains("Protocol message had too many levels of nesting"); + } + private void checkSizeLimitExceeded(InvalidProtocolBufferException e) { assertThat(e) .hasMessageThat() diff --git a/java/core/src/test/proto/com/google/protobuf/map_lite_test.proto b/java/core/src/test/proto/com/google/protobuf/map_lite_test.proto index 610bef866852e..c40ea5461efac 100644 --- a/java/core/src/test/proto/com/google/protobuf/map_lite_test.proto +++ b/java/core/src/test/proto/com/google/protobuf/map_lite_test.proto @@ -115,4 +115,5 @@ message ReservedAsMapFieldWithEnumValue { // https://github.com/protocolbuffers/protobuf/issues/9785 message MapContainer { map my_map = 1; + map m = 3; } diff --git a/java/core/src/test/proto/com/google/protobuf/map_test.proto b/java/core/src/test/proto/com/google/protobuf/map_test.proto index 4d1ac677e6387..80bdc0d8532d8 100644 --- a/java/core/src/test/proto/com/google/protobuf/map_test.proto +++ b/java/core/src/test/proto/com/google/protobuf/map_test.proto @@ -114,4 +114,5 @@ message ReservedAsMapFieldWithEnumValue { // https://github.com/protocolbuffers/protobuf/issues/9785 message MapContainer { map my_map = 1; + map m = 3; } diff --git a/java/lite/src/test/java/com/google/protobuf/LiteTest.java b/java/lite/src/test/java/com/google/protobuf/LiteTest.java index 754ed7d5fc7c5..1fc9c56071609 100644 --- a/java/lite/src/test/java/com/google/protobuf/LiteTest.java +++ b/java/lite/src/test/java/com/google/protobuf/LiteTest.java @@ -10,12 +10,14 @@ import static com.google.common.truth.Truth.assertThat; import static com.google.common.truth.Truth.assertWithMessage; import static java.util.Collections.singletonList; +import static org.junit.Assert.assertThrows; import com.google.protobuf.FieldPresenceTestProto.TestAllTypes; import com.google.protobuf.UnittestImportLite.ImportEnumLite; import com.google.protobuf.UnittestImportPublicLite.PublicImportMessageLite; import com.google.protobuf.UnittestLite.ForeignEnumLite; import com.google.protobuf.UnittestLite.ForeignMessageLite; +import com.google.protobuf.UnittestLite.RecursiveGroup; import com.google.protobuf.UnittestLite.RecursiveMessage; import com.google.protobuf.UnittestLite.TestAllExtensionsLite; import com.google.protobuf.UnittestLite.TestAllTypesLite; @@ -29,6 +31,7 @@ import com.google.protobuf.UnittestLite.TestHugeFieldNumbersLite; import com.google.protobuf.UnittestLite.TestNestedExtensionLite; import com.google.protobuf.testing.Proto3TestingLite.Proto3MessageLite; +import map_lite_test.MapTestProto.MapContainer; import map_lite_test.MapTestProto.TestMap; import map_lite_test.MapTestProto.TestMap.MessageValue; import protobuf_unittest.NestedExtensionLite; @@ -50,6 +53,7 @@ import java.util.Arrays; import java.util.Iterator; import java.util.List; +import java.util.concurrent.atomic.AtomicBoolean; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -2459,6 +2463,211 @@ public void testParseFromByteBufferThrows() { } } + @Test + public void testParseFromInputStream_concurrent_nestingUnknownGroups() throws Exception { + int numThreads = 200; + ArrayList threads = new ArrayList<>(); + + ByteString byteString = generateNestingGroups(99); + AtomicBoolean thrown = new AtomicBoolean(false); + + for (int i = 0; i < numThreads; i++) { + Thread thread = + new Thread( + () -> { + try { + TestAllTypesLite unused = TestAllTypesLite.parseFrom(byteString); + } catch (IOException e) { + if (e.getMessage().contains("Protocol message had too many levels of nesting")) { + thrown.set(true); + } + } + }); + thread.start(); + threads.add(thread); + } + + for (Thread thread : threads) { + thread.join(); + } + + assertThat(thrown.get()).isFalse(); + } + + @Test + public void testParseFromInputStream_nestingUnknownGroups() throws IOException { + ByteString byteString = generateNestingGroups(99); + + Throwable thrown = + assertThrows( + InvalidProtocolBufferException.class, () -> TestAllTypesLite.parseFrom(byteString)); + assertThat(thrown) + .hasMessageThat() + .doesNotContain("Protocol message had too many levels of nesting"); + } + + @Test + public void testParseFromInputStream_nestingUnknownGroups_exception() throws IOException { + ByteString byteString = generateNestingGroups(100); + + Throwable thrown = + assertThrows( + InvalidProtocolBufferException.class, () -> TestAllTypesLite.parseFrom(byteString)); + assertThat(thrown).hasMessageThat().contains("Protocol message had too many levels of nesting"); + } + + @Test + public void testParseFromInputStream_setRecursionLimit_exception() throws IOException { + ByteString byteString = generateNestingGroups(199); + UnknownFieldSchema schema = SchemaUtil.unknownFieldSetLiteSchema(); + schema.setRecursionLimit(200); + + Throwable thrown = + assertThrows( + InvalidProtocolBufferException.class, () -> TestAllTypesLite.parseFrom(byteString)); + assertThat(thrown) + .hasMessageThat() + .doesNotContain("Protocol message had too many levels of nesting"); + schema.setRecursionLimit(UnknownFieldSchema.DEFAULT_RECURSION_LIMIT); + } + + @Test + public void testParseFromBytes_concurrent_nestingUnknownGroups() throws Exception { + int numThreads = 200; + ArrayList threads = new ArrayList<>(); + + ByteString byteString = generateNestingGroups(99); + AtomicBoolean thrown = new AtomicBoolean(false); + + for (int i = 0; i < numThreads; i++) { + Thread thread = + new Thread( + () -> { + try { + // Should pass in byte[] instead of ByteString to go into ArrayDecoders. + TestAllTypesLite unused = TestAllTypesLite.parseFrom(byteString.toByteArray()); + } catch (InvalidProtocolBufferException e) { + if (e.getMessage().contains("Protocol message had too many levels of nesting")) { + thrown.set(true); + } + } + }); + thread.start(); + threads.add(thread); + } + + for (Thread thread : threads) { + thread.join(); + } + + assertThat(thrown.get()).isFalse(); + } + + @Test + public void testParseFromBytes_nestingUnknownGroups() throws IOException { + ByteString byteString = generateNestingGroups(99); + + Throwable thrown = + assertThrows( + InvalidProtocolBufferException.class, + () -> TestAllTypesLite.parseFrom(byteString.toByteArray())); + assertThat(thrown) + .hasMessageThat() + .doesNotContain("Protocol message had too many levels of nesting"); + } + + @Test + public void testParseFromBytes_nestingUnknownGroups_exception() throws IOException { + ByteString byteString = generateNestingGroups(100); + + Throwable thrown = + assertThrows( + InvalidProtocolBufferException.class, + () -> TestAllTypesLite.parseFrom(byteString.toByteArray())); + assertThat(thrown).hasMessageThat().contains("Protocol message had too many levels of nesting"); + } + + @Test + public void testParseFromBytes_setRecursionLimit_exception() throws IOException { + ByteString byteString = generateNestingGroups(199); + ArrayDecoders.setRecursionLimit(200); + + Throwable thrown = + assertThrows( + InvalidProtocolBufferException.class, + () -> TestAllTypesLite.parseFrom(byteString.toByteArray())); + assertThat(thrown) + .hasMessageThat() + .doesNotContain("Protocol message had too many levels of nesting"); + ArrayDecoders.setRecursionLimit(ArrayDecoders.DEFAULT_RECURSION_LIMIT); + } + + @Test + public void testParseFromBytes_recursiveMessages() throws Exception { + byte[] data99 = makeRecursiveMessage(99).toByteArray(); + byte[] data100 = makeRecursiveMessage(100).toByteArray(); + + RecursiveMessage unused = RecursiveMessage.parseFrom(data99); + Throwable thrown = + assertThrows( + InvalidProtocolBufferException.class, () -> RecursiveMessage.parseFrom(data100)); + assertThat(thrown).hasMessageThat().contains("Protocol message had too many levels of nesting"); + } + + @Test + public void testParseFromBytes_recursiveKnownGroups() throws Exception { + byte[] data99 = makeRecursiveGroup(99).toByteArray(); + byte[] data100 = makeRecursiveGroup(100).toByteArray(); + + RecursiveGroup unused = RecursiveGroup.parseFrom(data99); + Throwable thrown = + assertThrows(InvalidProtocolBufferException.class, () -> RecursiveGroup.parseFrom(data100)); + assertThat(thrown).hasMessageThat().contains("Protocol message had too many levels of nesting"); + } + + @Test + @SuppressWarnings("ProtoParseFromByteString") + public void testMaliciousSGroupTagsWithMapField_fromByteArray() throws Exception { + ByteString byteString = generateNestingGroups(102); + + Throwable parseFromThrown = + assertThrows( + InvalidProtocolBufferException.class, + () -> MapContainer.parseFrom(byteString.toByteArray())); + Throwable mergeFromThrown = + assertThrows( + InvalidProtocolBufferException.class, + () -> MapContainer.newBuilder().mergeFrom(byteString.toByteArray())); + + assertThat(parseFromThrown) + .hasMessageThat() + .contains("Protocol message had too many levels of nesting"); + assertThat(mergeFromThrown) + .hasMessageThat() + .contains("Protocol message had too many levels of nesting"); + } + + @Test + public void testMaliciousSGroupTagsWithMapField_fromInputStream() throws Exception { + byte[] bytes = generateNestingGroups(101).toByteArray(); + + Throwable parseFromThrown = + assertThrows( + InvalidProtocolBufferException.class, + () -> MapContainer.parseFrom(new ByteArrayInputStream(bytes))); + Throwable mergeFromThrown = + assertThrows( + InvalidProtocolBufferException.class, + () -> MapContainer.newBuilder().mergeFrom(new ByteArrayInputStream(bytes))); + + assertThat(parseFromThrown) + .hasMessageThat() + .contains("Protocol message had too many levels of nesting"); + assertThat(mergeFromThrown) + .hasMessageThat() + .contains("Protocol message had too many levels of nesting"); + } + @Test public void testParseFromByteBuffer_extensions() throws Exception { TestAllExtensionsLite message = @@ -2815,4 +3024,31 @@ private static boolean contains(ByteString a, ByteString b) { } return false; } + + private static ByteString generateNestingGroups(int num) throws IOException { + int groupTap = WireFormat.makeTag(3, WireFormat.WIRETYPE_START_GROUP); + ByteString.Output byteStringOutput = ByteString.newOutput(); + CodedOutputStream codedOutput = CodedOutputStream.newInstance(byteStringOutput); + for (int i = 0; i < num; i++) { + codedOutput.writeInt32NoTag(groupTap); + } + codedOutput.flush(); + return byteStringOutput.toByteString(); + } + + private static RecursiveMessage makeRecursiveMessage(int num) { + if (num == 0) { + return RecursiveMessage.getDefaultInstance(); + } else { + return RecursiveMessage.newBuilder().setRecurse(makeRecursiveMessage(num - 1)).build(); + } + } + + private static RecursiveGroup makeRecursiveGroup(int num) { + if (num == 0) { + return RecursiveGroup.getDefaultInstance(); + } else { + return RecursiveGroup.newBuilder().setRecurse(makeRecursiveGroup(num - 1)).build(); + } + } } diff --git a/src/google/protobuf/unittest_lite.proto b/src/google/protobuf/unittest_lite.proto index cdafcf7933ba9..0a6c0456c3321 100644 --- a/src/google/protobuf/unittest_lite.proto +++ b/src/google/protobuf/unittest_lite.proto @@ -627,3 +627,7 @@ message RecursiveMessage { RecursiveMessage recurse = 1; bytes payload = 2; } + +message RecursiveGroup { + RecursiveGroup recurse = 1 [features.message_encoding = DELIMITED]; +}