Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NRBF] More bug fixes #107682

Merged
merged 15 commits into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ namespace System.Formats.Nrbf;
/// <remarks>
/// ArrayInfo structures are described in <see href="https://learn.microsoft.com/openspecs/windows_protocols/ms-nrbf/8fac763f-e46d-43a1-b360-80eb83d2c5fb">[MS-NRBF] 2.4.2.1</see>.
/// </remarks>
[DebuggerDisplay("Length={Length}, {ArrayType}, rank={Rank}")]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This type no longer provides Length property.

[DebuggerDisplay("{ArrayType}, rank={Rank}")]
internal readonly struct ArrayInfo
{
internal const int MaxArrayLength = 2147483591; // Array.MaxLength
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,17 +41,9 @@ internal ArraySinglePrimitiveRecord(ArrayInfo arrayInfo, IReadOnlyList<T> values
public override T[] GetArray(bool allowNulls = true)
=> (T[])(_arrayNullsNotAllowed ??= (Values is T[] array ? array : Values.ToArray()));

internal override (AllowedRecordTypes allowed, PrimitiveType primitiveType) GetAllowedRecordType()
{
Debug.Fail("GetAllowedRecordType should never be called on ArraySinglePrimitiveRecord");
throw new InvalidOperationException();
}
internal override (AllowedRecordTypes allowed, PrimitiveType primitiveType) GetAllowedRecordType() => throw new InvalidOperationException();

private protected override void AddValue(object value)
{
Debug.Fail("AddValue should never be called on ArraySinglePrimitiveRecord");
throw new InvalidOperationException();
}
private protected override void AddValue(object value) => throw new InvalidOperationException();

internal static IReadOnlyList<T> DecodePrimitiveTypes(BinaryReader reader, int count)
{
Expand Down Expand Up @@ -94,7 +86,7 @@ internal static IReadOnlyList<T> DecodePrimitiveTypes(BinaryReader reader, int c
#if NET
reader.BaseStream.ReadExactly(resultAsBytes);
#else
byte[] bytes = ArrayPool<byte>.Shared.Rent(Math.Min(count * Unsafe.SizeOf<T>(), 256_000));
byte[] bytes = ArrayPool<byte>.Shared.Rent((int)Math.Min(requiredBytes, 256_000));
JeremyKuhne marked this conversation as resolved.
Show resolved Hide resolved

while (!resultAsBytes.IsEmpty)
{
Expand Down Expand Up @@ -159,31 +151,10 @@ internal static IReadOnlyList<T> DecodePrimitiveTypes(BinaryReader reader, int c
private static List<decimal> DecodeDecimals(BinaryReader reader, int count)
{
List<decimal> values = new();
#if NET
Span<byte> buffer = stackalloc byte[256];
for (int i = 0; i < count; i++)
{
int stringLength = reader.Read7BitEncodedInt();
if (!(stringLength > 0 && stringLength <= buffer.Length))
{
ThrowHelper.ThrowInvalidValue(stringLength);
}

reader.BaseStream.ReadExactly(buffer.Slice(0, stringLength));

if (!decimal.TryParse(buffer.Slice(0, stringLength), NumberStyles.Number, CultureInfo.InvariantCulture, out decimal value))
{
ThrowHelper.ThrowInvalidFormat();
}

values.Add(value);
}
#else
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this perf optimization was wrong because it was not checking if given string input is a valid utf8 string

for (int i = 0; i < count; i++)
{
values.Add(reader.ParseDecimal());
}
#endif
return values;
}

Expand Down Expand Up @@ -244,12 +215,14 @@ private static List<T> DecodeFromNonSeekableStream(BinaryReader reader, int coun
{
values.Add((T)(object)Utils.BinaryReaderExtensions.CreateDateTimeFromData(reader.ReadUInt64()));
}
else
else if (typeof(T) == typeof(TimeSpan))
{
Debug.Assert(typeof(T) == typeof(TimeSpan));

values.Add((T)(object)new TimeSpan(reader.ReadInt64()));
}
else
{
throw new InvalidOperationException();
}
}

return values;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ internal sealed class ArraySingleStringRecord : SZArrayRecord<string?>
public override SerializationRecordType RecordType => SerializationRecordType.ArraySingleString;

/// <inheritdoc />
public override TypeName TypeName => TypeNameHelpers.GetPrimitiveSZArrayTypeName(PrimitiveType.String);
public override TypeName TypeName => TypeNameHelpers.GetPrimitiveSZArrayTypeName(TypeNameHelpers.StringPrimitiveType);

private List<SerializationRecord> Records { get; }

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ internal static ArrayRecord Decode(BinaryReader reader, RecordMap recordMap, Pay
lengths[i] = ArrayInfo.ParseValidArrayLength(reader);
totalElementCount *= lengths[i];

if (totalElementCount > uint.MaxValue)
if (totalElementCount > ArrayInfo.MaxArrayLength)
JeremyKuhne marked this conversation as resolved.
Show resolved Hide resolved
{
ThrowHelper.ThrowInvalidValue(lengths[i]); // max array size exceeded
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,7 @@ private BinaryLibraryRecord(SerializationRecordId libraryId, AssemblyNameInfo li

public override SerializationRecordType RecordType => SerializationRecordType.BinaryLibrary;

public override TypeName TypeName
{
get
{
Debug.Fail("TypeName should never be called on BinaryLibraryRecord");
return TypeName.Parse(nameof(BinaryLibraryRecord).AsSpan());
}
}
public override TypeName TypeName => TypeName.Parse(nameof(BinaryLibraryRecord).AsSpan());

internal string? RawLibraryName { get; }

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,14 @@ internal static MemberTypeInfo Decode(BinaryReader reader, int count, PayloadOpt
case BinaryType.Class:
info[i] = (type, ClassTypeInfo.Decode(reader, options, recordMap));
break;
default:
// Other types have no additional data.
Debug.Assert(type is BinaryType.String or BinaryType.ObjectArray or BinaryType.StringArray or BinaryType.Object);
case BinaryType.String:
case BinaryType.StringArray:
case BinaryType.Object:
case BinaryType.ObjectArray:
// These types have no additional data.
break;
default:
throw new InvalidOperationException();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW, InvalidOperationException is for when a method is invoked on an object and the object has state that doesn't support it. For example, calling Stream.Read when Stream.CanRead says false. The object is not in a state where you can call the method.

Many of yours are probably more likely InvalidDataException, where the default message is "Found invalid data while decoding."

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Immo corrected me. Stream.Read when Stream.CanRead is false is NotSupportedException (you should have checked the property on the stream). InvalidOperationException is for controllable state that was controlled wrong, like an object where you have to set 3 properties before calling "DoStuff", and you called DoStuff without setting all 3.

But this is still most likely InvalidDataException, because the operation made sense to do, just the data that it read back was not internally consistent.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I wanted to achieve is:

  • ensure all methods that parse enums throw SerializationException when they encounter invalid value (out of range, not in the spec)
  • assume that the above is not true and always handle all enum values in switch statements in explicit way, throw in the default case (so if there is a bug in parsing logic, we don't get undefined behavior but just an exception)

My initial idea was to use UnreachableException, but it's not part of NS2.0. I decided to use InvalidOperationException and I am open to changing it.

It's being tested by the Fuzzer which does not catch InvalidOperationException and treats it as a bug:

}
catch (SerializationException) { /* Reading from the stream encountered invalid NRBF data.*/ }
catch (NotSupportedException) { /* Reading from the stream encountered unsupported records */ }
catch (DecoderFallbackException) { /* Reading from the stream encountered an invalid UTF8 sequence. */ }
catch (EndOfStreamException) { /* The end of the stream was reached before reading SerializationRecordType.MessageEnd record. */ }
catch (IOException) { /* An I/O error occurred. */ }
}

(as shown in the fuzzer comments in this PR that has discovered places where it was missing)

}
}

Expand Down Expand Up @@ -97,7 +101,8 @@ internal static MemberTypeInfo Decode(BinaryReader reader, int count, PayloadOpt
BinaryType.PrimitiveArray => (PrimitiveArray, default),
BinaryType.Class => (NonSystemClass, default),
BinaryType.SystemClass => (SystemClass, default),
_ => (ObjectArray, default)
BinaryType.ObjectArray => (ObjectArray, default),
_ => throw new InvalidOperationException()
};
}

Expand Down Expand Up @@ -144,15 +149,15 @@ internal TypeName GetArrayTypeName(ArrayInfo arrayInfo)

TypeName elementTypeName = binaryType switch
{
BinaryType.String => TypeNameHelpers.GetPrimitiveTypeName(PrimitiveType.String),
BinaryType.StringArray => TypeNameHelpers.GetPrimitiveSZArrayTypeName(PrimitiveType.String),
BinaryType.String => TypeNameHelpers.GetPrimitiveTypeName(TypeNameHelpers.StringPrimitiveType),
BinaryType.StringArray => TypeNameHelpers.GetPrimitiveSZArrayTypeName(TypeNameHelpers.StringPrimitiveType),
BinaryType.Primitive => TypeNameHelpers.GetPrimitiveTypeName((PrimitiveType)additionalInfo!),
BinaryType.PrimitiveArray => TypeNameHelpers.GetPrimitiveSZArrayTypeName((PrimitiveType)additionalInfo!),
BinaryType.Object => TypeNameHelpers.GetPrimitiveTypeName(TypeNameHelpers.ObjectPrimitiveType),
BinaryType.ObjectArray => TypeNameHelpers.GetPrimitiveSZArrayTypeName(TypeNameHelpers.ObjectPrimitiveType),
BinaryType.SystemClass => (TypeName)additionalInfo!,
BinaryType.Class => ((ClassTypeInfo)additionalInfo!).TypeName,
_ => throw new ArgumentOutOfRangeException(paramName: nameof(binaryType), actualValue: binaryType, message: null)
_ => throw new InvalidOperationException()
};

// In general, arrayRank == 1 may have two different meanings:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,5 @@ private MessageEndRecord()

public override SerializationRecordId Id => SerializationRecordId.NoId;

public override TypeName TypeName
{
get
{
Debug.Fail("TypeName should never be called on MessageEndRecord");
return TypeName.Parse(nameof(MessageEndRecord).AsSpan());
}
}
public override TypeName TypeName => TypeName.Parse(nameof(MessageEndRecord).AsSpan());
}
Original file line number Diff line number Diff line change
Expand Up @@ -69,28 +69,22 @@ public static bool StartsWithPayloadHeader(Stream stream)
return false;
}

try
byte[] buffer = new byte[SerializedStreamHeaderRecord.Size];
int offset = 0;
while (offset < buffer.Length)
{
#if NET
Span<byte> buffer = stackalloc byte[SerializedStreamHeaderRecord.Size];
stream.ReadExactly(buffer);
#else
byte[] buffer = new byte[SerializedStreamHeaderRecord.Size];
int offset = 0;
while (offset < buffer.Length)
int read = stream.Read(buffer, offset, buffer.Length - offset);
if (read == 0)
{
int read = stream.Read(buffer, offset, buffer.Length - offset);
if (read == 0)
throw new EndOfStreamException();
offset += read;
stream.Position = beginning;
return false;
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the goal of this change is to simply return false rather than throw EOSE

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to set the stream position back to the beginning before we return here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to set the stream position back to the beginning before we return here?

Great catch!

}
#endif
return StartsWithPayloadHeader(buffer);
}
finally
{
stream.Position = beginning;
offset += read;
}

bool result = StartsWithPayloadHeader(buffer);
stream.Position = beginning;
return result;
}

/// <summary>
Expand Down Expand Up @@ -241,7 +235,8 @@ private static SerializationRecord DecodeNext(BinaryReader reader, RecordMap rec
SerializationRecordType.ObjectNullMultiple => ObjectNullMultipleRecord.Decode(reader),
SerializationRecordType.ObjectNullMultiple256 => ObjectNullMultiple256Record.Decode(reader),
SerializationRecordType.SerializedStreamHeader => SerializedStreamHeaderRecord.Decode(reader),
_ => SystemClassWithMembersAndTypesRecord.Decode(reader, recordMap, options),
SerializationRecordType.SystemClassWithMembersAndTypes => SystemClassWithMembersAndTypesRecord.Decode(reader, recordMap, options),
_ => throw new InvalidOperationException()
};

recordMap.Add(record);
Expand Down Expand Up @@ -269,8 +264,8 @@ private static SerializationRecord DecodeMemberPrimitiveTypedRecord(BinaryReader
PrimitiveType.Double => new MemberPrimitiveTypedRecord<double>(reader.ReadDouble()),
PrimitiveType.Decimal => new MemberPrimitiveTypedRecord<decimal>(reader.ParseDecimal()),
PrimitiveType.DateTime => new MemberPrimitiveTypedRecord<DateTime>(Utils.BinaryReaderExtensions.CreateDateTimeFromData(reader.ReadUInt64())),
// String is handled with a record, never on it's own
_ => new MemberPrimitiveTypedRecord<TimeSpan>(new TimeSpan(reader.ReadInt64())),
PrimitiveType.TimeSpan => new MemberPrimitiveTypedRecord<TimeSpan>(new TimeSpan(reader.ReadInt64())),
_ => throw new InvalidOperationException()
};
}

Expand All @@ -295,7 +290,8 @@ private static SerializationRecord DecodeArraySinglePrimitiveRecord(BinaryReader
PrimitiveType.Double => Decode<double>(info, reader),
PrimitiveType.Decimal => Decode<decimal>(info, reader),
PrimitiveType.DateTime => Decode<DateTime>(info, reader),
_ => Decode<TimeSpan>(info, reader),
PrimitiveType.TimeSpan => Decode<TimeSpan>(info, reader),
_ => throw new InvalidOperationException()
};

static SerializationRecord Decode<T>(ArrayInfo info, BinaryReader reader) where T : unmanaged
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,5 @@ internal abstract class NullsRecord : SerializationRecord

public override SerializationRecordId Id => SerializationRecordId.NoId;

public override TypeName TypeName
{
get
{
Debug.Fail($"TypeName should never be called on {GetType().Name}");
return TypeName.Parse(GetType().Name.AsSpan());
}
}
public override TypeName TypeName => TypeName.Parse(GetType().Name.AsSpan());
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,6 @@ namespace System.Formats.Nrbf;
/// </remarks>
internal enum PrimitiveType : byte
{
/// <summary>
/// Used internally to express no value
/// </summary>
None = 0,
Boolean = 1,
Byte = 2,
Char = 3,
Expand All @@ -30,7 +26,19 @@ internal enum PrimitiveType : byte
DateTime = 13,
UInt16 = 14,
UInt32 = 15,
UInt64 = 16,
Null = 17,
String = 18
UInt64 = 16
// This internal enum no longer contains Null and String as they were always illegal:
// - In case of BinaryArray (NRBF 2.4.3.1):
// "If the BinaryTypeEnum value is Primitive, the PrimitiveTypeEnumeration
// value in AdditionalTypeInfo MUST NOT be Null (17) or String (18)."
// - In case of MemberPrimitiveTyped (NRBF 2.5.1):
// "PrimitiveTypeEnum (1 byte): A PrimitiveTypeEnumeration
// value that specifies the Primitive Type of data that is being transmitted.
// This field MUST NOT contain a value of 17 (Null) or 18 (String)."
// - In case of ArraySinglePrimitive (NRBF 2.4.3.3):
// "A PrimitiveTypeEnumeration value that identifies the Primitive Type
// of the items of the Array. The value MUST NOT be 17 (Null) or 18 (String)."
// - In case of MemberTypeInfo (NRBF 2.3.1.2):
// "When the BinaryTypeEnum value is Primitive, the PrimitiveTypeEnumeration
// value in AdditionalInfo MUST NOT be Null (17) or String (18)."
}
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ internal void Add(SerializationRecord record)
return;
}
#endif
throw new SerializationException(SR.Format(SR.Serialization_DuplicateSerializationRecordId, record.Id));
throw new SerializationException(SR.Format(SR.Serialization_DuplicateSerializationRecordId, record.Id._id));
JeremyKuhne marked this conversation as resolved.
Show resolved Hide resolved
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ namespace System.Formats.Nrbf;
internal sealed class RectangularArrayRecord : ArrayRecord
{
private readonly int[] _lengths;
private readonly ICollection<object> _values;
private readonly List<object> _values;
private TypeName? _typeName;

private RectangularArrayRecord(Type elementType, ArrayInfo arrayInfo,
Expand All @@ -24,18 +24,8 @@ private RectangularArrayRecord(Type elementType, ArrayInfo arrayInfo,
MemberTypeInfo = memberTypeInfo;
_lengths = lengths;

// A List<T> can hold as many objects as an array, so for multi-dimensional arrays
// with more elements than Array.MaxLength we use LinkedList.
// Testing that many elements takes a LOT of time, so to ensure that both code paths are tested,
// we always use LinkedList code path for Debug builds.
#if DEBUG
_values = new LinkedList<object>();
#else
_values = arrayInfo.TotalElementsCount <= ArrayInfo.MaxArrayLength
? new List<object>(canPreAllocate ? arrayInfo.GetSZArrayLength() : Math.Min(4, arrayInfo.GetSZArrayLength()))
: new LinkedList<object>();
#endif

// ArrayInfo.GetSZArrayLength ensures to return a value <= Array.MaxLength
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since now we don't allow for any MD array with number of elements > Array.MaxLength, we don't need to use LinkedList anymore

_values = new List<object>(canPreAllocate ? arrayInfo.GetSZArrayLength() : Math.Min(4, arrayInfo.GetSZArrayLength()));
}

public override SerializationRecordType RecordType => SerializationRecordType.BinaryArray;
Expand Down Expand Up @@ -108,6 +98,7 @@ private protected override Array Deserialize(Type arrayType, bool allowNulls)
else if (ElementType == typeof(TimeSpan)) CopyTo<TimeSpan>(_values, result);
else if (ElementType == typeof(DateTime)) CopyTo<DateTime>(_values, result);
else if (ElementType == typeof(decimal)) CopyTo<decimal>(_values, result);
else throw new InvalidOperationException();
}
else
{
Expand All @@ -116,7 +107,7 @@ private protected override Array Deserialize(Type arrayType, bool allowNulls)

return result;

static void CopyTo<T>(ICollection<object> list, Array array)
static void CopyTo<T>(List<object> list, Array array)
{
ref byte arrayDataRef = ref MemoryMarshal.GetArrayDataReference(array);
ref T firstElementRef = ref Unsafe.As<byte, T>(ref arrayDataRef);
Expand Down Expand Up @@ -176,7 +167,10 @@ internal static RectangularArrayRecord Create(BinaryReader reader, ArrayInfo arr
PrimitiveType.Int64 => sizeof(long),
PrimitiveType.UInt64 => sizeof(ulong),
PrimitiveType.Double => sizeof(double),
_ => -1
PrimitiveType.TimeSpan => sizeof(ulong),
PrimitiveType.DateTime => sizeof(ulong),
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is an improvement that I got into when I was making all the switch statements throw in the default case, we were not handling TimeSpan and DateTime properly before

PrimitiveType.Decimal => -1, // represented as variable-length string
_ => throw new InvalidOperationException()
};

if (sizeOfSingleValue > 0)
Expand Down Expand Up @@ -215,7 +209,8 @@ private static Type MapPrimitive(PrimitiveType primitiveType)
PrimitiveType.DateTime => typeof(DateTime),
PrimitiveType.UInt16 => typeof(ushort),
PrimitiveType.UInt32 => typeof(uint),
_ => typeof(ulong)
PrimitiveType.UInt64 => typeof(ulong),
_ => throw new InvalidOperationException()
};

private static Type MapPrimitiveArray(PrimitiveType primitiveType)
Expand All @@ -235,7 +230,8 @@ private static Type MapPrimitiveArray(PrimitiveType primitiveType)
PrimitiveType.DateTime => typeof(DateTime[]),
PrimitiveType.UInt16 => typeof(ushort[]),
PrimitiveType.UInt32 => typeof(uint[]),
_ => typeof(ulong[]),
PrimitiveType.UInt64 => typeof(ulong[]),
_ => throw new InvalidOperationException()
};

private static object? GetActualValue(object value)
Expand Down
Loading
Loading