Skip to content

Commit

Permalink
[NRBF] Fix bugs discovered by the fuzzer (dotnet#107368)
Browse files Browse the repository at this point in the history
* bug #1: don't allow for values out of the SerializationRecordType enum range

* bug #2: throw SerializationException rather than KeyNotFoundException when the referenced record is missing or it points to a record of different type

* bug #3: throw SerializationException rather than FormatException when it's being thrown by BinaryReader (or sth else that we use)

* bug #4: document the fact that IOException can be thrown

* bug #5: throw SerializationException rather than OverflowException when parsing the decimal fails

* bug #6: 0 and 17 are illegal values for PrimitiveType enum

* bug #7: throw SerializationException when a surrogate character is read (so far an ArgumentException was thrown)
  • Loading branch information
adamsitnik committed Sep 6, 2024
1 parent dc5dbab commit e79426e
Show file tree
Hide file tree
Showing 13 changed files with 197 additions and 25 deletions.
8 changes: 7 additions & 1 deletion src/libraries/System.Formats.Nrbf/src/Resources/Strings.resx
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@
<value>{0} Record Type is not supported by design.</value>
</data>
<data name="Serialization_InvalidReference" xml:space="preserve">
<value>Member reference was pointing to a record of unexpected type.</value>
<value>Invalid member reference.</value>
</data>
<data name="Serialization_InvalidTypeName" xml:space="preserve">
<value>Invalid type name: `{0}`.</value>
Expand Down Expand Up @@ -162,4 +162,10 @@
<data name="Serialization_InvalidAssemblyName" xml:space="preserve">
<value>Invalid assembly name: `{0}`.</value>
</data>
<data name="Serialization_InvalidFormat" xml:space="preserve">
<value>Invalid format.</value>
</data>
<data name="Serialization_SurrogateCharacter" xml:space="preserve">
<value>A surrogate character was read.</value>
</data>
</root>
Original file line number Diff line number Diff line change
Expand Up @@ -171,12 +171,17 @@ private static List<decimal> DecodeDecimals(BinaryReader reader, int count)

reader.BaseStream.ReadExactly(buffer.Slice(0, stringLength));

values.Add(decimal.Parse(buffer.Slice(0, stringLength), CultureInfo.InvariantCulture));
if (!decimal.TryParse(buffer.Slice(0, stringLength), NumberStyles.Number, CultureInfo.InvariantCulture, out decimal value))
{
ThrowHelper.ThrowInvalidFormat();
}

values.Add(value);
}
#else
for (int i = 0; i < count; i++)
{
values.Add(decimal.Parse(reader.ReadString(), CultureInfo.InvariantCulture));
values.Add(reader.ParseDecimal());
}
#endif
return values;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ internal static ClassTypeInfo Decode(BinaryReader reader, PayloadOptions options
string rawName = reader.ReadString();
SerializationRecordId libraryId = SerializationRecordId.Decode(reader);

BinaryLibraryRecord library = (BinaryLibraryRecord)recordMap[libraryId];
BinaryLibraryRecord library = recordMap.GetRecord<BinaryLibraryRecord>(libraryId);

return new ClassTypeInfo(rawName.ParseNonSystemClassRecordTypeName(library, options));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,7 @@ internal static ClassWithIdRecord Decode(
SerializationRecordId id = SerializationRecordId.Decode(reader);
SerializationRecordId metadataId = SerializationRecordId.Decode(reader);

if (recordMap[metadataId] is not ClassRecord referencedRecord)
{
throw new SerializationException(SR.Serialization_InvalidReference);
}
ClassRecord referencedRecord = recordMap.GetRecord<ClassRecord>(metadataId);

return new ClassWithIdRecord(id, referencedRecord);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ internal static ClassWithMembersAndTypesRecord Decode(BinaryReader reader, Recor
MemberTypeInfo memberTypeInfo = MemberTypeInfo.Decode(reader, classInfo.MemberNames.Count, options, recordMap);
SerializationRecordId libraryId = SerializationRecordId.Decode(reader);

BinaryLibraryRecord library = (BinaryLibraryRecord)recordMap[libraryId];
BinaryLibraryRecord library = recordMap.GetRecord<BinaryLibraryRecord>(libraryId);
classInfo.LoadTypeName(library, options);

return new ClassWithMembersAndTypesRecord(classInfo, memberTypeInfo);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,5 +38,5 @@ private MemberReferenceRecord(SerializationRecordId reference, RecordMap recordM
internal static MemberReferenceRecord Decode(BinaryReader reader, RecordMap recordMap)
=> new(SerializationRecordId.Decode(reader), recordMap);

internal SerializationRecord GetReferencedRecord() => RecordMap[Reference];
internal SerializationRecord GetReferencedRecord() => RecordMap.GetRecord(Reference);
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ public static bool StartsWithPayloadHeader(ReadOnlySpan<byte> bytes)
/// <exception cref="ArgumentNullException"><paramref name="stream" /> is <see langword="null" />.</exception>
/// <exception cref="NotSupportedException">The stream does not support reading or seeking.</exception>
/// <exception cref="ObjectDisposedException">The stream was closed.</exception>
/// <exception cref="IOException">An I/O error occurred.</exception>
/// <remarks>When this method returns, <paramref name="stream" /> is restored to its original position.</remarks>
public static bool StartsWithPayloadHeader(Stream stream)
{
Expand Down Expand Up @@ -107,6 +108,7 @@ public static bool StartsWithPayloadHeader(Stream stream)
/// <exception cref="ArgumentNullException"><paramref name="payload"/> is <see langword="null" />.</exception>
/// <exception cref="ArgumentException"><paramref name="payload"/> does not support reading or is already closed.</exception>
/// <exception cref="SerializationException">Reading from <paramref name="payload"/> encountered invalid NRBF data.</exception>
/// <exception cref="IOException">An I/O error occurred.</exception>
/// <exception cref="NotSupportedException">
/// Reading from <paramref name="payload"/> encountered unsupported records,
/// for example, arrays with non-zero offset or unsupported record types
Expand Down Expand Up @@ -142,7 +144,14 @@ public static SerializationRecord Decode(Stream payload, out IReadOnlyDictionary
#endif

using BinaryReader reader = new(payload, ThrowOnInvalidUtf8Encoding, leaveOpen: leaveOpen);
return Decode(reader, options ?? new(), out recordMap);
try
{
return Decode(reader, options ?? new(), out recordMap);
}
catch (FormatException) // can be thrown by various BinaryReader methods
{
throw new SerializationException(SR.Serialization_InvalidFormat);
}
}

/// <summary>
Expand Down Expand Up @@ -213,12 +222,7 @@ private static SerializationRecord Decode(BinaryReader reader, PayloadOptions op
private static SerializationRecord DecodeNext(BinaryReader reader, RecordMap recordMap,
AllowedRecordTypes allowed, PayloadOptions options, out SerializationRecordType recordType)
{
byte nextByte = reader.ReadByte();
if (((uint)allowed & (1u << nextByte)) == 0)
{
ThrowHelper.ThrowForUnexpectedRecordType(nextByte);
}
recordType = (SerializationRecordType)nextByte;
recordType = reader.ReadSerializationRecordType(allowed);

SerializationRecord record = recordType switch
{
Expand Down Expand Up @@ -254,7 +258,7 @@ private static SerializationRecord DecodeMemberPrimitiveTypedRecord(BinaryReader
PrimitiveType.Boolean => new MemberPrimitiveTypedRecord<bool>(reader.ReadBoolean()),
PrimitiveType.Byte => new MemberPrimitiveTypedRecord<byte>(reader.ReadByte()),
PrimitiveType.SByte => new MemberPrimitiveTypedRecord<sbyte>(reader.ReadSByte()),
PrimitiveType.Char => new MemberPrimitiveTypedRecord<char>(reader.ReadChar()),
PrimitiveType.Char => new MemberPrimitiveTypedRecord<char>(reader.ParseChar()),
PrimitiveType.Int16 => new MemberPrimitiveTypedRecord<short>(reader.ReadInt16()),
PrimitiveType.UInt16 => new MemberPrimitiveTypedRecord<ushort>(reader.ReadUInt16()),
PrimitiveType.Int32 => new MemberPrimitiveTypedRecord<int>(reader.ReadInt32()),
Expand All @@ -263,7 +267,7 @@ private static SerializationRecord DecodeMemberPrimitiveTypedRecord(BinaryReader
PrimitiveType.UInt64 => new MemberPrimitiveTypedRecord<ulong>(reader.ReadUInt64()),
PrimitiveType.Single => new MemberPrimitiveTypedRecord<float>(reader.ReadSingle()),
PrimitiveType.Double => new MemberPrimitiveTypedRecord<double>(reader.ReadDouble()),
PrimitiveType.Decimal => new MemberPrimitiveTypedRecord<decimal>(decimal.Parse(reader.ReadString(), CultureInfo.InvariantCulture)),
PrimitiveType.Decimal => new MemberPrimitiveTypedRecord<decimal>(reader.ParseDecimal()),
PrimitiveType.DateTime => new MemberPrimitiveTypedRecord<DateTime>(Utils.BinaryReaderExtensions.CreateDateTimeFromData(reader.ReadUInt64())),
// String is handled with a record, never on it's own
_ => new MemberPrimitiveTypedRecord<TimeSpan>(new TimeSpan(reader.ReadInt64())),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ internal void Add(SerializationRecord record)

internal SerializationRecord GetRootRecord(SerializedStreamHeaderRecord header)
{
SerializationRecord rootRecord = _map[header.RootId];
SerializationRecord rootRecord = GetRecord(header.RootId);

if (rootRecord is SystemClassWithMembersAndTypesRecord systemClass)
{
// update the record map, so it's visible also to those who access it via Id
Expand All @@ -72,4 +73,14 @@ internal SerializationRecord GetRootRecord(SerializedStreamHeaderRecord header)

return rootRecord;
}

internal SerializationRecord GetRecord(SerializationRecordId recordId)
=> _map.TryGetValue(recordId, out SerializationRecord? record)
? record
: throw new SerializationException(SR.Serialization_InvalidReference);

internal T GetRecord<T>(SerializationRecordId recordId) where T : SerializationRecord
=> _map.TryGetValue(recordId, out SerializationRecord? record) && record is T casted
? casted
: throw new SerializationException(SR.Serialization_InvalidReference);
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,19 @@ internal static class BinaryReaderExtensions
{
private static object? s_baseAmbiguousDstDateTime;

internal static SerializationRecordType ReadSerializationRecordType(this BinaryReader reader, AllowedRecordTypes allowed)
{
byte nextByte = reader.ReadByte();
if (nextByte > (byte)SerializationRecordType.MethodReturn // MethodReturn is the last defined value.
|| (nextByte > (byte)SerializationRecordType.ArraySingleString && nextByte < (byte)SerializationRecordType.MethodCall) // not part of the spec
|| ((uint)allowed & (1u << nextByte)) == 0) // valid, but not allowed
{
ThrowHelper.ThrowForUnexpectedRecordType(nextByte);
}

return (SerializationRecordType)nextByte;
}

internal static BinaryArrayType ReadArrayType(this BinaryReader reader)
{
byte arrayType = reader.ReadByte();
Expand Down Expand Up @@ -48,7 +61,7 @@ internal static PrimitiveType ReadPrimitiveType(this BinaryReader reader)
{
byte primitiveType = reader.ReadByte();
// String is the last defined value, 4 is not used at all.
if (primitiveType is 4 or > (byte)PrimitiveType.String)
if (primitiveType is 0 or 4 or (byte)PrimitiveType.Null or > (byte)PrimitiveType.String)
{
ThrowHelper.ThrowInvalidValue(primitiveType);
}
Expand All @@ -64,7 +77,7 @@ internal static object ReadPrimitiveValue(this BinaryReader reader, PrimitiveTyp
PrimitiveType.Boolean => reader.ReadBoolean(),
PrimitiveType.Byte => reader.ReadByte(),
PrimitiveType.SByte => reader.ReadSByte(),
PrimitiveType.Char => reader.ReadChar(),
PrimitiveType.Char => reader.ParseChar(),
PrimitiveType.Int16 => reader.ReadInt16(),
PrimitiveType.UInt16 => reader.ReadUInt16(),
PrimitiveType.Int32 => reader.ReadInt32(),
Expand All @@ -73,11 +86,35 @@ internal static object ReadPrimitiveValue(this BinaryReader reader, PrimitiveTyp
PrimitiveType.UInt64 => reader.ReadUInt64(),
PrimitiveType.Single => reader.ReadSingle(),
PrimitiveType.Double => reader.ReadDouble(),
PrimitiveType.Decimal => decimal.Parse(reader.ReadString(), CultureInfo.InvariantCulture),
PrimitiveType.Decimal => reader.ParseDecimal(),
PrimitiveType.DateTime => CreateDateTimeFromData(reader.ReadUInt64()),
_ => new TimeSpan(reader.ReadInt64()),
};

// BinaryFormatter serializes decimals as strings and we can't BinaryReader.ReadDecimal.
internal static decimal ParseDecimal(this BinaryReader reader)
{
string text = reader.ReadString();
if (!decimal.TryParse(text, NumberStyles.Number, CultureInfo.InvariantCulture, out decimal result))
{
ThrowHelper.ThrowInvalidFormat();
}

return result;
}

internal static char ParseChar(this BinaryReader reader)
{
try
{
return reader.ReadChar();
}
catch (ArgumentException) // A surrogate character was read.
{
throw new SerializationException(SR.Serialization_SurrogateCharacter);
}
}

/// <summary>
/// Creates a <see cref="DateTime"/> object from raw data with validation.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ internal static void ThrowArrayContainedNulls()
internal static void ThrowInvalidAssemblyName(string rawName)
=> throw new SerializationException(SR.Format(SR.Serialization_InvalidAssemblyName, rawName));

internal static void ThrowInvalidFormat()
=> throw new SerializationException(SR.Serialization_InvalidFormat);

internal static void ThrowEndOfStreamException()
=> throw new EndOfStreamException();

Expand Down
107 changes: 107 additions & 0 deletions src/libraries/System.Formats.Nrbf/tests/InvalidInputTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,9 @@ public static IEnumerable<object[]> ThrowsForInvalidPrimitiveType_Arguments()
{
foreach (byte binaryType in new byte[] { (byte)0 /* BinaryType.Primitive */, (byte)7 /* BinaryType.PrimitiveArray */ })
{
yield return new object[] { recordType, binaryType, (byte)0 }; // value not used by the spec
yield return new object[] { recordType, binaryType, (byte)4 }; // value not used by the spec
yield return new object[] { recordType, binaryType, (byte)17 }; // used by the spec, but illegal in given context
yield return new object[] { recordType, binaryType, (byte)19 };
}
}
Expand Down Expand Up @@ -478,4 +480,109 @@ public void ThrowsOnInvalidArrayType()
stream.Position = 0;
Assert.Throws<SerializationException>(() => NrbfDecoder.Decode(stream));
}

[Theory]
[InlineData(18, typeof(NotSupportedException))] // not part of the spec, but still less than max allowed value (22)
[InlineData(19, typeof(NotSupportedException))] // same as above
[InlineData(20, typeof(NotSupportedException))] // same as above
[InlineData(23, typeof(SerializationException))] // not part of the spec and more than max allowed value (22)
[InlineData(64, typeof(SerializationException))] // same as above but also matches AllowedRecordTypes.SerializedStreamHeader
public void InvalidSerializationRecordType(byte recordType, Type expectedException)
{
using MemoryStream stream = new();
BinaryWriter writer = new(stream, Encoding.UTF8);

WriteSerializedStreamHeader(writer);
writer.Write(recordType); // SerializationRecordType
writer.Write((byte)SerializationRecordType.MessageEnd);

stream.Position = 0;

Assert.Throws(expectedException, () => NrbfDecoder.Decode(stream));
}

[Fact]
public void MissingRootRecord()
{
const int RootRecordId = 1;
using MemoryStream stream = new();
BinaryWriter writer = new(stream, Encoding.UTF8);

WriteSerializedStreamHeader(writer, rootId: RootRecordId);
writer.Write((byte)SerializationRecordType.BinaryObjectString);
writer.Write(RootRecordId + 1); // a different ID
writer.Write("theString");
writer.Write((byte)SerializationRecordType.MessageEnd);

stream.Position = 0;

Assert.Throws<SerializationException>(() => NrbfDecoder.Decode(stream));
}

[Fact]
public void Invalid7BitEncodedStringLength()
{
// The highest bit of the last byte is set (so it's invalid).
byte[] invalidLength = [byte.MaxValue, byte.MaxValue, byte.MaxValue, byte.MaxValue, byte.MaxValue];

using MemoryStream stream = new();
BinaryWriter writer = new(stream, Encoding.UTF8);

WriteSerializedStreamHeader(writer);
writer.Write((byte)SerializationRecordType.BinaryObjectString);
writer.Write(1); // root record Id
writer.Write(invalidLength); // the length prefix
writer.Write(Encoding.UTF8.GetBytes("theString"));
writer.Write((byte)SerializationRecordType.MessageEnd);

stream.Position = 0;

Assert.Throws<SerializationException>(() => NrbfDecoder.Decode(stream));
}

[Theory]
[InlineData("79228162514264337593543950336")] // invalid format (decimal.MaxValue + 1)
[InlineData("1111111111111111111111111111111111111111111111111")] // overflow
public void InvalidDecimal(string textRepresentation)
{
using MemoryStream stream = new();
BinaryWriter writer = new(stream, Encoding.UTF8);

WriteSerializedStreamHeader(writer);
writer.Write((byte)SerializationRecordType.SystemClassWithMembersAndTypes);
writer.Write(1); // root record Id
writer.Write("ClassWithDecimalField"); // type name
writer.Write(1); // member count
writer.Write("memberName");
writer.Write((byte)BinaryType.Primitive);
writer.Write((byte)PrimitiveType.Decimal);
writer.Write(textRepresentation);
writer.Write((byte)SerializationRecordType.MessageEnd);

stream.Position = 0;

Assert.Throws<SerializationException>(() => NrbfDecoder.Decode(stream));
}

[Fact]
public void SurrogateCharacter()
{
using MemoryStream stream = new();
BinaryWriter writer = new(stream, Encoding.UTF8);

WriteSerializedStreamHeader(writer);
writer.Write((byte)SerializationRecordType.SystemClassWithMembersAndTypes);
writer.Write(1); // root record Id
writer.Write("ClassWithCharField"); // type name
writer.Write(1); // member count
writer.Write("memberName");
writer.Write((byte)BinaryType.Primitive);
writer.Write((byte)PrimitiveType.Char);
writer.Write((byte)0xC0); // a surrogate character
writer.Write((byte)SerializationRecordType.MessageEnd);

stream.Position = 0;

Assert.Throws<SerializationException>(() => NrbfDecoder.Decode(stream));
}
}
4 changes: 2 additions & 2 deletions src/libraries/System.Formats.Nrbf/tests/ReadTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,10 @@ protected static BinaryFormatter CreateBinaryFormatter()
};
#pragma warning restore SYSLIB0011 // Type or member is obsolete

protected static void WriteSerializedStreamHeader(BinaryWriter writer, int major = 1, int minor = 0)
protected static void WriteSerializedStreamHeader(BinaryWriter writer, int major = 1, int minor = 0, int rootId = 1)
{
writer.Write((byte)SerializationRecordType.SerializedStreamHeader);
writer.Write(1); // root ID
writer.Write(rootId); // root ID
writer.Write(1); // header ID
writer.Write(major); // major version
writer.Write(minor); // minor version
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

<ItemGroup>
<Compile Include="..\src\System\Formats\Nrbf\BinaryArrayType.cs" Link="BinaryArrayType.cs" />
<Compile Include="..\src\System\Formats\Nrbf\BinaryType.cs" Link="BinaryType.cs" />
<Compile Include="..\src\System\Formats\Nrbf\PrimitiveType.cs" Link="PrimitiveType.cs" />
</ItemGroup>

<ItemGroup>
Expand Down

0 comments on commit e79426e

Please sign in to comment.