From 401191294b4f4fce422c6456a58391aca3c6083e Mon Sep 17 00:00:00 2001 From: Buyaa Namnan Date: Tue, 3 Sep 2024 15:02:33 -0700 Subject: [PATCH] Fix bug in validating unused bits (#106771) * Fix bug in validating unused bits * Fix another failure * Update src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Url/Base64UrlValidator.cs Co-authored-by: Stephen Toub --------- Co-authored-by: Stephen Toub --- .../tests/Base64/Base64ValidationUnitTests.cs | 23 ++++++++++ .../Base64Url/Base64UrlValidationUnitTests.cs | 12 +++++ .../Text/Base64Url/Base64UrlValidator.cs | 46 +++++++++---------- 3 files changed, 58 insertions(+), 23 deletions(-) diff --git a/src/libraries/System.Memory/tests/Base64/Base64ValidationUnitTests.cs b/src/libraries/System.Memory/tests/Base64/Base64ValidationUnitTests.cs index 62b978c60c5e1..f95399ae42091 100644 --- a/src/libraries/System.Memory/tests/Base64/Base64ValidationUnitTests.cs +++ b/src/libraries/System.Memory/tests/Base64/Base64ValidationUnitTests.cs @@ -9,6 +9,29 @@ namespace System.Buffers.Text.Tests { public class Base64ValidationUnitTests : Base64TestBase { + [Theory] + [InlineData("= ")] + [InlineData("= =")] + [InlineData("+ +=")] + [InlineData("A=")] + [InlineData("A==")] + [InlineData("44==")] + [InlineData(" A==")] + [InlineData("AAAAA ==")] + [InlineData("\tLLLL\t=\r")] + [InlineData("6066=")] + [InlineData("6066==")] + [InlineData("SM==")] + [InlineData("SM =")] + [InlineData("s\rEs\r\r==")] + public void BasicValidationEdgeCaseScenario(string base64UrlText) + { + Assert.False(Base64.IsValid(base64UrlText.AsSpan(), out int decodedLength)); + Assert.Equal(0, decodedLength); + Span dest = new byte[Base64.GetMaxDecodedFromUtf8Length(base64UrlText.Length)]; + Assert.Equal(OperationStatus.InvalidData, Base64.DecodeFromUtf8(base64UrlText.ToUtf8Span(), dest, out _, out _)); + } + [Fact] public void BasicValidationBytes() { diff --git a/src/libraries/System.Memory/tests/Base64Url/Base64UrlValidationUnitTests.cs b/src/libraries/System.Memory/tests/Base64Url/Base64UrlValidationUnitTests.cs index 0ee3abcd9153a..092d45d2370d1 100644 --- a/src/libraries/System.Memory/tests/Base64Url/Base64UrlValidationUnitTests.cs +++ b/src/libraries/System.Memory/tests/Base64Url/Base64UrlValidationUnitTests.cs @@ -10,6 +10,7 @@ namespace System.Buffers.Text.Tests public class Base64UrlValidationUnitTests : Base64TestBase { [Theory] + [InlineData("=")] [InlineData("==")] [InlineData("-%")] [InlineData("A=")] @@ -19,10 +20,17 @@ public class Base64UrlValidationUnitTests : Base64TestBase [InlineData("AAAAA ==")] [InlineData("\tLLLL\t=\r")] [InlineData("6066=")] + [InlineData("6066==")] + [InlineData("SM==")] + [InlineData("SM=")] + [InlineData("sEs==")] + [InlineData("s\rEs\r\r==")] public void BasicValidationEdgeCaseScenario(string base64UrlText) { Assert.False(Base64Url.IsValid(base64UrlText.AsSpan(), out int decodedLength)); Assert.Equal(0, decodedLength); + Span dest = new byte[Base64Url.GetMaxDecodedLength(base64UrlText.Length)]; + Assert.Equal(OperationStatus.InvalidData, Base64Url.DecodeFromChars(base64UrlText.AsSpan(), dest, out _, out _)); } [Fact] @@ -258,6 +266,10 @@ public void ValidateWithPaddingReturnsCorrectCountChars(string utf8WithByteToBeI Assert.True(Base64Url.IsValid(utf8BytesWithByteToBeIgnored)); Assert.True(Base64Url.IsValid(utf8BytesWithByteToBeIgnored, out int decodedLength)); Assert.Equal(expectedLength, decodedLength); + + Span dest = new byte[Base64Url.GetMaxDecodedLength(utf8WithByteToBeIgnored.Length)]; + Assert.Equal(OperationStatus.Done, Base64Url.DecodeFromChars(utf8WithByteToBeIgnored.AsSpan(), dest, out _, out decodedLength)); + Assert.Equal(expectedLength, decodedLength); } [Theory] diff --git a/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Url/Base64UrlValidator.cs b/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Url/Base64UrlValidator.cs index bc95ab0f054cd..3cd8bf5c679f2 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Url/Base64UrlValidator.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Url/Base64UrlValidator.cs @@ -1,8 +1,8 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Diagnostics; using System.Runtime.CompilerServices; -using static System.Buffers.Text.Base64Helper; namespace System.Buffers.Text { @@ -94,33 +94,33 @@ public bool ValidateAndDecodeLength(char lastChar, int length, int paddingCount, [MethodImpl(MethodImplOptions.AggressiveInlining)] public bool ValidateAndDecodeLength(byte lastChar, int length, int paddingCount, out int decodedLength) { - // Padding is optional for Base64Url, so need to account remainder. If remainder is 1, then it's invalid. -#if NET - (uint whole, uint remainder) = uint.DivRem((uint)(length), 4); - if (remainder == 1 || (remainder > 1 && (remainder - paddingCount == 1 || paddingCount == remainder))) - { - decodedLength = 0; - return false; - } - - decodedLength = (int)((whole * 3) + (remainder > 0 ? remainder - 1 : 0) - paddingCount); -#else + // Padding is optional for Base64Url, so need to account remainder. int remainder = (int)((uint)length % 4); - if (remainder == 1 || (remainder > 1 && (remainder - paddingCount == 1 || paddingCount == remainder))) + + if (paddingCount != 0) { - decodedLength = 0; - return false; + length -= paddingCount; + remainder = (int)((uint)length % 4); + + // if there is a padding, there should be remainder and the sum of remainder and padding should not exceed 4 + if (remainder == 0 || remainder + paddingCount > 4) + { + decodedLength = 0; + return false; + } } - decodedLength = (length >> 2) * 3 + (remainder > 0 ? remainder - 1 : 0) - paddingCount; -#endif - int decoded = default(Base64DecoderByte).DecodingMap[lastChar]; - if (((remainder == 3 || paddingCount == 1) && (decoded & 0x03) != 0) || - ((remainder == 2 || paddingCount == 2) && (decoded & 0x0F) != 0)) + decodedLength = (length >> 2) * 3 + (remainder > 0 ? remainder - 1 : 0); + + if (remainder > 0) { - // unused lower bits are not 0, reject input - decodedLength = 0; - return false; + int decoded = default(Base64UrlDecoderByte).DecodingMap[lastChar]; + switch (remainder) + { + case 1: return false; // 1 byte is not decodable => invalid. + case 2: return ((decoded & 0x0F) == 0); // if unused lower 4 bits are set to 0 + case 3: return ((decoded & 0x03) == 0); // if unused lower 2 bits are set to 0 + } } return true;