Skip to content

Commit

Permalink
[NRBF] Address issues discovered by Threat Model (dotnet#106629)
Browse files Browse the repository at this point in the history
* introduce ArrayRecord.FlattenedLength

* do not include invalid Type or Assembly names in the exception messages, as it's most likely corrupted/tampered/malicious data and could be used as a vector of attack.

* It is possible to have binary array records have an element type of array without being marked as jagged
  • Loading branch information
adamsitnik authored and sirntar committed Sep 30, 2024
1 parent 176877f commit ed5aafb
Show file tree
Hide file tree
Showing 13 changed files with 197 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ public abstract partial class ArrayRecord : System.Formats.Nrbf.SerializationRec
internal ArrayRecord() { }
public override System.Formats.Nrbf.SerializationRecordId Id { get { throw null; } }
public abstract System.ReadOnlySpan<int> Lengths { get; }
public virtual long FlattenedLength { get; }
public int Rank { get { throw null; } }
[System.Diagnostics.CodeAnalysis.RequiresDynamicCode("The code for an array of the specified type might not be available.")]
public System.Array GetArray(System.Type expectedArrayType, bool allowNulls = true) { throw null; }
Expand Down
11 changes: 4 additions & 7 deletions src/libraries/System.Formats.Nrbf/src/Resources/Strings.resx
Original file line number Diff line number Diff line change
Expand Up @@ -126,26 +126,23 @@
<data name="Serialization_UnexpectedNullRecordCount" xml:space="preserve">
<value>Unexpected Null Record count.</value>
</data>
<data name="Serialization_MaxArrayLength" xml:space="preserve">
<value>The serialized array length ({0}) was larger than the configured limit {1}.</value>
</data>
<data name="NotSupported_RecordType" xml:space="preserve">
<value>{0} Record Type is not supported by design.</value>
</data>
<data name="Serialization_InvalidReference" xml:space="preserve">
<value>Invalid member reference.</value>
</data>
<data name="Serialization_InvalidTypeName" xml:space="preserve">
<value>Invalid type name: `{0}`.</value>
<value>Invalid type name.</value>
</data>
<data name="Serialization_TypeMismatch" xml:space="preserve">
<value>Expected the array to be of type {0}, but its element type was {1}.</value>
</data>
<data name="Serialization_InvalidTypeOrAssemblyName" xml:space="preserve">
<value>Invalid type or assembly name: `{0},{1}`.</value>
<value>Invalid type or assembly name.</value>
</data>
<data name="Serialization_DuplicateMemberName" xml:space="preserve">
<value>Duplicate member name: `{0}`.</value>
<value>Duplicate member name.</value>
</data>
<data name="Argument_NonSeekableStream" xml:space="preserve">
<value>Stream does not support seeking.</value>
Expand All @@ -160,7 +157,7 @@
<value>Only arrays with zero offsets are supported.</value>
</data>
<data name="Serialization_InvalidAssemblyName" xml:space="preserve">
<value>Invalid assembly name: `{0}`.</value>
<value>Invalid assembly name.</value>
</data>
<data name="Serialization_InvalidFormat" xml:space="preserve">
<value>Invalid format.</value>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,23 +25,23 @@ internal readonly struct ArrayInfo
internal ArrayInfo(SerializationRecordId id, long totalElementsCount, BinaryArrayType arrayType = BinaryArrayType.Single, int rank = 1)
{
Id = id;
TotalElementsCount = totalElementsCount;
FlattenedLength = totalElementsCount;
ArrayType = arrayType;
Rank = rank;
}

internal SerializationRecordId Id { get; }

internal long TotalElementsCount { get; }
internal long FlattenedLength { get; }

internal BinaryArrayType ArrayType { get; }

internal int Rank { get; }

internal int GetSZArrayLength()
{
Debug.Assert(TotalElementsCount <= MaxArrayLength);
return (int)TotalElementsCount;
Debug.Assert(FlattenedLength <= MaxArrayLength);
return (int)FlattenedLength;
}

internal static ArrayInfo Decode(BinaryReader reader)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ public abstract class ArrayRecord : SerializationRecord
private protected ArrayRecord(ArrayInfo arrayInfo)
{
ArrayInfo = arrayInfo;
ValuesToRead = arrayInfo.TotalElementsCount;
ValuesToRead = arrayInfo.FlattenedLength;
}

/// <summary>
Expand All @@ -27,6 +27,12 @@ private protected ArrayRecord(ArrayInfo arrayInfo)
/// <value>A buffer of integers that represent the number of elements in every dimension.</value>
public abstract ReadOnlySpan<int> Lengths { get; }

/// <summary>
/// When overridden in a derived class, gets the total number of all elements in every dimension.
/// </summary>
/// <value>A number that represent the total number of all elements in every dimension.</value>
public virtual long FlattenedLength => ArrayInfo.FlattenedLength;

/// <summary>
/// Gets the rank of the array.
/// </summary>
Expand All @@ -44,7 +50,12 @@ private protected ArrayRecord(ArrayInfo arrayInfo)

internal long ValuesToRead { get; private protected set; }

private protected ArrayInfo ArrayInfo { get; }
internal ArrayInfo ArrayInfo { get; }

internal bool IsJagged
=> ArrayInfo.ArrayType == BinaryArrayType.Jagged
// It is possible to have binary array records have an element type of array without being marked as jagged.
|| TypeName.GetElementType().IsArray;

/// <summary>
/// Allocates an array and fills it with the data provided in the serialized records (in case of primitive types like <see cref="string"/> or <see cref="int"/>) or the serialized records themselves.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,38 @@ internal sealed class BinaryArrayRecord : ArrayRecord
];

private TypeName? _typeName;
private long _totalElementsCount;

private BinaryArrayRecord(ArrayInfo arrayInfo, MemberTypeInfo memberTypeInfo)
: base(arrayInfo)
{
MemberTypeInfo = memberTypeInfo;
Values = [];
// We need to parse all elements of the jagged array to obtain total elements count.
_totalElementsCount = -1;
}

public override SerializationRecordType RecordType => SerializationRecordType.BinaryArray;

/// <inheritdoc/>
public override ReadOnlySpan<int> Lengths => new int[1] { Length };

/// <inheritdoc/>
public override long FlattenedLength
{
get
{
if (_totalElementsCount < 0)
{
_totalElementsCount = IsJagged
? GetJaggedArrayFlattenedLength(this)
: ArrayInfo.FlattenedLength;
}

return _totalElementsCount;
}
}

public override TypeName TypeName
=> _typeName ??= MemberTypeInfo.GetArrayTypeName(ArrayInfo);

Expand Down Expand Up @@ -174,6 +193,65 @@ internal static ArrayRecord Decode(BinaryReader reader, RecordMap recordMap, Pay
: new BinaryArrayRecord(arrayInfo, memberTypeInfo);
}

private static long GetJaggedArrayFlattenedLength(BinaryArrayRecord jaggedArrayRecord)
{
long result = 0;
Queue<BinaryArrayRecord>? jaggedArrayRecords = null;

do
{
if (jaggedArrayRecords is not null)
{
jaggedArrayRecord = jaggedArrayRecords.Dequeue();
}

Debug.Assert(jaggedArrayRecord.IsJagged);

// In theory somebody could create a payload that would represent
// a very nested array with total elements count > long.MaxValue.
// That is why this method is using checked arithmetic.
result = checked(result + jaggedArrayRecord.Length); // count the arrays themselves

foreach (object value in jaggedArrayRecord.Values)
{
if (value is not SerializationRecord record)
{
continue;
}

if (record.RecordType == SerializationRecordType.MemberReference)
{
record = ((MemberReferenceRecord)record).GetReferencedRecord();
}

switch (record.RecordType)
{
case SerializationRecordType.ArraySinglePrimitive:
case SerializationRecordType.ArraySingleObject:
case SerializationRecordType.ArraySingleString:
case SerializationRecordType.BinaryArray:
ArrayRecord nestedArrayRecord = (ArrayRecord)record;
if (nestedArrayRecord.IsJagged)
{
(jaggedArrayRecords ??= new()).Enqueue((BinaryArrayRecord)nestedArrayRecord);
}
else
{
// Don't call nestedArrayRecord.FlattenedLength to avoid any potential recursion,
// just call nestedArrayRecord.ArrayInfo.FlattenedLength that returns pre-computed value.
result = checked(result + nestedArrayRecord.ArrayInfo.FlattenedLength);
}
break;
default:
break;
}
}
}
while (jaggedArrayRecords is not null && jaggedArrayRecords.Count > 0);

return result;
}

private protected override void AddValue(object value) => Values.Add(value);

internal override (AllowedRecordTypes allowed, PrimitiveType primitiveType) GetAllowedRecordType()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ internal static BinaryLibraryRecord Decode(BinaryReader reader, PayloadOptions o
}
else if (!options.UndoTruncatedTypeNames)
{
ThrowHelper.ThrowInvalidAssemblyName(rawName);
ThrowHelper.ThrowInvalidAssemblyName();
}

return new BinaryLibraryRecord(id, rawName);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ internal static ClassInfo Decode(BinaryReader reader)
continue;
}
#endif
throw new SerializationException(SR.Format(SR.Serialization_DuplicateMemberName, memberName));
ThrowHelper.ThrowDuplicateMemberName();
}

return new ClassInfo(id, typeName, memberNames);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,11 +193,11 @@ internal static RectangularArrayRecord Create(BinaryReader reader, ArrayInfo arr
// to encountering an EOF if we realize later that we actually need to read more bytes in
// order to fully populate the char[,,,...] array. Any such allocation is still linearly
// proportional to the length of the incoming payload, so it's not a DoS vector.
// The multiplication below is guaranteed not to overflow because TotalElementsCount is bounded
// to <= uint.MaxValue (see BinaryArrayRecord.Decode) and sizeOfSingleValue is at most 8.
Debug.Assert(arrayInfo.TotalElementsCount >= 0 && arrayInfo.TotalElementsCount <= long.MaxValue / sizeOfSingleValue);
// The multiplication below is guaranteed not to overflow because FlattenedLength is bounded
// to <= Array.MaxLength (see BinaryArrayRecord.Decode) and sizeOfSingleValue is at most 8.
Debug.Assert(arrayInfo.FlattenedLength >= 0 && arrayInfo.FlattenedLength <= long.MaxValue / sizeOfSingleValue);

long size = arrayInfo.TotalElementsCount * sizeOfSingleValue;
long size = arrayInfo.FlattenedLength * sizeOfSingleValue;
bool? isDataAvailable = reader.IsDataAvailable(size);
if (isDataAvailable.HasValue)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,28 +6,30 @@

namespace System.Formats.Nrbf.Utils;

// The exception messages do not contain member/type/assembly names on purpose,
// as it's most likely corrupted/tampered/malicious data.
internal static class ThrowHelper
{
internal static void ThrowInvalidValue(object value)
internal static void ThrowDuplicateMemberName()
=> throw new SerializationException(SR.Serialization_DuplicateMemberName);

internal static void ThrowInvalidValue(int value)
=> throw new SerializationException(SR.Format(SR.Serialization_InvalidValue, value));

internal static void ThrowInvalidReference()
=> throw new SerializationException(SR.Serialization_InvalidReference);

internal static void ThrowInvalidTypeName(string name)
=> throw new SerializationException(SR.Format(SR.Serialization_InvalidTypeName, name));
internal static void ThrowInvalidTypeName()
=> throw new SerializationException(SR.Serialization_InvalidTypeName);

internal static void ThrowUnexpectedNullRecordCount()
=> throw new SerializationException(SR.Serialization_UnexpectedNullRecordCount);

internal static void ThrowMaxArrayLength(long limit, long actual)
=> throw new SerializationException(SR.Format(SR.Serialization_MaxArrayLength, actual, limit));

internal static void ThrowArrayContainedNulls()
=> throw new SerializationException(SR.Serialization_ArrayContainedNulls);

internal static void ThrowInvalidAssemblyName(string rawName)
=> throw new SerializationException(SR.Format(SR.Serialization_InvalidAssemblyName, rawName));
internal static void ThrowInvalidAssemblyName()
=> throw new SerializationException(SR.Serialization_InvalidAssemblyName);

internal static void ThrowInvalidFormat()
=> throw new SerializationException(SR.Serialization_InvalidFormat);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ internal static TypeName ParseNonSystemClassRecordTypeName(this string rawName,

if (typeName is null)
{
throw new SerializationException(SR.Format(SR.Serialization_InvalidTypeOrAssemblyName, rawName, libraryRecord.RawLibraryName));
throw new SerializationException(SR.Serialization_InvalidTypeOrAssemblyName);
}

if (typeName.AssemblyName is null)
Expand Down Expand Up @@ -183,7 +183,7 @@ private static TypeName With(this TypeName typeName, AssemblyNameInfo assemblyNa
else
{
// BinaryFormatter can not serialize pointers or references.
ThrowHelper.ThrowInvalidTypeName(typeName.FullName);
ThrowHelper.ThrowInvalidTypeName();
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ private void Test<T>(int size, bool canSeek)
SZArrayRecord<T> arrayRecord = (SZArrayRecord<T>)NrbfDecoder.Decode(stream);

Assert.Equal(size, arrayRecord.Length);
Assert.Equal(size, arrayRecord.FlattenedLength);
T?[] output = arrayRecord.GetArray();
Assert.Equal(input, output);
Assert.Same(output, arrayRecord.GetArray());
Expand Down
Loading

0 comments on commit ed5aafb

Please sign in to comment.