Skip to content

Commit

Permalink
Merge pull request #18387 from protocolbuffers/cp-lp-25
Browse files Browse the repository at this point in the history
Add recursion check when parsing unknown fields in Java.
  • Loading branch information
zhangskz authored Sep 18, 2024
2 parents e673479 + b5a7cf7 commit 4a197e7
Show file tree
Hide file tree
Showing 10 changed files with 469 additions and 95 deletions.
31 changes: 29 additions & 2 deletions java/core/src/main/java/com/google/protobuf/ArrayDecoders.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,12 @@
*/
@CheckReturnValue
final class ArrayDecoders {
static final int DEFAULT_RECURSION_LIMIT = 100;

private ArrayDecoders() {
}
@SuppressWarnings("NonFinalStaticField")
private static volatile int recursionLimit = DEFAULT_RECURSION_LIMIT;

private ArrayDecoders() {}

/**
* A helper used to return multiple values in a Java function. Java doesn't natively support
Expand All @@ -38,6 +41,7 @@ static final class Registers {
public long long1;
public Object object1;
public final ExtensionRegistryLite extensionRegistry;
public int recursionDepth;

Registers() {
this.extensionRegistry = ExtensionRegistryLite.getEmptyRegistry();
Expand Down Expand Up @@ -245,7 +249,10 @@ static int mergeMessageField(
if (length < 0 || length > limit - position) {
throw InvalidProtocolBufferException.truncatedMessage();
}
registers.recursionDepth++;
checkRecursionLimit(registers.recursionDepth);
schema.mergeFrom(msg, data, position, position + length, registers);
registers.recursionDepth--;
registers.object1 = msg;
return position + length;
}
Expand All @@ -263,8 +270,11 @@ static int mergeGroupField(
// A group field must has a MessageSchema (the only other subclass of Schema is MessageSetSchema
// and it can't be used in group fields).
final MessageSchema messageSchema = (MessageSchema) schema;
registers.recursionDepth++;
checkRecursionLimit(registers.recursionDepth);
final int endPosition =
messageSchema.parseMessage(msg, data, position, limit, endGroup, registers);
registers.recursionDepth--;
registers.object1 = msg;
return endPosition;
}
Expand Down Expand Up @@ -1025,6 +1035,8 @@ static int decodeUnknownField(
final UnknownFieldSetLite child = UnknownFieldSetLite.newInstance();
final int endGroup = (tag & ~0x7) | WireFormat.WIRETYPE_END_GROUP;
int lastTag = 0;
registers.recursionDepth++;
checkRecursionLimit(registers.recursionDepth);
while (position < limit) {
position = decodeVarint32(data, position, registers);
lastTag = registers.int1;
Expand All @@ -1033,6 +1045,7 @@ static int decodeUnknownField(
}
position = decodeUnknownField(lastTag, data, position, limit, child, registers);
}
registers.recursionDepth--;
if (position > limit || lastTag != endGroup) {
throw InvalidProtocolBufferException.parseFailure();
}
Expand Down Expand Up @@ -1079,4 +1092,18 @@ static int skipField(int tag, byte[] data, int position, int limit, Registers re
throw InvalidProtocolBufferException.invalidTag();
}
}

/**
* Set the maximum recursion limit that ArrayDecoders will allow. An exception will be thrown if
* the depth of the message exceeds this limit.
*/
public static void setRecursionLimit(int limit) {
recursionLimit = limit;
}

private static void checkRecursionLimit(int depth) throws InvalidProtocolBufferException {
if (depth >= recursionLimit) {
throw InvalidProtocolBufferException.recursionLimitExceeded();
}
}
}
112 changes: 30 additions & 82 deletions java/core/src/main/java/com/google/protobuf/CodedInputStream.java
Original file line number Diff line number Diff line change
Expand Up @@ -223,13 +223,41 @@ public abstract boolean skipField(final int tag, final CodedOutputStream output)
* Reads and discards an entire message. This will read either until EOF or until an endgroup tag,
* whichever comes first.
*/
public abstract void skipMessage() throws IOException;
public void skipMessage() throws IOException {
while (true) {
final int tag = readTag();
if (tag == 0) {
return;
}
checkRecursionLimit();
++recursionDepth;
boolean fieldSkipped = skipField(tag);
--recursionDepth;
if (!fieldSkipped) {
return;
}
}
}

/**
* Reads an entire message and writes it to output in wire format. This will read either until EOF
* or until an endgroup tag, whichever comes first.
*/
public abstract void skipMessage(CodedOutputStream output) throws IOException;
public void skipMessage(CodedOutputStream output) throws IOException {
while (true) {
final int tag = readTag();
if (tag == 0) {
return;
}
checkRecursionLimit();
++recursionDepth;
boolean fieldSkipped = skipField(tag, output);
--recursionDepth;
if (!fieldSkipped) {
return;
}
}
}

// -----------------------------------------------------------------

Expand Down Expand Up @@ -699,26 +727,6 @@ public boolean skipField(final int tag, final CodedOutputStream output) throws I
}
}

@Override
public void skipMessage() throws IOException {
while (true) {
final int tag = readTag();
if (tag == 0 || !skipField(tag)) {
return;
}
}
}

@Override
public void skipMessage(CodedOutputStream output) throws IOException {
while (true) {
final int tag = readTag();
if (tag == 0 || !skipField(tag, output)) {
return;
}
}
}

// -----------------------------------------------------------------

@Override
Expand Down Expand Up @@ -1411,26 +1419,6 @@ public boolean skipField(final int tag, final CodedOutputStream output) throws I
}
}

@Override
public void skipMessage() throws IOException {
while (true) {
final int tag = readTag();
if (tag == 0 || !skipField(tag)) {
return;
}
}
}

@Override
public void skipMessage(CodedOutputStream output) throws IOException {
while (true) {
final int tag = readTag();
if (tag == 0 || !skipField(tag, output)) {
return;
}
}
}

// -----------------------------------------------------------------

@Override
Expand Down Expand Up @@ -2176,26 +2164,6 @@ public boolean skipField(final int tag, final CodedOutputStream output) throws I
}
}

@Override
public void skipMessage() throws IOException {
while (true) {
final int tag = readTag();
if (tag == 0 || !skipField(tag)) {
return;
}
}
}

@Override
public void skipMessage(CodedOutputStream output) throws IOException {
while (true) {
final int tag = readTag();
if (tag == 0 || !skipField(tag, output)) {
return;
}
}
}

/** Collects the bytes skipped and returns the data in a ByteBuffer. */
private class SkippedDataSink implements RefillCallback {
private int lastPos = pos;
Expand Down Expand Up @@ -3307,26 +3275,6 @@ public boolean skipField(final int tag, final CodedOutputStream output) throws I
}
}

@Override
public void skipMessage() throws IOException {
while (true) {
final int tag = readTag();
if (tag == 0 || !skipField(tag)) {
return;
}
}
}

@Override
public void skipMessage(CodedOutputStream output) throws IOException {
while (true) {
final int tag = readTag();
if (tag == 0 || !skipField(tag, output)) {
return;
}
}
}

// -----------------------------------------------------------------

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ public InvalidWireTypeException(String description) {
static InvalidProtocolBufferException recursionLimitExceeded() {
return new InvalidProtocolBufferException(
"Protocol message had too many levels of nesting. May be malicious. "
+ "Use CodedInputStream.setRecursionLimit() to increase the depth limit.");
+ "Use setRecursionLimit() to increase the recursion depth limit.");
}

static InvalidProtocolBufferException sizeLimitExceeded() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3006,7 +3006,8 @@ private <UT, UB, ET extends FieldDescriptorLite<ET>> void mergeFromHelper(
unknownFields = unknownFieldSchema.getBuilderFromMessage(message);
}
// Unknown field.
if (unknownFieldSchema.mergeOneFieldFrom(unknownFields, reader)) {
if (unknownFieldSchema.mergeOneFieldFrom(
unknownFields, reader, /* currentDepth= */ 0)) {
continue;
}
}
Expand Down Expand Up @@ -3381,7 +3382,8 @@ private <UT, UB, ET extends FieldDescriptorLite<ET>> void mergeFromHelper(
if (unknownFields == null) {
unknownFields = unknownFieldSchema.getBuilderFromMessage(message);
}
if (!unknownFieldSchema.mergeOneFieldFrom(unknownFields, reader)) {
if (!unknownFieldSchema.mergeOneFieldFrom(
unknownFields, reader, /* currentDepth= */ 0)) {
return;
}
break;
Expand All @@ -3397,7 +3399,8 @@ private <UT, UB, ET extends FieldDescriptorLite<ET>> void mergeFromHelper(
if (unknownFields == null) {
unknownFields = unknownFieldSchema.getBuilderFromMessage(message);
}
if (!unknownFieldSchema.mergeOneFieldFrom(unknownFields, reader)) {
if (!unknownFieldSchema.mergeOneFieldFrom(
unknownFields, reader, /* currentDepth= */ 0)) {
return;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ boolean parseMessageSetItemOrUnknownField(
reader, extension, extensionRegistry, extensions);
return true;
} else {
return unknownFieldSchema.mergeOneFieldFrom(unknownFields, reader);
return unknownFieldSchema.mergeOneFieldFrom(unknownFields, reader, /* currentDepth= */ 0);
}
} else {
return reader.skipField();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@
@CheckReturnValue
abstract class UnknownFieldSchema<T, B> {

static final int DEFAULT_RECURSION_LIMIT = 100;

@SuppressWarnings("NonFinalStaticField")
private static volatile int recursionLimit = DEFAULT_RECURSION_LIMIT;

/** Whether unknown fields should be dropped. */
abstract boolean shouldDiscardUnknownFields(Reader reader);

Expand Down Expand Up @@ -56,7 +61,8 @@ abstract class UnknownFieldSchema<T, B> {
abstract void makeImmutable(Object message);

/** Merges one field into the unknown fields. */
final boolean mergeOneFieldFrom(B unknownFields, Reader reader) throws IOException {
final boolean mergeOneFieldFrom(B unknownFields, Reader reader, int currentDepth)
throws IOException {
int tag = reader.getTag();
int fieldNumber = WireFormat.getTagFieldNumber(tag);
switch (WireFormat.getTagWireType(tag)) {
Expand All @@ -75,7 +81,12 @@ final boolean mergeOneFieldFrom(B unknownFields, Reader reader) throws IOExcepti
case WireFormat.WIRETYPE_START_GROUP:
final B subFields = newBuilder();
int endGroupTag = WireFormat.makeTag(fieldNumber, WireFormat.WIRETYPE_END_GROUP);
mergeFrom(subFields, reader);
currentDepth++;
if (currentDepth >= recursionLimit) {
throw InvalidProtocolBufferException.recursionLimitExceeded();
}
mergeFrom(subFields, reader, currentDepth);
currentDepth--;
if (endGroupTag != reader.getTag()) {
throw InvalidProtocolBufferException.invalidEndTag();
}
Expand All @@ -88,10 +99,11 @@ final boolean mergeOneFieldFrom(B unknownFields, Reader reader) throws IOExcepti
}
}

final void mergeFrom(B unknownFields, Reader reader) throws IOException {
private final void mergeFrom(B unknownFields, Reader reader, int currentDepth)
throws IOException {
while (true) {
if (reader.getFieldNumber() == Reader.READ_DONE
|| !mergeOneFieldFrom(unknownFields, reader)) {
|| !mergeOneFieldFrom(unknownFields, reader, currentDepth)) {
break;
}
}
Expand All @@ -108,4 +120,12 @@ final void mergeFrom(B unknownFields, Reader reader) throws IOException {
abstract int getSerializedSizeAsMessageSet(T message);

abstract int getSerializedSize(T unknowns);

/**
* Set the maximum recursion limit that ArrayDecoders will allow. An exception will be thrown if
* the depth of the message exceeds this limit.
*/
public void setRecursionLimit(int limit) {
recursionLimit = limit;
}
}
Loading

0 comments on commit 4a197e7

Please sign in to comment.