Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NRBF] More bug fixes #107682

Merged
merged 15 commits into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ namespace System.Formats.Nrbf;
/// <remarks>
/// ArrayInfo structures are described in <see href="https://learn.microsoft.com/openspecs/windows_protocols/ms-nrbf/8fac763f-e46d-43a1-b360-80eb83d2c5fb">[MS-NRBF] 2.4.2.1</see>.
/// </remarks>
[DebuggerDisplay("Length={Length}, {ArrayType}, rank={Rank}")]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This type no longer provides Length property.

[DebuggerDisplay("{ArrayType}, rank={Rank}")]
internal readonly struct ArrayInfo
{
internal const int MaxArrayLength = 2147483591; // Array.MaxLength
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,17 +41,9 @@ internal ArraySinglePrimitiveRecord(ArrayInfo arrayInfo, IReadOnlyList<T> values
public override T[] GetArray(bool allowNulls = true)
=> (T[])(_arrayNullsNotAllowed ??= (Values is T[] array ? array : Values.ToArray()));

internal override (AllowedRecordTypes allowed, PrimitiveType primitiveType) GetAllowedRecordType()
{
Debug.Fail("GetAllowedRecordType should never be called on ArraySinglePrimitiveRecord");
throw new InvalidOperationException();
}
internal override (AllowedRecordTypes allowed, PrimitiveType primitiveType) GetAllowedRecordType() => throw new InvalidOperationException();

private protected override void AddValue(object value)
{
Debug.Fail("AddValue should never be called on ArraySinglePrimitiveRecord");
throw new InvalidOperationException();
}
private protected override void AddValue(object value) => throw new InvalidOperationException();

internal static IReadOnlyList<T> DecodePrimitiveTypes(BinaryReader reader, int count)
{
Expand Down Expand Up @@ -94,7 +86,7 @@ internal static IReadOnlyList<T> DecodePrimitiveTypes(BinaryReader reader, int c
#if NET
reader.BaseStream.ReadExactly(resultAsBytes);
#else
byte[] bytes = ArrayPool<byte>.Shared.Rent(Math.Min(count * Unsafe.SizeOf<T>(), 256_000));
byte[] bytes = ArrayPool<byte>.Shared.Rent((int)Math.Min(requiredBytes, 256_000));
JeremyKuhne marked this conversation as resolved.
Show resolved Hide resolved

while (!resultAsBytes.IsEmpty)
{
Expand Down Expand Up @@ -159,31 +151,10 @@ internal static IReadOnlyList<T> DecodePrimitiveTypes(BinaryReader reader, int c
private static List<decimal> DecodeDecimals(BinaryReader reader, int count)
{
List<decimal> values = new();
#if NET
Span<byte> buffer = stackalloc byte[256];
for (int i = 0; i < count; i++)
{
int stringLength = reader.Read7BitEncodedInt();
if (!(stringLength > 0 && stringLength <= buffer.Length))
{
ThrowHelper.ThrowInvalidValue(stringLength);
}

reader.BaseStream.ReadExactly(buffer.Slice(0, stringLength));

if (!decimal.TryParse(buffer.Slice(0, stringLength), NumberStyles.Number, CultureInfo.InvariantCulture, out decimal value))
{
ThrowHelper.ThrowInvalidFormat();
}

values.Add(value);
}
#else
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this perf optimization was wrong because it was not checking if given string input is a valid utf8 string

for (int i = 0; i < count; i++)
{
values.Add(reader.ParseDecimal());
}
#endif
return values;
}

Expand Down Expand Up @@ -244,12 +215,14 @@ private static List<T> DecodeFromNonSeekableStream(BinaryReader reader, int coun
{
values.Add((T)(object)Utils.BinaryReaderExtensions.CreateDateTimeFromData(reader.ReadUInt64()));
}
else
else if (typeof(T) == typeof(TimeSpan))
{
Debug.Assert(typeof(T) == typeof(TimeSpan));

values.Add((T)(object)new TimeSpan(reader.ReadInt64()));
}
else
{
throw new InvalidOperationException();
}
}

return values;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ internal static ArrayRecord Decode(BinaryReader reader, RecordMap recordMap, Pay
lengths[i] = ArrayInfo.ParseValidArrayLength(reader);
totalElementCount *= lengths[i];

if (totalElementCount > uint.MaxValue)
if (totalElementCount > ArrayInfo.MaxArrayLength)
JeremyKuhne marked this conversation as resolved.
Show resolved Hide resolved
{
ThrowHelper.ThrowInvalidValue(lengths[i]); // max array size exceeded
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,7 @@ private BinaryLibraryRecord(SerializationRecordId libraryId, AssemblyNameInfo li

public override SerializationRecordType RecordType => SerializationRecordType.BinaryLibrary;

public override TypeName TypeName
{
get
{
Debug.Fail("TypeName should never be called on BinaryLibraryRecord");
return TypeName.Parse(nameof(BinaryLibraryRecord).AsSpan());
}
}
public override TypeName TypeName => TypeName.Parse(nameof(BinaryLibraryRecord).AsSpan());

internal string? RawLibraryName { get; }

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,14 @@ internal static MemberTypeInfo Decode(BinaryReader reader, int count, PayloadOpt
case BinaryType.Class:
info[i] = (type, ClassTypeInfo.Decode(reader, options, recordMap));
break;
default:
// Other types have no additional data.
Debug.Assert(type is BinaryType.String or BinaryType.ObjectArray or BinaryType.StringArray or BinaryType.Object);
case BinaryType.String:
case BinaryType.StringArray:
case BinaryType.Object:
case BinaryType.ObjectArray:
// These types have no additional data.
break;
default:
throw new InvalidOperationException();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW, InvalidOperationException is for when a method is invoked on an object and the object has state that doesn't support it. For example, calling Stream.Read when Stream.CanRead says false. The object is not in a state where you can call the method.

Many of yours are probably more likely InvalidDataException, where the default message is "Found invalid data while decoding."

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Immo corrected me. Stream.Read when Stream.CanRead is false is NotSupportedException (you should have checked the property on the stream). InvalidOperationException is for controllable state that was controlled wrong, like an object where you have to set 3 properties before calling "DoStuff", and you called DoStuff without setting all 3.

But this is still most likely InvalidDataException, because the operation made sense to do, just the data that it read back was not internally consistent.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I wanted to achieve is:

  • ensure all methods that parse enums throw SerializationException when they encounter invalid value (out of range, not in the spec)
  • assume that the above is not true and always handle all enum values in switch statements in explicit way, throw in the default case (so if there is a bug in parsing logic, we don't get undefined behavior but just an exception)

My initial idea was to use UnreachableException, but it's not part of NS2.0. I decided to use InvalidOperationException and I am open to changing it.

It's being tested by the Fuzzer which does not catch InvalidOperationException and treats it as a bug:

}
catch (SerializationException) { /* Reading from the stream encountered invalid NRBF data.*/ }
catch (NotSupportedException) { /* Reading from the stream encountered unsupported records */ }
catch (DecoderFallbackException) { /* Reading from the stream encountered an invalid UTF8 sequence. */ }
catch (EndOfStreamException) { /* The end of the stream was reached before reading SerializationRecordType.MessageEnd record. */ }
catch (IOException) { /* An I/O error occurred. */ }
}

(as shown in the fuzzer comments in this PR that has discovered places where it was missing)

}
}

Expand Down Expand Up @@ -97,7 +101,8 @@ internal static MemberTypeInfo Decode(BinaryReader reader, int count, PayloadOpt
BinaryType.PrimitiveArray => (PrimitiveArray, default),
BinaryType.Class => (NonSystemClass, default),
BinaryType.SystemClass => (SystemClass, default),
_ => (ObjectArray, default)
BinaryType.ObjectArray => (ObjectArray, default),
_ => throw new InvalidOperationException()
};
}

Expand Down Expand Up @@ -152,7 +157,7 @@ internal TypeName GetArrayTypeName(ArrayInfo arrayInfo)
BinaryType.ObjectArray => TypeNameHelpers.GetPrimitiveSZArrayTypeName(TypeNameHelpers.ObjectPrimitiveType),
BinaryType.SystemClass => (TypeName)additionalInfo!,
BinaryType.Class => ((ClassTypeInfo)additionalInfo!).TypeName,
_ => throw new ArgumentOutOfRangeException(paramName: nameof(binaryType), actualValue: binaryType, message: null)
_ => throw new InvalidOperationException()
};

// In general, arrayRank == 1 may have two different meanings:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,5 @@ private MessageEndRecord()

public override SerializationRecordId Id => SerializationRecordId.NoId;

public override TypeName TypeName
{
get
{
Debug.Fail("TypeName should never be called on MessageEndRecord");
return TypeName.Parse(nameof(MessageEndRecord).AsSpan());
}
}
public override TypeName TypeName => TypeName.Parse(nameof(MessageEndRecord).AsSpan());
}
Original file line number Diff line number Diff line change
Expand Up @@ -69,28 +69,22 @@ public static bool StartsWithPayloadHeader(Stream stream)
return false;
}

try
byte[] buffer = new byte[SerializedStreamHeaderRecord.Size];
int offset = 0;
while (offset < buffer.Length)
{
#if NET
Span<byte> buffer = stackalloc byte[SerializedStreamHeaderRecord.Size];
stream.ReadExactly(buffer);
#else
byte[] buffer = new byte[SerializedStreamHeaderRecord.Size];
int offset = 0;
while (offset < buffer.Length)
int read = stream.Read(buffer, offset, buffer.Length - offset);
if (read == 0)
{
int read = stream.Read(buffer, offset, buffer.Length - offset);
if (read == 0)
throw new EndOfStreamException();
offset += read;
stream.Position = beginning;
return false;
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the goal of this change is to simply return false rather than throw EOSE

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to set the stream position back to the beginning before we return here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to set the stream position back to the beginning before we return here?

Great catch!

}
#endif
return StartsWithPayloadHeader(buffer);
}
finally
{
stream.Position = beginning;
offset += read;
}

bool result = StartsWithPayloadHeader(buffer);
stream.Position = beginning;
return result;
}

/// <summary>
Expand Down Expand Up @@ -241,7 +235,8 @@ private static SerializationRecord DecodeNext(BinaryReader reader, RecordMap rec
SerializationRecordType.ObjectNullMultiple => ObjectNullMultipleRecord.Decode(reader),
SerializationRecordType.ObjectNullMultiple256 => ObjectNullMultiple256Record.Decode(reader),
SerializationRecordType.SerializedStreamHeader => SerializedStreamHeaderRecord.Decode(reader),
_ => SystemClassWithMembersAndTypesRecord.Decode(reader, recordMap, options),
SerializationRecordType.SystemClassWithMembersAndTypes => SystemClassWithMembersAndTypesRecord.Decode(reader, recordMap, options),
_ => throw new InvalidOperationException()
};

recordMap.Add(record);
Expand Down Expand Up @@ -269,8 +264,10 @@ private static SerializationRecord DecodeMemberPrimitiveTypedRecord(BinaryReader
PrimitiveType.Double => new MemberPrimitiveTypedRecord<double>(reader.ReadDouble()),
PrimitiveType.Decimal => new MemberPrimitiveTypedRecord<decimal>(reader.ParseDecimal()),
PrimitiveType.DateTime => new MemberPrimitiveTypedRecord<DateTime>(Utils.BinaryReaderExtensions.CreateDateTimeFromData(reader.ReadUInt64())),
// String is handled with a record, never on it's own
_ => new MemberPrimitiveTypedRecord<TimeSpan>(new TimeSpan(reader.ReadInt64())),
PrimitiveType.TimeSpan => new MemberPrimitiveTypedRecord<TimeSpan>(new TimeSpan(reader.ReadInt64())),
// PrimitiveType.String is handled with a record, never on it's own
PrimitiveType.String => throw new SerializationException(SR.Format(SR.Serialization_InvalidValue, primitiveType)),
_ => throw new InvalidOperationException()
};
}

Expand All @@ -295,7 +292,10 @@ private static SerializationRecord DecodeArraySinglePrimitiveRecord(BinaryReader
PrimitiveType.Double => Decode<double>(info, reader),
PrimitiveType.Decimal => Decode<decimal>(info, reader),
PrimitiveType.DateTime => Decode<DateTime>(info, reader),
_ => Decode<TimeSpan>(info, reader),
PrimitiveType.TimeSpan => Decode<TimeSpan>(info, reader),
// PrimitiveType.String is handled with a record, never on it's own
PrimitiveType.String => throw new SerializationException(SR.Format(SR.Serialization_InvalidValue, primitiveType)),
_ => throw new InvalidOperationException()
};

static SerializationRecord Decode<T>(ArrayInfo info, BinaryReader reader) where T : unmanaged
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,5 @@ internal abstract class NullsRecord : SerializationRecord

public override SerializationRecordId Id => SerializationRecordId.NoId;

public override TypeName TypeName
{
get
{
Debug.Fail($"TypeName should never be called on {GetType().Name}");
return TypeName.Parse(GetType().Name.AsSpan());
}
}
public override TypeName TypeName => TypeName.Parse(GetType().Name.AsSpan());
}
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ internal void Add(SerializationRecord record)
return;
}
#endif
throw new SerializationException(SR.Format(SR.Serialization_DuplicateSerializationRecordId, record.Id));
throw new SerializationException(SR.Format(SR.Serialization_DuplicateSerializationRecordId, record.Id._id));
JeremyKuhne marked this conversation as resolved.
Show resolved Hide resolved
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ namespace System.Formats.Nrbf;
internal sealed class RectangularArrayRecord : ArrayRecord
{
private readonly int[] _lengths;
private readonly ICollection<object> _values;
private readonly List<object> _values;
private TypeName? _typeName;

private RectangularArrayRecord(Type elementType, ArrayInfo arrayInfo,
Expand All @@ -24,18 +24,8 @@ private RectangularArrayRecord(Type elementType, ArrayInfo arrayInfo,
MemberTypeInfo = memberTypeInfo;
_lengths = lengths;

// A List<T> can hold as many objects as an array, so for multi-dimensional arrays
// with more elements than Array.MaxLength we use LinkedList.
// Testing that many elements takes a LOT of time, so to ensure that both code paths are tested,
// we always use LinkedList code path for Debug builds.
#if DEBUG
_values = new LinkedList<object>();
#else
_values = arrayInfo.TotalElementsCount <= ArrayInfo.MaxArrayLength
? new List<object>(canPreAllocate ? arrayInfo.GetSZArrayLength() : Math.Min(4, arrayInfo.GetSZArrayLength()))
: new LinkedList<object>();
#endif

// ArrayInfo.GetSZArrayLength ensures to return a value <= Array.MaxLength
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since now we don't allow for any MD array with number of elements > Array.MaxLength, we don't need to use LinkedList anymore

_values = new List<object>(canPreAllocate ? arrayInfo.GetSZArrayLength() : Math.Min(4, arrayInfo.GetSZArrayLength()));
}

public override SerializationRecordType RecordType => SerializationRecordType.BinaryArray;
Expand Down Expand Up @@ -108,6 +98,7 @@ private protected override Array Deserialize(Type arrayType, bool allowNulls)
else if (ElementType == typeof(TimeSpan)) CopyTo<TimeSpan>(_values, result);
else if (ElementType == typeof(DateTime)) CopyTo<DateTime>(_values, result);
else if (ElementType == typeof(decimal)) CopyTo<decimal>(_values, result);
else throw new InvalidOperationException();
}
else
{
Expand All @@ -116,7 +107,7 @@ private protected override Array Deserialize(Type arrayType, bool allowNulls)

return result;

static void CopyTo<T>(ICollection<object> list, Array array)
static void CopyTo<T>(List<object> list, Array array)
{
ref byte arrayDataRef = ref MemoryMarshal.GetArrayDataReference(array);
ref T firstElementRef = ref Unsafe.As<byte, T>(ref arrayDataRef);
Expand Down Expand Up @@ -176,7 +167,11 @@ internal static RectangularArrayRecord Create(BinaryReader reader, ArrayInfo arr
PrimitiveType.Int64 => sizeof(long),
PrimitiveType.UInt64 => sizeof(ulong),
PrimitiveType.Double => sizeof(double),
_ => -1
PrimitiveType.TimeSpan => sizeof(ulong),
PrimitiveType.DateTime => sizeof(ulong),
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is an improvement that I got into when I was making all the switch statements throw in the default case, we were not handling TimeSpan and DateTime properly before

PrimitiveType.Decimal => -1, // represented as variable-length string
PrimitiveType.String => -1, // can't predict required size based on the number of strings
_ => throw new InvalidOperationException()
};

if (sizeOfSingleValue > 0)
Expand Down Expand Up @@ -215,7 +210,8 @@ private static Type MapPrimitive(PrimitiveType primitiveType)
PrimitiveType.DateTime => typeof(DateTime),
PrimitiveType.UInt16 => typeof(ushort),
PrimitiveType.UInt32 => typeof(uint),
_ => typeof(ulong)
PrimitiveType.UInt64 => typeof(ulong),
_ => throw new InvalidOperationException()
};

private static Type MapPrimitiveArray(PrimitiveType primitiveType)
Expand All @@ -235,7 +231,8 @@ private static Type MapPrimitiveArray(PrimitiveType primitiveType)
PrimitiveType.DateTime => typeof(DateTime[]),
PrimitiveType.UInt16 => typeof(ushort[]),
PrimitiveType.UInt32 => typeof(uint[]),
_ => typeof(ulong[]),
PrimitiveType.UInt64 => typeof(ulong[]),
_ => throw new InvalidOperationException()
};

private static object? GetActualValue(object value)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,8 @@ private static bool Matches(Type type, TypeName typeName)
internal virtual object? GetValue() => this;

internal virtual void HandleNextRecord(SerializationRecord nextRecord, NextInfo info)
=> Debug.Fail($"HandleNextRecord should not have been called for '{GetType().Name}'");
=> throw new InvalidOperationException();

internal virtual void HandleNextValue(object value, NextInfo info)
=> Debug.Fail($"HandleNextValue should not have been called for '{GetType().Name}'");
=> throw new InvalidOperationException();
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Formats.Nrbf.Utils;
using System.IO;
using System.Linq;
Expand All @@ -15,6 +16,7 @@ namespace System.Formats.Nrbf;
/// <summary>
/// The ID of <see cref="SerializationRecord" />.
/// </summary>
[DebuggerDisplay("{_id}")]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's just user experience improvement

public readonly struct SerializationRecordId : IEquatable<SerializationRecordId>
{
#pragma warning disable CS0649 // the default value is used on purpose
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,8 @@ internal sealed class SerializedStreamHeaderRecord : SerializationRecord

public override SerializationRecordType RecordType => SerializationRecordType.SerializedStreamHeader;

public override TypeName TypeName
{
get
{
Debug.Fail("TypeName should never be called on SerializedStreamHeaderRecord");
return TypeName.Parse(nameof(SerializedStreamHeaderRecord).AsSpan());
}
}
public override TypeName TypeName => TypeName.Parse(nameof(SerializedStreamHeaderRecord).AsSpan());

public override SerializationRecordId Id => SerializationRecordId.NoId;

internal SerializationRecordId RootId { get; }
Expand Down
Loading
Loading