diff --git a/src/libraries/Fuzzing/DotnetFuzzing/DotnetFuzzing.csproj b/src/libraries/Fuzzing/DotnetFuzzing/DotnetFuzzing.csproj index 86c7c058a8fff..82f54f412ace6 100644 --- a/src/libraries/Fuzzing/DotnetFuzzing/DotnetFuzzing.csproj +++ b/src/libraries/Fuzzing/DotnetFuzzing/DotnetFuzzing.csproj @@ -13,7 +13,6 @@ - @@ -31,4 +30,8 @@ + + + + diff --git a/src/libraries/Fuzzing/DotnetFuzzing/Fuzzers/NrbfDecoderFuzzer.cs b/src/libraries/Fuzzing/DotnetFuzzing/Fuzzers/NrbfDecoderFuzzer.cs index 15d7d147bc239..46da18e4b0fc8 100644 --- a/src/libraries/Fuzzing/DotnetFuzzing/Fuzzers/NrbfDecoderFuzzer.cs +++ b/src/libraries/Fuzzing/DotnetFuzzing/Fuzzers/NrbfDecoderFuzzer.cs @@ -18,10 +18,18 @@ internal sealed class NrbfDecoderFuzzer : IFuzzer public void FuzzTarget(ReadOnlySpan bytes) { - using PooledBoundedMemory inputPoisoned = PooledBoundedMemory.Rent(bytes, PoisonPagePlacement.After); - using MemoryStream stream = new MemoryStream(inputPoisoned.Memory.ToArray()); + using PooledBoundedMemory inputPoisonedAfter = PooledBoundedMemory.Rent(bytes, PoisonPagePlacement.After); + using PooledBoundedMemory inputPoisonedBefore = PooledBoundedMemory.Rent(bytes, PoisonPagePlacement.Before); + using MemoryStream streamAfter = new MemoryStream(inputPoisonedAfter.Memory.ToArray()); + using MemoryStream streamBefore = new MemoryStream(inputPoisonedBefore.Memory.ToArray()); - if (NrbfDecoder.StartsWithPayloadHeader(inputPoisoned.Span)) + Test(inputPoisonedAfter.Span, streamAfter); + Test(inputPoisonedBefore.Span, streamBefore); + } + + private static void Test(Span testSpan, MemoryStream stream) + { + if (NrbfDecoder.StartsWithPayloadHeader(testSpan)) { try { @@ -52,7 +60,6 @@ public void FuzzTarget(ReadOnlySpan bytes) case SerializationRecordType.ClassWithId: case SerializationRecordType.ClassWithMembersAndTypes: case SerializationRecordType.SystemClassWithMembersAndTypes: - { ClassRecord classRecord = (ClassRecord)record; Assert.NotNull(classRecord.TypeName); @@ -60,7 +67,7 @@ public void FuzzTarget(ReadOnlySpan bytes) { Assert.Equal(true, classRecord.HasMember(name)); } - } break; + break; case SerializationRecordType.MemberPrimitiveTyped: PrimitiveTypeRecord primitiveType = (PrimitiveTypeRecord)record; Assert.NotNull(primitiveType.Value); @@ -69,6 +76,8 @@ public void FuzzTarget(ReadOnlySpan bytes) Assert.NotNull(record.TypeName); break; case SerializationRecordType.BinaryLibrary: + Assert.Equal(false, record.Id.Equals(default)); + break; case SerializationRecordType.ObjectNull: case SerializationRecordType.ObjectNullMultiple: case SerializationRecordType.ObjectNullMultiple256: @@ -86,8 +95,9 @@ public void FuzzTarget(ReadOnlySpan bytes) catch (NotSupportedException) { /* Reading from the stream encountered unsupported records */ } catch (DecoderFallbackException) { /* Reading from the stream encountered an invalid UTF8 sequence. */ } catch (EndOfStreamException) { /* The end of the stream was reached before reading SerializationRecordType.MessageEnd record. */ } + catch (IOException) { /* An I/O error occurred. */ } } - else + else { try { @@ -97,10 +107,6 @@ public void FuzzTarget(ReadOnlySpan bytes) catch (SerializationException) { /* Everything has to start with a header */ } catch (NotSupportedException) { /* Reading from the stream encountered unsupported records */ } catch (EndOfStreamException) { /* The end of the stream was reached before reading SerializationRecordType.MessageEnd record. */ } - catch (DecoderFallbackException) { /* Reading from the stream encountered an invalid UTF8 sequence. */ } - // below exceptions are not expected - catch (FormatException) { /* Temporarily catch this until its fixed */ } - catch (IOException) { /* Temporarily catch this until its fixed */ } } } }