-
Notifications
You must be signed in to change notification settings - Fork 4.7k
/
NrbfDecoderFuzzer.cs
126 lines (115 loc) · 6.45 KB
/
NrbfDecoderFuzzer.cs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
using System.Buffers;
using System.Formats.Nrbf;
using System.Runtime.Serialization;
using System.Text;
namespace DotnetFuzzing.Fuzzers
{
internal sealed class NrbfDecoderFuzzer : IFuzzer
{
public string[] TargetAssemblies { get; } = ["System.Formats.Nrbf"];
public string[] TargetCoreLibPrefixes => [];
public string Dictionary => "nrbfdecoder.dict";
public void FuzzTarget(ReadOnlySpan<byte> bytes)
{
Test(bytes, PoisonPagePlacement.Before);
Test(bytes, PoisonPagePlacement.After);
}
private static void Test(ReadOnlySpan<byte> bytes, PoisonPagePlacement poisonPagePlacement)
{
using PooledBoundedMemory<byte> inputPoisoned = PooledBoundedMemory<byte>.Rent(bytes, poisonPagePlacement);
using MemoryStream seekableStream = new(inputPoisoned.Memory.ToArray());
Test(inputPoisoned.Span, seekableStream);
// NrbfDecoder has few code paths dedicated to non-seekable streams, let's test them as well.
using NonSeekableStream nonSeekableStream = new(inputPoisoned.Memory.ToArray());
Test(inputPoisoned.Span, nonSeekableStream);
}
private static void Test(Span<byte> testSpan, Stream stream)
{
if (NrbfDecoder.StartsWithPayloadHeader(testSpan))
{
try
{
SerializationRecord record = NrbfDecoder.Decode(stream, out IReadOnlyDictionary<SerializationRecordId, SerializationRecord> recordMap);
switch (record.RecordType)
{
case SerializationRecordType.ArraySingleObject:
SZArrayRecord<object?> arrayObj = (SZArrayRecord<object?>)record;
object?[] objArray = arrayObj.GetArray();
Assert.Equal(arrayObj.Length, objArray.Length);
Assert.Equal(1, arrayObj.Rank);
break;
case SerializationRecordType.ArraySingleString:
SZArrayRecord<string?> arrayString = (SZArrayRecord<string?>)record;
string?[] array = arrayString.GetArray();
Assert.Equal(arrayString.Length, array.Length);
Assert.Equal(1, arrayString.Rank);
Assert.Equal(true, arrayString.TypeNameMatches(typeof(string[])));
break;
case SerializationRecordType.ArraySinglePrimitive:
case SerializationRecordType.BinaryArray:
ArrayRecord arrayBinary = (ArrayRecord)record;
Assert.NotNull(arrayBinary.TypeName);
break;
case SerializationRecordType.BinaryObjectString:
_ = ((PrimitiveTypeRecord<string>)record).Value;
break;
case SerializationRecordType.ClassWithId:
case SerializationRecordType.ClassWithMembersAndTypes:
case SerializationRecordType.SystemClassWithMembersAndTypes:
ClassRecord classRecord = (ClassRecord)record;
Assert.NotNull(classRecord.TypeName);
foreach (string name in classRecord.MemberNames)
{
Assert.Equal(true, classRecord.HasMember(name));
}
break;
case SerializationRecordType.MemberPrimitiveTyped:
PrimitiveTypeRecord primitiveType = (PrimitiveTypeRecord)record;
Assert.NotNull(primitiveType.Value);
break;
case SerializationRecordType.MemberReference:
Assert.NotNull(record.TypeName);
break;
case SerializationRecordType.BinaryLibrary:
Assert.Equal(false, record.Id.Equals(default));
break;
case SerializationRecordType.ObjectNull:
case SerializationRecordType.ObjectNullMultiple:
case SerializationRecordType.ObjectNullMultiple256:
Assert.Equal(default, record.Id);
break;
case SerializationRecordType.MessageEnd:
case SerializationRecordType.SerializedStreamHeader:
// case SerializationRecordType.ClassWithMembers: will cause NotSupportedException
// case SerializationRecordType.SystemClassWithMembers: will cause NotSupportedException
default:
throw new Exception("Unexpected RecordType");
}
}
catch (SerializationException) { /* Reading from the stream encountered invalid NRBF data.*/ }
catch (NotSupportedException) { /* Reading from the stream encountered unsupported records */ }
catch (DecoderFallbackException) { /* Reading from the stream encountered an invalid UTF8 sequence. */ }
catch (EndOfStreamException) { /* The end of the stream was reached before reading SerializationRecordType.MessageEnd record. */ }
catch (IOException) { /* An I/O error occurred. */ }
}
else
{
try
{
NrbfDecoder.Decode(stream);
throw new Exception("Decoding supposed to fail!");
}
catch (SerializationException) { /* Everything has to start with a header */ }
catch (NotSupportedException) { /* Reading from the stream encountered unsupported records */ }
catch (EndOfStreamException) { /* The end of the stream was reached before reading SerializationRecordType.MessageEnd record. */ }
}
}
private class NonSeekableStream : MemoryStream
{
public NonSeekableStream(byte[] buffer) : base(buffer) { }
public override bool CanSeek => false;
}
}
}