Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NRBF] Address issues discovered by Threat Model #106629

Merged
merged 13 commits into from
Sep 16, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ public abstract partial class ArrayRecord : System.Formats.Nrbf.SerializationRec
internal ArrayRecord() { }
public override System.Formats.Nrbf.SerializationRecordId Id { get { throw null; } }
public abstract System.ReadOnlySpan<int> Lengths { get; }
public virtual long TotalElementsCount { get; }
public int Rank { get { throw null; } }
[System.Diagnostics.CodeAnalysis.RequiresDynamicCode("The code for an array of the specified type might not be available.")]
public System.Array GetArray(System.Type expectedArrayType, bool allowNulls = true) { throw null; }
Expand Down
6 changes: 3 additions & 3 deletions src/libraries/System.Formats.Nrbf/src/Resources/Strings.resx
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on what we have discussed and what @bartonjs wrote here: #103713 (comment)

I believe the type names and assembly names should not be provided in the exception messages.

Original file line number Diff line number Diff line change
Expand Up @@ -136,13 +136,13 @@
<value>Member reference was pointing to a record of unexpected type.</value>
</data>
<data name="Serialization_InvalidTypeName" xml:space="preserve">
<value>Invalid type name: `{0}`.</value>
<value>Invalid type name.</value>
</data>
<data name="Serialization_TypeMismatch" xml:space="preserve">
<value>Expected the array to be of type {0}, but its element type was {1}.</value>
</data>
<data name="Serialization_InvalidTypeOrAssemblyName" xml:space="preserve">
<value>Invalid type or assembly name: `{0},{1}`.</value>
<value>Invalid type or assembly name.</value>
</data>
<data name="Serialization_DuplicateMemberName" xml:space="preserve">
<value>Duplicate member name: `{0}`.</value>
Expand All @@ -160,6 +160,6 @@
<value>Only arrays with zero offsets are supported.</value>
</data>
<data name="Serialization_InvalidAssemblyName" xml:space="preserve">
<value>Invalid assembly name: `{0}`.</value>
<value>Invalid assembly name.</value>
</data>
</root>
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@ private protected ArrayRecord(ArrayInfo arrayInfo)
/// <value>A buffer of integers that represent the number of elements in every dimension.</value>
public abstract ReadOnlySpan<int> Lengths { get; }

/// <summary>
/// When overridden in a derived class, gets the total number of all elements in every dimension.
/// </summary>
/// <value>A number that represent the total number of all elements in every dimension.</value>
public virtual long TotalElementsCount => ArrayInfo.TotalElementsCount;

/// <summary>
/// Gets the rank of the array.
/// </summary>
Expand All @@ -46,6 +52,11 @@ private protected ArrayRecord(ArrayInfo arrayInfo)

private protected ArrayInfo ArrayInfo { get; }

internal bool IsJagged
=> ArrayInfo.ArrayType == BinaryArrayType.Jagged
// It is possible to have binary array records have an element type of array without being marked as jagged.
|| TypeName.GetElementType().IsArray;
Comment on lines +57 to +58
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JeremyKuhne this handles the scenario you have mentioned offline (there is also a test for that). Thank you again for pointing this out!


/// <summary>
/// Allocates an array and fills it with the data provided in the serialized records (in case of primitive types like <see cref="string"/> or <see cref="int"/>) or the serialized records themselves.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using System.IO;
using System.Reflection.Metadata;
using System.Formats.Nrbf.Utils;
using System.Diagnostics;

namespace System.Formats.Nrbf;

Expand All @@ -27,19 +28,38 @@ internal sealed class BinaryArrayRecord : ArrayRecord
];

private TypeName? _typeName;
private long _totalElementsCount;

private BinaryArrayRecord(ArrayInfo arrayInfo, MemberTypeInfo memberTypeInfo)
: base(arrayInfo)
{
MemberTypeInfo = memberTypeInfo;
Values = [];
// We need to parse all elements of the jagged array to obtain total elements count.
_totalElementsCount = -1;
}

public override SerializationRecordType RecordType => SerializationRecordType.BinaryArray;

/// <inheritdoc/>
public override ReadOnlySpan<int> Lengths => new int[1] { Length };

/// <inheritdoc/>
public override long TotalElementsCount
{
get
{
if (_totalElementsCount < 0)
{
_totalElementsCount = IsJagged
? GetJaggedArrayTotalElementsCount(this)
: ArrayInfo.TotalElementsCount;
}

return _totalElementsCount;
}
}

public override TypeName TypeName
=> _typeName ??= MemberTypeInfo.GetArrayTypeName(ArrayInfo);

Expand Down Expand Up @@ -157,6 +177,65 @@ internal static ArrayRecord Decode(BinaryReader reader, RecordMap recordMap, Pay
: new BinaryArrayRecord(arrayInfo, memberTypeInfo);
}

private static long GetJaggedArrayTotalElementsCount(BinaryArrayRecord jaggedArrayRecord)
{
long result = 0;
Queue<BinaryArrayRecord>? jaggedArrayRecords = null;

do
{
if (jaggedArrayRecords is not null)
{
jaggedArrayRecord = jaggedArrayRecords.Dequeue();
}

Debug.Assert(jaggedArrayRecord.IsJagged);

foreach (object value in jaggedArrayRecord.Values)
{
object item = value is MemberReferenceRecord referenceRecord
? referenceRecord.GetReferencedRecord()
: value;

if (item is not SerializationRecord record)
{
result++;
continue;
}

switch (record.RecordType)
{
case SerializationRecordType.BinaryArray:
case SerializationRecordType.ArraySinglePrimitive:
case SerializationRecordType.ArraySingleObject:
case SerializationRecordType.ArraySingleString:
ArrayRecord nestedArrayRecord = (ArrayRecord)record;
if (nestedArrayRecord.IsJagged)
{
(jaggedArrayRecords ??= new()).Enqueue((BinaryArrayRecord)nestedArrayRecord);
}
else
{
result += nestedArrayRecord.TotalElementsCount;
adamsitnik marked this conversation as resolved.
Show resolved Hide resolved
}
break;
case SerializationRecordType.ObjectNull:
case SerializationRecordType.ObjectNullMultiple256:
case SerializationRecordType.ObjectNullMultiple:
// Null Records nested inside jagged array do not increase total elements count.
// Example: "int[][] input = [[1, 2, 3], null]" is just 3 elements in total.
adamsitnik marked this conversation as resolved.
Show resolved Hide resolved
break;
default:
result++;
break;
}
}
}
while (jaggedArrayRecords is not null && jaggedArrayRecords.Count > 0);

return result;
}

private protected override void AddValue(object value) => Values.Add(value);

internal override (AllowedRecordTypes allowed, PrimitiveType primitiveType) GetAllowedRecordType()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ internal static BinaryLibraryRecord Decode(BinaryReader reader, PayloadOptions o
}
else if (!options.UndoTruncatedTypeNames)
{
ThrowHelper.ThrowInvalidAssemblyName(rawName);
ThrowHelper.ThrowInvalidAssemblyName();
}

return new BinaryLibraryRecord(id, rawName);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

namespace System.Formats.Nrbf.Utils;

// The exception messages do not contain type or assembly names on purpose,
// as it's most likely corrupted/tampered/malicious data.
internal static class ThrowHelper
{
internal static void ThrowInvalidValue(object value)
Expand All @@ -14,8 +16,8 @@ internal static void ThrowInvalidValue(object value)
internal static void ThrowInvalidReference()
=> throw new SerializationException(SR.Serialization_InvalidReference);

internal static void ThrowInvalidTypeName(string name)
=> throw new SerializationException(SR.Format(SR.Serialization_InvalidTypeName, name));
internal static void ThrowInvalidTypeName()
=> throw new SerializationException(SR.Serialization_InvalidTypeName);

internal static void ThrowUnexpectedNullRecordCount()
=> throw new SerializationException(SR.Serialization_UnexpectedNullRecordCount);
Expand All @@ -26,8 +28,8 @@ internal static void ThrowMaxArrayLength(long limit, long actual)
internal static void ThrowArrayContainedNulls()
=> throw new SerializationException(SR.Serialization_ArrayContainedNulls);

internal static void ThrowInvalidAssemblyName(string rawName)
=> throw new SerializationException(SR.Format(SR.Serialization_InvalidAssemblyName, rawName));
internal static void ThrowInvalidAssemblyName()
=> throw new SerializationException(SR.Serialization_InvalidAssemblyName);

internal static void ThrowEndOfStreamException()
=> throw new EndOfStreamException();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ internal static TypeName ParseNonSystemClassRecordTypeName(this string rawName,

if (typeName is null)
{
throw new SerializationException(SR.Format(SR.Serialization_InvalidTypeOrAssemblyName, rawName, libraryRecord.RawLibraryName));
throw new SerializationException(SR.Serialization_InvalidTypeOrAssemblyName);
}

if (typeName.AssemblyName is null)
Expand Down Expand Up @@ -169,7 +169,7 @@ private static TypeName With(this TypeName typeName, AssemblyNameInfo assemblyNa
else
{
// BinaryFormatter can not serialize pointers or references.
ThrowHelper.ThrowInvalidTypeName(typeName.FullName);
ThrowHelper.ThrowInvalidTypeName();
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ private void Test<T>(int size, bool canSeek)
SZArrayRecord<T> arrayRecord = (SZArrayRecord<T>)NrbfDecoder.Decode(stream);

Assert.Equal(size, arrayRecord.Length);
Assert.Equal(size, arrayRecord.TotalElementsCount);
T?[] output = arrayRecord.GetArray();
Assert.Equal(input, output);
Assert.Same(output, arrayRecord.GetArray());
Expand Down
53 changes: 53 additions & 0 deletions src/libraries/System.Formats.Nrbf/tests/JaggedArraysTests.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System.Formats.Nrbf.Utils;
using System.IO;
using System.Linq;
using Xunit;

Expand All @@ -19,6 +20,51 @@ public void CanReadJaggedArraysOfPrimitiveTypes_2D()

Verify(input, arrayRecord);
Assert.Equal(input, arrayRecord.GetArray(input.GetType()));
Assert.Equal(input.Length * 3, arrayRecord.TotalElementsCount);
}

[Fact]
public void TotalElementsCountDoesNotIncludeNullArrays()
{
int[][] input = [[1, 2, 3], null];

var arrayRecord = (ArrayRecord)NrbfDecoder.Decode(Serialize(input));

Verify(input, arrayRecord);
Assert.Equal(input, arrayRecord.GetArray(input.GetType()));
Assert.Equal(3, arrayRecord.TotalElementsCount);
}

[Fact]
public void ItIsPossibleToHaveBinaryArrayRecordsHaveAnElementTypeOfArrayWithoutBeingMarkedAsJagged()
{
int[][][] input = new int[3][][];
for (int i = 0; i < input.Length; i++)
{
input[i] = new int[4][];

for (int j = 0; j < input[i].Length; j++)
{
input[i][j] = [i, j, 0, 1, 2];
}
}

byte[] serialized = Serialize(input).ToArray();
const int ArrayTypeByteIndex =
sizeof(byte) + sizeof(int) * 4 + // stream header
sizeof(byte) + // SerializationRecordType.BinaryArray
sizeof(int); // SerializationRecordId

Assert.Equal((byte)BinaryArrayType.Jagged, serialized[ArrayTypeByteIndex]);

// change the reported array type
serialized[ArrayTypeByteIndex] = (byte)BinaryArrayType.Single;

var arrayRecord = (ArrayRecord)NrbfDecoder.Decode(new MemoryStream(serialized));

Verify(input, arrayRecord);
Assert.Equal(input, arrayRecord.GetArray(input.GetType()));
Assert.Equal(3 * 4 * 5, arrayRecord.TotalElementsCount);
}

[Fact]
Expand All @@ -36,6 +82,7 @@ public void CanReadJaggedArraysOfPrimitiveTypes_3D()
Verify(input, arrayRecord);
Assert.Equal(input, arrayRecord.GetArray(input.GetType()));
Assert.Equal(1, arrayRecord.Rank);
Assert.Equal(input.Length * 1 * 3, arrayRecord.TotalElementsCount);
}

[Fact]
Expand All @@ -60,6 +107,7 @@ public void CanReadJaggedArrayOfRectangularArrays()
Verify(input, arrayRecord);
Assert.Equal(input, arrayRecord.GetArray(input.GetType()));
Assert.Equal(1, arrayRecord.Rank);
Assert.Equal(input.Length * 3 * 3, arrayRecord.TotalElementsCount);
}

[Fact]
Expand All @@ -75,6 +123,7 @@ public void CanReadJaggedArraysOfStrings()

Verify(input, arrayRecord);
Assert.Equal(input, arrayRecord.GetArray(input.GetType()));
Assert.Equal(input.Length * 3, arrayRecord.TotalElementsCount);
}

[Fact]
Expand All @@ -90,6 +139,7 @@ public void CanReadJaggedArraysOfObjects()

Verify(input, arrayRecord);
Assert.Equal(input, arrayRecord.GetArray(input.GetType()));
Assert.Equal(input.Length * 3, arrayRecord.TotalElementsCount);
}

[Serializable]
Expand All @@ -102,14 +152,17 @@ public class ComplexType
public void CanReadJaggedArraysOfComplexTypes()
{
ComplexType[][] input = new ComplexType[3][];
long totalElementsCount = 0;
for (int i = 0; i < input.Length; i++)
{
input[i] = Enumerable.Range(0, i + 1).Select(j => new ComplexType { SomeField = j }).ToArray();
totalElementsCount += input[i].Length;
}

var arrayRecord = (ArrayRecord)NrbfDecoder.Decode(Serialize(input));

Verify(input, arrayRecord);
Assert.Equal(totalElementsCount, arrayRecord.TotalElementsCount);
var output = (ClassRecord?[][])arrayRecord.GetArray(input.GetType());
for (int i = 0; i < input.Length; i++)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -223,10 +223,13 @@ public void CanReadRectangularArraysOfComplexTypes_3D()
internal static void Verify(Array input, ArrayRecord arrayRecord)
{
Assert.Equal(input.Rank, arrayRecord.Lengths.Length);
long totalElementsCount = 1;
for (int i = 0; i < input.Rank; i++)
{
Assert.Equal(input.GetLength(i), arrayRecord.Lengths[i]);
totalElementsCount *= input.GetLength(i);
}
Assert.Equal(totalElementsCount, arrayRecord.TotalElementsCount);
Assert.Equal(input.GetType().FullName, arrayRecord.TypeName.FullName);
Assert.Equal(input.GetType().GetAssemblyNameIncludingTypeForwards(), arrayRecord.TypeName.AssemblyName!.FullName);
}
Expand Down
Loading