Skip to content

Commit

Permalink
address code review feedback: count the arrays themselves
Browse files Browse the repository at this point in the history
  • Loading branch information
adamsitnik committed Sep 13, 2024
1 parent a5a38fb commit 94f6c04
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ private protected ArrayRecord(ArrayInfo arrayInfo)

internal long ValuesToRead { get; private protected set; }

private protected ArrayInfo ArrayInfo { get; }
internal ArrayInfo ArrayInfo { get; }

internal bool IsJagged
=> ArrayInfo.ArrayType == BinaryArrayType.Jagged
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,51 +191,42 @@ private static long GetJaggedArrayFlattenedLength(BinaryArrayRecord jaggedArrayR

Debug.Assert(jaggedArrayRecord.IsJagged);

// In theory somebody could create a payload that would represent
// a very nested array with total elements count > long.MaxValue.
// That is why this method is using checked arithmetic.
result = checked(result + jaggedArrayRecord.Length); // count the arrays themselves

foreach (object value in jaggedArrayRecord.Values)
{
object item = value is MemberReferenceRecord referenceRecord
? referenceRecord.GetReferencedRecord()
: value;

if (item is not SerializationRecord record)
if (value is not SerializationRecord record)
{
result++;
continue;
}

if (record.RecordType == SerializationRecordType.MemberReference)
{
record = ((MemberReferenceRecord)record).GetReferencedRecord();
}

switch (record.RecordType)
{
case SerializationRecordType.BinaryArray:
case SerializationRecordType.ArraySinglePrimitive:
case SerializationRecordType.ArraySingleObject:
case SerializationRecordType.ArraySingleString:
result = checked(result + ((ArrayRecord)record).FlattenedLength);
break;
case SerializationRecordType.BinaryArray:
ArrayRecord nestedArrayRecord = (ArrayRecord)record;
if (nestedArrayRecord.IsJagged)
{
(jaggedArrayRecords ??= new()).Enqueue((BinaryArrayRecord)nestedArrayRecord);
}
else
{
Debug.Assert(nestedArrayRecord is not BinaryArrayRecord, "Ensure lack of recursive call");
checked
{
// In theory somebody could create a payload that would represent
// a very nested array with total elements count > long.MaxValue.
result += nestedArrayRecord.FlattenedLength;
}
}
break;
case SerializationRecordType.ObjectNull:
case SerializationRecordType.ObjectNullMultiple256:
case SerializationRecordType.ObjectNullMultiple:
// All nulls need to be included, as it's another form of possible attack.
checked
{
result += ((NullsRecord)item).NullCount;
result = checked(result + nestedArrayRecord.ArrayInfo.FlattenedLength);
}
break;
default:
result++;
break;
}
}
Expand Down
34 changes: 25 additions & 9 deletions src/libraries/System.Formats.Nrbf/tests/JaggedArraysTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,25 @@ namespace System.Formats.Nrbf.Tests;

public class JaggedArraysTests : ReadTests
{
[Fact]
public void CanReadJaggedArraysOfPrimitiveTypes_2D()
[Theory]
[InlineData(true)]
[InlineData(false)]
public void CanReadJaggedArraysOfPrimitiveTypes_2D(bool useReferences)
{
int[][] input = new int[7][];
int[] same = [1, 2, 3];
for (int i = 0; i < input.Length; i++)
{
input[i] = [i, i, i];
input[i] = useReferences
? same // reuse the same object (represented as a single record that is referenced multiple times)
: [i, i, i]; // create new array
}

var arrayRecord = (ArrayRecord)NrbfDecoder.Decode(Serialize(input));

Verify(input, arrayRecord);
Assert.Equal(input, arrayRecord.GetArray(input.GetType()));
Assert.Equal(input.Length * 3, arrayRecord.FlattenedLength);
Assert.Equal(input.Length + input.Length * 3, arrayRecord.FlattenedLength);
}

[Theory]
Expand All @@ -42,13 +47,17 @@ public void FlattenedLengthIncludesNullArrays(int nullCount)
public void ItIsPossibleToHaveBinaryArrayRecordsHaveAnElementTypeOfArrayWithoutBeingMarkedAsJagged()
{
int[][][] input = new int[3][][];
long totalElementsCount = 0;
for (int i = 0; i < input.Length; i++)
{
input[i] = new int[4][];
totalElementsCount++; // count the arrays themselves

for (int j = 0; j < input[i].Length; j++)
{
input[i][j] = [i, j, 0, 1, 2];
totalElementsCount += input[i][j].Length;
totalElementsCount++; // count the arrays themselves
}
}

Expand All @@ -67,25 +76,31 @@ public void ItIsPossibleToHaveBinaryArrayRecordsHaveAnElementTypeOfArrayWithoutB

Verify(input, arrayRecord);
Assert.Equal(input, arrayRecord.GetArray(input.GetType()));
Assert.Equal(3 * 4 * 5, arrayRecord.FlattenedLength);
Assert.Equal(3 + 3 * 4 + 3 * 4 * 5, totalElementsCount);
Assert.Equal(totalElementsCount, arrayRecord.FlattenedLength);
}

[Fact]
public void CanReadJaggedArraysOfPrimitiveTypes_3D()
{
int[][][] input = new int[7][][];
long totalElementsCount = 0;
for (int i = 0; i < input.Length; i++)
{
totalElementsCount++; // count the arrays themselves
input[i] = new int[1][];
totalElementsCount++; // count the arrays themselves
input[i][0] = [i, i, i];
totalElementsCount += input[i][0].Length;
}

var arrayRecord = (ArrayRecord)NrbfDecoder.Decode(Serialize(input));

Verify(input, arrayRecord);
Assert.Equal(input, arrayRecord.GetArray(input.GetType()));
Assert.Equal(1, arrayRecord.Rank);
Assert.Equal(input.Length * 1 * 3, arrayRecord.FlattenedLength);
Assert.Equal(7 + 7 * 1 + 7 * 1 * 3, totalElementsCount);
Assert.Equal(totalElementsCount, arrayRecord.FlattenedLength);
}

[Fact]
Expand All @@ -110,7 +125,7 @@ public void CanReadJaggedArrayOfRectangularArrays()
Verify(input, arrayRecord);
Assert.Equal(input, arrayRecord.GetArray(input.GetType()));
Assert.Equal(1, arrayRecord.Rank);
Assert.Equal(input.Length * 3 * 3, arrayRecord.FlattenedLength);
Assert.Equal(input.Length + input.Length * 3 * 3, arrayRecord.FlattenedLength);
}

[Fact]
Expand All @@ -126,7 +141,7 @@ public void CanReadJaggedArraysOfStrings()

Verify(input, arrayRecord);
Assert.Equal(input, arrayRecord.GetArray(input.GetType()));
Assert.Equal(input.Length * 3, arrayRecord.FlattenedLength);
Assert.Equal(input.Length + input.Length * 3, arrayRecord.FlattenedLength);
}

[Fact]
Expand All @@ -142,7 +157,7 @@ public void CanReadJaggedArraysOfObjects()

Verify(input, arrayRecord);
Assert.Equal(input, arrayRecord.GetArray(input.GetType()));
Assert.Equal(input.Length * 3, arrayRecord.FlattenedLength);
Assert.Equal(input.Length + input.Length * 3, arrayRecord.FlattenedLength);
}

[Serializable]
Expand All @@ -160,6 +175,7 @@ public void CanReadJaggedArraysOfComplexTypes()
{
input[i] = Enumerable.Range(0, i + 1).Select(j => new ComplexType { SomeField = j }).ToArray();
totalElementsCount += input[i].Length;
totalElementsCount++; // count the arrays themselves
}

var arrayRecord = (ArrayRecord)NrbfDecoder.Decode(Serialize(input));
Expand Down

0 comments on commit 94f6c04

Please sign in to comment.