diff --git a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArrayRecord.cs b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArrayRecord.cs index aa2e168c2f0a7..237b7b72a2719 100644 --- a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArrayRecord.cs +++ b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArrayRecord.cs @@ -50,7 +50,7 @@ private protected ArrayRecord(ArrayInfo arrayInfo) internal long ValuesToRead { get; private protected set; } - private protected ArrayInfo ArrayInfo { get; } + internal ArrayInfo ArrayInfo { get; } internal bool IsJagged => ArrayInfo.ArrayType == BinaryArrayType.Jagged diff --git a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/BinaryArrayRecord.cs b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/BinaryArrayRecord.cs index be8fb01d9ea3b..5e55f1cfbf8d1 100644 --- a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/BinaryArrayRecord.cs +++ b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/BinaryArrayRecord.cs @@ -191,24 +191,29 @@ private static long GetJaggedArrayFlattenedLength(BinaryArrayRecord jaggedArrayR Debug.Assert(jaggedArrayRecord.IsJagged); + // In theory somebody could create a payload that would represent + // a very nested array with total elements count > long.MaxValue. + // That is why this method is using checked arithmetic. + result = checked(result + jaggedArrayRecord.Length); // count the arrays themselves + foreach (object value in jaggedArrayRecord.Values) { - object item = value is MemberReferenceRecord referenceRecord - ? referenceRecord.GetReferencedRecord() - : value; - - if (item is not SerializationRecord record) + if (value is not SerializationRecord record) { - result++; continue; } + if (record.RecordType == SerializationRecordType.MemberReference) + { + record = ((MemberReferenceRecord)record).GetReferencedRecord(); + } + switch (record.RecordType) { - case SerializationRecordType.BinaryArray: case SerializationRecordType.ArraySinglePrimitive: case SerializationRecordType.ArraySingleObject: case SerializationRecordType.ArraySingleString: + case SerializationRecordType.BinaryArray: ArrayRecord nestedArrayRecord = (ArrayRecord)record; if (nestedArrayRecord.IsJagged) { @@ -216,26 +221,12 @@ private static long GetJaggedArrayFlattenedLength(BinaryArrayRecord jaggedArrayR } else { - Debug.Assert(nestedArrayRecord is not BinaryArrayRecord, "Ensure lack of recursive call"); - checked - { - // In theory somebody could create a payload that would represent - // a very nested array with total elements count > long.MaxValue. - result += nestedArrayRecord.FlattenedLength; - } - } - break; - case SerializationRecordType.ObjectNull: - case SerializationRecordType.ObjectNullMultiple256: - case SerializationRecordType.ObjectNullMultiple: - // All nulls need to be included, as it's another form of possible attack. - checked - { - result += ((NullsRecord)item).NullCount; + // Don't call nestedArrayRecord.FlattenedLength to avoid any potential recursion, + // just call nestedArrayRecord.ArrayInfo.FlattenedLength that returns pre-computed value. + result = checked(result + nestedArrayRecord.ArrayInfo.FlattenedLength); } break; default: - result++; break; } } diff --git a/src/libraries/System.Formats.Nrbf/tests/JaggedArraysTests.cs b/src/libraries/System.Formats.Nrbf/tests/JaggedArraysTests.cs index 488b7992d29e7..8bb844ff76a58 100644 --- a/src/libraries/System.Formats.Nrbf/tests/JaggedArraysTests.cs +++ b/src/libraries/System.Formats.Nrbf/tests/JaggedArraysTests.cs @@ -7,20 +7,25 @@ namespace System.Formats.Nrbf.Tests; public class JaggedArraysTests : ReadTests { - [Fact] - public void CanReadJaggedArraysOfPrimitiveTypes_2D() + [Theory] + [InlineData(true)] + [InlineData(false)] + public void CanReadJaggedArraysOfPrimitiveTypes_2D(bool useReferences) { int[][] input = new int[7][]; + int[] same = [1, 2, 3]; for (int i = 0; i < input.Length; i++) { - input[i] = [i, i, i]; + input[i] = useReferences + ? same // reuse the same object (represented as a single record that is referenced multiple times) + : [i, i, i]; // create new array } var arrayRecord = (ArrayRecord)NrbfDecoder.Decode(Serialize(input)); Verify(input, arrayRecord); Assert.Equal(input, arrayRecord.GetArray(input.GetType())); - Assert.Equal(input.Length * 3, arrayRecord.FlattenedLength); + Assert.Equal(input.Length + input.Length * 3, arrayRecord.FlattenedLength); } [Theory] @@ -42,13 +47,17 @@ public void FlattenedLengthIncludesNullArrays(int nullCount) public void ItIsPossibleToHaveBinaryArrayRecordsHaveAnElementTypeOfArrayWithoutBeingMarkedAsJagged() { int[][][] input = new int[3][][]; + long totalElementsCount = 0; for (int i = 0; i < input.Length; i++) { input[i] = new int[4][]; + totalElementsCount++; // count the arrays themselves for (int j = 0; j < input[i].Length; j++) { input[i][j] = [i, j, 0, 1, 2]; + totalElementsCount += input[i][j].Length; + totalElementsCount++; // count the arrays themselves } } @@ -67,17 +76,22 @@ public void ItIsPossibleToHaveBinaryArrayRecordsHaveAnElementTypeOfArrayWithoutB Verify(input, arrayRecord); Assert.Equal(input, arrayRecord.GetArray(input.GetType())); - Assert.Equal(3 * 4 * 5, arrayRecord.FlattenedLength); + Assert.Equal(3 + 3 * 4 + 3 * 4 * 5, totalElementsCount); + Assert.Equal(totalElementsCount, arrayRecord.FlattenedLength); } [Fact] public void CanReadJaggedArraysOfPrimitiveTypes_3D() { int[][][] input = new int[7][][]; + long totalElementsCount = 0; for (int i = 0; i < input.Length; i++) { + totalElementsCount++; // count the arrays themselves input[i] = new int[1][]; + totalElementsCount++; // count the arrays themselves input[i][0] = [i, i, i]; + totalElementsCount += input[i][0].Length; } var arrayRecord = (ArrayRecord)NrbfDecoder.Decode(Serialize(input)); @@ -85,7 +99,8 @@ public void CanReadJaggedArraysOfPrimitiveTypes_3D() Verify(input, arrayRecord); Assert.Equal(input, arrayRecord.GetArray(input.GetType())); Assert.Equal(1, arrayRecord.Rank); - Assert.Equal(input.Length * 1 * 3, arrayRecord.FlattenedLength); + Assert.Equal(7 + 7 * 1 + 7 * 1 * 3, totalElementsCount); + Assert.Equal(totalElementsCount, arrayRecord.FlattenedLength); } [Fact] @@ -110,7 +125,7 @@ public void CanReadJaggedArrayOfRectangularArrays() Verify(input, arrayRecord); Assert.Equal(input, arrayRecord.GetArray(input.GetType())); Assert.Equal(1, arrayRecord.Rank); - Assert.Equal(input.Length * 3 * 3, arrayRecord.FlattenedLength); + Assert.Equal(input.Length + input.Length * 3 * 3, arrayRecord.FlattenedLength); } [Fact] @@ -126,7 +141,7 @@ public void CanReadJaggedArraysOfStrings() Verify(input, arrayRecord); Assert.Equal(input, arrayRecord.GetArray(input.GetType())); - Assert.Equal(input.Length * 3, arrayRecord.FlattenedLength); + Assert.Equal(input.Length + input.Length * 3, arrayRecord.FlattenedLength); } [Fact] @@ -142,7 +157,7 @@ public void CanReadJaggedArraysOfObjects() Verify(input, arrayRecord); Assert.Equal(input, arrayRecord.GetArray(input.GetType())); - Assert.Equal(input.Length * 3, arrayRecord.FlattenedLength); + Assert.Equal(input.Length + input.Length * 3, arrayRecord.FlattenedLength); } [Serializable] @@ -160,6 +175,7 @@ public void CanReadJaggedArraysOfComplexTypes() { input[i] = Enumerable.Range(0, i + 1).Select(j => new ComplexType { SomeField = j }).ToArray(); totalElementsCount += input[i].Length; + totalElementsCount++; // count the arrays themselves } var arrayRecord = (ArrayRecord)NrbfDecoder.Decode(Serialize(input));