Skip to content

Commit

Permalink
[release/9.0] Fix bug in validating unused bits (#107321)
Browse files Browse the repository at this point in the history
* Fix bug in validating unused bits

* Fix another failure

* Update src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Url/Base64UrlValidator.cs

Co-authored-by: Stephen Toub <stoub@microsoft.com>

---------

Co-authored-by: Buyaa Namnan <bunamnan@microsoft.com>
Co-authored-by: Buyaa Namnan <buyankhishig.namnan@microsoft.com>
Co-authored-by: Stephen Toub <stoub@microsoft.com>
  • Loading branch information
4 people authored Sep 4, 2024
1 parent d02256f commit a0847e7
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,29 @@ namespace System.Buffers.Text.Tests
{
public class Base64ValidationUnitTests : Base64TestBase
{
[Theory]
[InlineData("= ")]
[InlineData("= =")]
[InlineData("+ +=")]
[InlineData("A=")]
[InlineData("A==")]
[InlineData("44==")]
[InlineData(" A==")]
[InlineData("AAAAA ==")]
[InlineData("\tLLLL\t=\r")]
[InlineData("6066=")]
[InlineData("6066==")]
[InlineData("SM==")]
[InlineData("SM =")]
[InlineData("s\rEs\r\r==")]
public void BasicValidationEdgeCaseScenario(string base64UrlText)
{
Assert.False(Base64.IsValid(base64UrlText.AsSpan(), out int decodedLength));
Assert.Equal(0, decodedLength);
Span<byte> dest = new byte[Base64.GetMaxDecodedFromUtf8Length(base64UrlText.Length)];
Assert.Equal(OperationStatus.InvalidData, Base64.DecodeFromUtf8(base64UrlText.ToUtf8Span(), dest, out _, out _));
}

[Fact]
public void BasicValidationBytes()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ namespace System.Buffers.Text.Tests
public class Base64UrlValidationUnitTests : Base64TestBase
{
[Theory]
[InlineData("=")]
[InlineData("==")]
[InlineData("-%")]
[InlineData("A=")]
Expand All @@ -19,10 +20,17 @@ public class Base64UrlValidationUnitTests : Base64TestBase
[InlineData("AAAAA ==")]
[InlineData("\tLLLL\t=\r")]
[InlineData("6066=")]
[InlineData("6066==")]
[InlineData("SM==")]
[InlineData("SM=")]
[InlineData("sEs==")]
[InlineData("s\rEs\r\r==")]
public void BasicValidationEdgeCaseScenario(string base64UrlText)
{
Assert.False(Base64Url.IsValid(base64UrlText.AsSpan(), out int decodedLength));
Assert.Equal(0, decodedLength);
Span<byte> dest = new byte[Base64Url.GetMaxDecodedLength(base64UrlText.Length)];
Assert.Equal(OperationStatus.InvalidData, Base64Url.DecodeFromChars(base64UrlText.AsSpan(), dest, out _, out _));
}

[Fact]
Expand Down Expand Up @@ -258,6 +266,10 @@ public void ValidateWithPaddingReturnsCorrectCountChars(string utf8WithByteToBeI
Assert.True(Base64Url.IsValid(utf8BytesWithByteToBeIgnored));
Assert.True(Base64Url.IsValid(utf8BytesWithByteToBeIgnored, out int decodedLength));
Assert.Equal(expectedLength, decodedLength);

Span<byte> dest = new byte[Base64Url.GetMaxDecodedLength(utf8WithByteToBeIgnored.Length)];
Assert.Equal(OperationStatus.Done, Base64Url.DecodeFromChars(utf8WithByteToBeIgnored.AsSpan(), dest, out _, out decodedLength));
Assert.Equal(expectedLength, decodedLength);
}

[Theory]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Diagnostics;
using System.Runtime.CompilerServices;
using static System.Buffers.Text.Base64Helper;

namespace System.Buffers.Text
{
Expand Down Expand Up @@ -94,33 +94,33 @@ public bool ValidateAndDecodeLength(char lastChar, int length, int paddingCount,
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public bool ValidateAndDecodeLength(byte lastChar, int length, int paddingCount, out int decodedLength)
{
// Padding is optional for Base64Url, so need to account remainder. If remainder is 1, then it's invalid.
#if NET
(uint whole, uint remainder) = uint.DivRem((uint)(length), 4);
if (remainder == 1 || (remainder > 1 && (remainder - paddingCount == 1 || paddingCount == remainder)))
{
decodedLength = 0;
return false;
}

decodedLength = (int)((whole * 3) + (remainder > 0 ? remainder - 1 : 0) - paddingCount);
#else
// Padding is optional for Base64Url, so need to account remainder.
int remainder = (int)((uint)length % 4);
if (remainder == 1 || (remainder > 1 && (remainder - paddingCount == 1 || paddingCount == remainder)))

if (paddingCount != 0)
{
decodedLength = 0;
return false;
length -= paddingCount;
remainder = (int)((uint)length % 4);

// if there is a padding, there should be remainder and the sum of remainder and padding should not exceed 4
if (remainder == 0 || remainder + paddingCount > 4)
{
decodedLength = 0;
return false;
}
}

decodedLength = (length >> 2) * 3 + (remainder > 0 ? remainder - 1 : 0) - paddingCount;
#endif
int decoded = default(Base64DecoderByte).DecodingMap[lastChar];
if (((remainder == 3 || paddingCount == 1) && (decoded & 0x03) != 0) ||
((remainder == 2 || paddingCount == 2) && (decoded & 0x0F) != 0))
decodedLength = (length >> 2) * 3 + (remainder > 0 ? remainder - 1 : 0);

if (remainder > 0)
{
// unused lower bits are not 0, reject input
decodedLength = 0;
return false;
int decoded = default(Base64UrlDecoderByte).DecodingMap[lastChar];
switch (remainder)
{
case 1: return false; // 1 byte is not decodable => invalid.
case 2: return ((decoded & 0x0F) == 0); // if unused lower 4 bits are set to 0
case 3: return ((decoded & 0x03) == 0); // if unused lower 2 bits are set to 0
}
}

return true;
Expand Down

0 comments on commit a0847e7

Please sign in to comment.