diff --git a/src/libraries/System.Memory/tests/Base64/Base64DecoderUnitTests.cs b/src/libraries/System.Memory/tests/Base64/Base64DecoderUnitTests.cs index bbd28fb092b01..b31393846c746 100644 --- a/src/libraries/System.Memory/tests/Base64/Base64DecoderUnitTests.cs +++ b/src/libraries/System.Memory/tests/Base64/Base64DecoderUnitTests.cs @@ -1,7 +1,6 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. -using System.Collections; using System.Collections.Generic; using System.Linq; using System.Text; @@ -273,6 +272,9 @@ public void BasicDecodingWithFinalBlockTrueKnownInputDone(string inputString, in [Theory] [InlineData("A", 0, 0)] + [InlineData("A===", 0, 0)] + [InlineData("A==", 0, 0)] + [InlineData("A=", 0, 0)] [InlineData("AQ", 0, 0)] [InlineData("AQI", 0, 0)] [InlineData("AQIDBA", 4, 3)] @@ -285,16 +287,18 @@ public void BasicDecodingWithFinalBlockTrueKnownInputInvalid(string inputString, Assert.Equal(OperationStatus.InvalidData, Base64.DecodeFromUtf8(source, decodedBytes, out int consumed, out int decodedByteCount)); Assert.Equal(expectedConsumed, consumed); Assert.Equal(expectedWritten, decodedByteCount); // expectedWritten == decodedBytes.Length - Assert.True(Base64TestHelper.VerifyDecodingCorrectness(expectedConsumed, decodedBytes.Length, source, decodedBytes)); + Assert.True(Base64TestHelper.VerifyDecodingCorrectness(expectedConsumed, expectedWritten, source, decodedBytes)); } [Theory] [InlineData("\u00ecz/T", 0, 0)] // scalar code-path [InlineData("z/Ta123\u00ec", 4, 3)] [InlineData("\u00ecz/TpH7sqEkerqMweH1uSw==", 0, 0)] // Vector128 code-path - [InlineData("z/TpH7sqEkerqMweH1uSw\u00ec==", 20, 15)] - [InlineData("\u00ecz/TpH7sqEkerqMweH1uSw1a5ebaAF9xa8B0ze1wet4epo==", 0, 0)] // Vector256 / AVX code-path + [InlineData("z/TpH7sqEkerqMweH1uSw\u5948==", 20, 15)] + [InlineData("\u5948/TpH7sqEkerqMweH1uSw1a5ebaAF9xa8B0ze1wet4epo==", 0, 0)] // Vector256 / AVX code-path [InlineData("z/TpH7sqEkerqMweH1uSw1a5ebaAF9xa8B0ze1wet4epo\u00ec==", 44, 33)] + [InlineData("\u5948z+T/H7sqEkerqMweH1uSw1a5ebaAF9xa8B0ze1wet4epo01234567890123456789012345678901234567890123456789==", 0, 0)] // Vector512 / Avx512Vbmi code-path + [InlineData("z/T+H7sqEkerqMweH1uSw1a5ebaAF9xa8B0ze1wet4epo01234567890123456789012345678901234567890123456789\u5948==", 92, 69)] public void BasicDecodingNonAsciiInputInvalid(string inputString, int expectedConsumed, int expectedWritten) { Span source = Encoding.UTF8.GetBytes(inputString); @@ -749,19 +753,5 @@ public void BasicDecodingWithExtraWhitespaceShouldBeCountedInConsumedBytes(strin Assert.Equal(expectedWritten, decodedByteCount); Assert.True(Base64TestHelper.VerifyDecodingCorrectness(expectedConsumed, expectedWritten, source, decodedBytes)); } - - public static IEnumerable BasicDecodingWithExtraWhitespaceShouldBeCountedInConsumedBytes_MemberData() - { - var r = new Random(42); - for (int i = 0; i < 5; i++) - { - yield return new object[] { "AQ==" + new string(r.GetItems(" \n\t\r", i)), 4 + i, 1 }; - } - - foreach (string s in new[] { "MTIz", "M TIz", "MT Iz", "MTI z", "MTIz ", "M TI z", "M T I Z " }) - { - yield return new object[] { s + s + s + s, s.Length * 4, 12 }; - } - } } } diff --git a/src/libraries/System.Memory/tests/Base64/Base64TestBase.cs b/src/libraries/System.Memory/tests/Base64/Base64TestBase.cs index 882db3026722e..828442602c7f2 100644 --- a/src/libraries/System.Memory/tests/Base64/Base64TestBase.cs +++ b/src/libraries/System.Memory/tests/Base64/Base64TestBase.cs @@ -107,5 +107,19 @@ public static IEnumerable StringsOnlyWithCharsToBeIgnored() string GetRepeatedChar(char charToInsert, int numberOfTimesToInsert) => new string(charToInsert, numberOfTimesToInsert); } + + public static IEnumerable BasicDecodingWithExtraWhitespaceShouldBeCountedInConsumedBytes_MemberData() + { + var r = new Random(42); + for (int i = 0; i < 5; i++) + { + yield return new object[] { "AQ==" + new string(r.GetItems(" \n\t\r", i)), 4 + i, 1 }; + } + + foreach (string s in new[] { "MTIz", "M TIz", "MT Iz", "MTI z", "MTIz ", "M TI z", "M T I Z " }) + { + yield return new object[] { s + s + s + s, s.Length * 4, 12 }; + } + } } } diff --git a/src/libraries/System.Memory/tests/Base64/Base64TestHelper.cs b/src/libraries/System.Memory/tests/Base64/Base64TestHelper.cs index 1ccc8e0cb4289..42dca70bd6660 100644 --- a/src/libraries/System.Memory/tests/Base64/Base64TestHelper.cs +++ b/src/libraries/System.Memory/tests/Base64/Base64TestHelper.cs @@ -24,12 +24,23 @@ public static class Base64TestHelper 52, 53, 54, 55, 56, 57, 43, 47 //4..9, +, / }; + public static readonly byte[] s_urlEncodingMap = { + 65, 66, 67, 68, 69, 70, 71, 72, //A..H + 73, 74, 75, 76, 77, 78, 79, 80, //I..P + 81, 82, 83, 84, 85, 86, 87, 88, //Q..X + 89, 90, 97, 98, 99, 100, 101, 102, //Y..Z, a..f + 103, 104, 105, 106, 107, 108, 109, 110, //g..n + 111, 112, 113, 114, 115, 116, 117, 118, //o..v + 119, 120, 121, 122, 48, 49, 50, 51, //w..z, 0..3 + 52, 53, 54, 55, 56, 57, 45, 95 //4..9, -, _ + }; + // Pre-computing this table using a custom string(s_characters) and GenerateDecodingMapAndVerify (found in tests) public static readonly sbyte[] s_decodingMap = { -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 62, -1, -1, -1, 63, //62 is placed at index 43 (for +), 63 at index 47 (for /) - 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, -1, -1, -1, -1, -1, -1, //52-61 are placed at index 48-57 (for 0-9), 64 at index 61 (for =) + 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, -1, -1, -1, -1, -1, -1, //52-61 are placed at index 48-57 (for 0-9) -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, -1, -1, -1, -1, -1, //0-25 are placed at index 65-90 (for A-Z) -1, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, @@ -44,9 +55,29 @@ public static class Base64TestHelper -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, }; + public static readonly sbyte[] s_urlDecodingMap = { + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 62, -1, -1, //62 is placed at index 45 (for -), 63 at index 95 (for _) + 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, -1, -1, -1, -1, -1, -1, //52-61 are placed at index 48-57 (for 0-9) + -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, + 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, -1, -1, -1, -1, 63, //0-25 are placed at index 65-90 (for A-Z) + -1, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, + 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, -1, -1, -1, -1, -1, //26-51 are placed at index 97-122 (for a-z) + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, // Bytes over 122 ('z') are invalid and cannot be decoded + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, // Hence, padding the map with 255, which indicates invalid input + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + }; + public static bool IsByteToBeIgnored(byte charByte) => charByte is (byte)' ' or (byte)'\t' or (byte)'\r' or (byte)'\n'; public const byte EncodingPad = (byte)'='; // '=', for padding + public const byte UrlEncodingPad = (byte)'%'; // '%', for url padding public const sbyte InvalidByte = -1; // Designating -1 for invalid bytes in the decoding map public static byte[] InvalidBytes @@ -60,6 +91,17 @@ public static byte[] InvalidBytes } } + public static byte[] UrlInvalidBytes + { + get + { + int[] indices = s_urlDecodingMap.FindAllIndexOf(InvalidByte); + // Workaround for indices.Cast().ToArray() since it throws + // InvalidCastException: Unable to cast object of type 'System.Int32' to type 'System.Byte' + return indices.Select(i => (byte)i).ToArray(); + } + } + internal static void InitializeBytes(Span bytes, int seed = 100) { var rnd = new Random(seed); @@ -79,6 +121,26 @@ internal static void InitializeDecodableBytes(Span bytes, int seed = 100) } } + internal static void InitializeUrlDecodableChars(Span bytes, int seed = 100) + { + var rnd = new Random(seed); + for (int i = 0; i < bytes.Length; i++) + { + int index = (byte)rnd.Next(0, s_urlEncodingMap.Length); + bytes[i] = (char)s_urlEncodingMap[index]; + } + } + + internal static void InitializeUrlDecodableBytes(Span bytes, int seed = 100) + { + var rnd = new Random(seed); + for (int i = 0; i < bytes.Length; i++) + { + int index = (byte)rnd.Next(0, s_urlEncodingMap.Length); + bytes[i] = s_urlEncodingMap[index]; + } + } + [Fact] public static void GenerateEncodingMapAndVerify() { @@ -112,16 +174,34 @@ public static int[] FindAllIndexOf(this IEnumerable values, T valueToFind) public static bool VerifyEncodingCorrectness(int expectedConsumed, int expectedWritten, Span source, Span encodedBytes) { - string expectedText = Convert.ToBase64String(source.Slice(0, expectedConsumed).ToArray()); - string encodedText = Encoding.ASCII.GetString(encodedBytes.Slice(0, expectedWritten).ToArray()); + string expectedText = Convert.ToBase64String(source.Slice(0, expectedConsumed)); + string encodedText = Encoding.ASCII.GetString(encodedBytes.Slice(0, expectedWritten)); + return expectedText.Equals(encodedText); + } + + public static bool VerifyUrlEncodingCorrectness(int expectedConsumed, int expectedWritten, Span source, Span encodedBytes) + { + string expectedText = Convert.ToBase64String(source.Slice(0, expectedConsumed)) + .Replace('+', '-').Replace('/', '_').TrimEnd('='); + string encodedText = Encoding.ASCII.GetString(encodedBytes.Slice(0, expectedWritten)); return expectedText.Equals(encodedText); } public static bool VerifyDecodingCorrectness(int expectedConsumed, int expectedWritten, Span source, Span decodedBytes) { - string sourceString = Encoding.ASCII.GetString(source.Slice(0, expectedConsumed).ToArray()); + string sourceString = Encoding.ASCII.GetString(source.Slice(0, expectedConsumed)); byte[] expectedBytes = Convert.FromBase64String(sourceString); return expectedBytes.AsSpan().SequenceEqual(decodedBytes.Slice(0, expectedWritten)); } + + public static bool VerifyUrlDecodingCorrectness(int expectedConsumed, int expectedWritten, Span source, Span decodedBytes) + { + string sourceString = Encoding.ASCII.GetString(source.Slice(0, expectedConsumed)); + string padded = sourceString.Length % 4 == 0 ? sourceString : + sourceString.PadRight(sourceString.Length + (4 - sourceString.Length % 4), '='); + string base64 = padded.Replace('_', '/').Replace('-', '+').Replace('%', '='); + byte[] expectedBytes = Convert.FromBase64String(base64); + return expectedBytes.AsSpan().SequenceEqual(decodedBytes.Slice(0, expectedWritten)); + } } } diff --git a/src/libraries/System.Memory/tests/Base64/Base64ValidationUnitTests.cs b/src/libraries/System.Memory/tests/Base64/Base64ValidationUnitTests.cs index c7f164ad9b7f5..62b978c60c5e1 100644 --- a/src/libraries/System.Memory/tests/Base64/Base64ValidationUnitTests.cs +++ b/src/libraries/System.Memory/tests/Base64/Base64ValidationUnitTests.cs @@ -69,6 +69,7 @@ public void BasicValidationInvalidInputLengthBytes() } while (numBytes % 4 == 0); // ensure we have a invalid length Span source = new byte[numBytes]; + Base64TestHelper.InitializeDecodableBytes(source, numBytes); Assert.False(Base64.IsValid(source)); Assert.False(Base64.IsValid(source, out int decodedLength)); @@ -88,10 +89,16 @@ public void BasicValidationInvalidInputLengthChars() numBytes = rnd.Next(100, 1000 * 1000); } while (numBytes % 4 == 0); // ensure we have a invalid length - Span source = new char[numBytes]; + Span source = new byte[numBytes]; + Base64TestHelper.InitializeDecodableBytes(source, numBytes); + Span chars = source + .ToArray() + .Select(Convert.ToChar) + .ToArray() + .AsSpan(); - Assert.False(Base64.IsValid(source)); - Assert.False(Base64.IsValid(source, out int decodedLength)); + Assert.False(Base64.IsValid(chars)); + Assert.False(Base64.IsValid(chars, out int decodedLength)); Assert.Equal(0, decodedLength); } } @@ -267,7 +274,7 @@ public void InvalidSizeBytes(string utf8WithByteToBeIgnored) [InlineData("Y")] public void InvalidSizeChars(string utf8WithByteToBeIgnored) { - byte[] utf8BytesWithByteToBeIgnored = UTF8Encoding.UTF8.GetBytes(utf8WithByteToBeIgnored); + ReadOnlySpan utf8BytesWithByteToBeIgnored = utf8WithByteToBeIgnored; Assert.False(Base64.IsValid(utf8BytesWithByteToBeIgnored)); Assert.False(Base64.IsValid(utf8BytesWithByteToBeIgnored, out int decodedLength)); @@ -329,10 +336,10 @@ public void InvalidBase64Bytes(string utf8WithByteToBeIgnored) [InlineData(" a ")] public void InvalidBase64Chars(string utf8WithByteToBeIgnored) { - byte[] utf8BytesWithByteToBeIgnored = UTF8Encoding.UTF8.GetBytes(utf8WithByteToBeIgnored); + ReadOnlySpan utf8CharsWithCharToBeIgnored = utf8WithByteToBeIgnored; - Assert.False(Base64.IsValid(utf8BytesWithByteToBeIgnored)); - Assert.False(Base64.IsValid(utf8BytesWithByteToBeIgnored, out int decodedLength)); + Assert.False(Base64.IsValid(utf8CharsWithCharToBeIgnored)); + Assert.False(Base64.IsValid(utf8CharsWithCharToBeIgnored, out int decodedLength)); Assert.Equal(0, decodedLength); } } diff --git a/src/libraries/System.Memory/tests/Base64Url/Base64UrlDecoderUnitTests.cs b/src/libraries/System.Memory/tests/Base64Url/Base64UrlDecoderUnitTests.cs new file mode 100644 index 0000000000000..23a16f6bb07a3 --- /dev/null +++ b/src/libraries/System.Memory/tests/Base64Url/Base64UrlDecoderUnitTests.cs @@ -0,0 +1,839 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Linq; +using System.Text; +using Xunit; + +namespace System.Buffers.Text.Tests +{ + public class Base64UrlDecoderUnitTests : Base64TestBase + { + [Fact] + public void BasicDecoding() + { + var rnd = new Random(42); + for (int i = 0; i < 10; i++) + { + int numBytes; + do + { + numBytes = rnd.Next(100, 1000 * 1000); + } while (numBytes % 4 == 1); // ensure we have a valid length + + Span source = new byte[numBytes]; + Base64TestHelper.InitializeUrlDecodableBytes(source, numBytes); + + Span decodedBytes = new byte[Base64Url.GetMaxDecodedLength(source.Length)]; + Assert.Equal(OperationStatus.Done, Base64Url.DecodeFromUtf8(source, decodedBytes, out int consumed, out int decodedByteCount)); + Assert.Equal(source.Length, consumed); + Assert.Equal(decodedBytes.Length, decodedByteCount); + Assert.True(Base64TestHelper.VerifyUrlDecodingCorrectness(source.Length, decodedBytes.Length, source, decodedBytes)); + } + } + + [Fact] + public void BasicDecodingByteArrayReturnOverload() + { + var rnd = new Random(42); + for (int i = 0; i < 5; i++) + { + int numBytes; + do + { + numBytes = rnd.Next(100, 1000 * 1000); + } while (numBytes % 4 == 1); // ensure we have a valid length + + Span source = new byte[numBytes]; + Base64TestHelper.InitializeUrlDecodableBytes(source, numBytes); + + Span decodedBytes = Base64Url.DecodeFromUtf8(source); + Assert.Equal(decodedBytes.Length, Base64Url.GetMaxDecodedLength(source.Length)); + Assert.True(Base64TestHelper.VerifyUrlDecodingCorrectness(source.Length, decodedBytes.Length, source, decodedBytes)); + } + } + + [Fact] + public void BasicDecodingInvalidInputLength() + { + var rnd = new Random(42); + for (int i = 0; i < 10; i++) + { + int numBytes; + do + { + numBytes = rnd.Next(100, 1000 * 1000); + } while (numBytes % 4 != 1); // ensure we have a invalid length + + Span source = new byte[numBytes]; + Base64TestHelper.InitializeUrlDecodableBytes(source, numBytes); + + Span decodedBytes = new byte[Base64Url.GetMaxDecodedLength(source.Length)]; + int expectedConsumed = numBytes / 4 * 4; // decode input up to the closest multiple of 4 + int expectedDecoded = expectedConsumed / 4 * 3; + + Assert.Equal(OperationStatus.InvalidData, Base64Url.DecodeFromUtf8(source, decodedBytes, out int consumed, out int decodedByteCount)); + Assert.Equal(expectedConsumed, consumed); + Assert.Equal(expectedDecoded, decodedByteCount); + Assert.True(Base64TestHelper.VerifyUrlDecodingCorrectness(expectedConsumed, expectedDecoded, source, decodedBytes)); + } + } + + [Fact] + public void BasicDecodingInvalidInputWithOneByteData() + { + // Only 1 byte of data is invalid, 2 - 3 bytes of data are valid as padding is optional + ReadOnlySpan source = stackalloc byte[] { (byte)'A' }; + Span decodedBytes = stackalloc byte[128]; + + Assert.Equal(OperationStatus.InvalidData, Base64Url.DecodeFromUtf8(source, decodedBytes, out int consumed, out int decodedByteCount)); + Assert.Equal(0, consumed); + Assert.Equal(0, decodedByteCount); + } + + [Fact] + public void BasicDecodingWithFinalBlockFalse() + { + var rnd = new Random(42); + for (int i = 0; i < 10; i++) + { + int numBytes; + do + { + numBytes = rnd.Next(100, 1000 * 1000); + } while (numBytes % 4 != 0); // ensure we have a complete length + + Span source = new byte[numBytes]; + Base64TestHelper.InitializeUrlDecodableBytes(source, numBytes); + + Span decodedBytes = new byte[Base64Url.GetMaxDecodedLength(source.Length)]; + int expectedConsumed = source.Length / 4 * 4; // only consume closest multiple of four since isFinalBlock is false + + Assert.Equal(OperationStatus.Done, Base64Url.DecodeFromUtf8(source, decodedBytes, out int consumed, out int decodedByteCount, isFinalBlock: false)); + Assert.Equal(expectedConsumed, consumed); + Assert.Equal(decodedBytes.Length, decodedByteCount); + Assert.True(Base64TestHelper.VerifyUrlDecodingCorrectness(expectedConsumed, decodedBytes.Length, source, decodedBytes)); + } + } + + [Fact] + public void BasicDecodingWithFinalBlockFalseInvalidInputLength() + { + var rnd = new Random(42); + for (int i = 0; i < 10; i++) + { + int numBytes; + do + { + numBytes = rnd.Next(100, 1000 * 1000); + } while (numBytes % 4 == 0); // ensure we have a incomplete length + + Span source = new byte[numBytes]; + Base64TestHelper.InitializeUrlDecodableBytes(source, numBytes); + + Span decodedBytes = new byte[Base64Url.GetMaxDecodedLength(source.Length)]; + int expectedConsumed = source.Length / 4 * 4; // only consume closest multiple of four since isFinalBlock is false + int expectedDecoded = expectedConsumed / 4 * 3; + + Assert.Equal(OperationStatus.NeedMoreData, Base64Url.DecodeFromUtf8(source, decodedBytes, out int consumed, out int decodedByteCount, isFinalBlock: false)); + Assert.Equal(expectedConsumed, consumed); + Assert.Equal(expectedDecoded, decodedByteCount); + Assert.True(Base64TestHelper.VerifyUrlDecodingCorrectness(expectedConsumed, decodedByteCount, source, decodedBytes)); + } + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public void DecodeEmptySpan(bool isFinalBlock) + { + Span source = Span.Empty; + Span decodedBytes = new byte[Base64Url.GetMaxDecodedLength(source.Length)]; + + Assert.Equal(OperationStatus.Done, Base64Url.DecodeFromUtf8(source, decodedBytes, out int consumed, out int decodedByteCount, isFinalBlock)); + Assert.Equal(0, consumed); + Assert.Equal(0, decodedByteCount); + } + + [Fact] + public void DecodeGuid() + { + Span source = new byte[22]; // For Base64Url padding ignored + Span providedBytes = Guid.NewGuid().ToByteArray(); + Base64Url.EncodeToUtf8(providedBytes, source); + + Span decodedBytes = new byte[Base64Url.GetMaxDecodedLength(source.Length)]; + Assert.Equal(16, Base64Url.DecodeFromUtf8(source, decodedBytes)); + Assert.True(providedBytes.SequenceEqual(decodedBytes)); + } + + [Fact] + public void DecodingOutputTooSmall() + { + for (int numBytes = 5; numBytes < 20; numBytes++) + { + Span source = new byte[numBytes]; + Base64TestHelper.InitializeUrlDecodableBytes(source, numBytes); + + Span decodedBytes = new byte[3]; + int consumed, written; + if (numBytes >= 6) + { + Assert.True(OperationStatus.DestinationTooSmall == + Base64Url.DecodeFromUtf8(source, decodedBytes, out consumed, out written), "Number of Input Bytes: " + numBytes); + } + else + { + Assert.True(OperationStatus.InvalidData == + Base64Url.DecodeFromUtf8(source, decodedBytes, out consumed, out written), "Number of Input Bytes: " + numBytes); + } + int expectedConsumed = 4; + Assert.Equal(expectedConsumed, consumed); + Assert.Equal(decodedBytes.Length, written); + Assert.True(Base64TestHelper.VerifyUrlDecodingCorrectness(expectedConsumed, decodedBytes.Length, source, decodedBytes)); + } + + // Output too small even with padding characters in the input + { + Span source = new byte[12]; + Base64TestHelper.InitializeUrlDecodableBytes(source); + source[10] = Base64TestHelper.EncodingPad; + source[11] = Base64TestHelper.EncodingPad; + + Span decodedBytes = new byte[6]; + Assert.Equal(OperationStatus.DestinationTooSmall, Base64Url.DecodeFromUtf8(source, decodedBytes, out int consumed, out int written)); + int expectedConsumed = 8; + Assert.Equal(expectedConsumed, consumed); + Assert.Equal(decodedBytes.Length, written); + Assert.True(Base64TestHelper.VerifyUrlDecodingCorrectness(expectedConsumed, decodedBytes.Length, source, decodedBytes)); + } + + { + Span source = new byte[12]; + Base64TestHelper.InitializeUrlDecodableBytes(source); + source[11] = Base64TestHelper.EncodingPad; + + Span decodedBytes = new byte[7]; + Assert.Equal(OperationStatus.DestinationTooSmall, Base64Url.DecodeFromUtf8(source, decodedBytes, out int consumed, out int written)); + int expectedConsumed = 8; + Assert.Equal(expectedConsumed, consumed); + Assert.Equal(6, written); + Assert.True(Base64TestHelper.VerifyUrlDecodingCorrectness(expectedConsumed, 6, source, decodedBytes)); + } + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public void DecodingOutputTooSmallWithFinalBlockTrueFalse(bool isFinalBlock) + { + for (int numBytes = 8; numBytes < 20; numBytes++) + { + Span source = new byte[numBytes]; + Base64TestHelper.InitializeUrlDecodableBytes(source, numBytes); + + Span decodedBytes = new byte[4]; + int consumed, written; + Assert.True(OperationStatus.DestinationTooSmall == + Base64Url.DecodeFromUtf8(source, decodedBytes, out consumed, out written, isFinalBlock: isFinalBlock), "Number of Input Bytes: " + numBytes); + int expectedConsumed = 4; + int expectedWritten = 3; + Assert.Equal(expectedConsumed, consumed); + Assert.Equal(expectedWritten, written); + Assert.True(Base64TestHelper.VerifyUrlDecodingCorrectness(expectedConsumed, expectedWritten, source, decodedBytes)); + } + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public void DecodingOutputTooSmallRetry(bool isFinalBlock) + { + Span source = new byte[1000]; + Base64TestHelper.InitializeUrlDecodableBytes(source); + + int outputSize = 240; + int requiredSize = Base64Url.GetMaxDecodedLength(source.Length); + + Span decodedBytes = new byte[outputSize]; + Assert.Equal(OperationStatus.DestinationTooSmall, Base64Url.DecodeFromUtf8(source, decodedBytes, out int consumed, out int decodedByteCount, isFinalBlock)); + int expectedConsumed = decodedBytes.Length / 3 * 4; + Assert.Equal(expectedConsumed, consumed); + Assert.Equal(decodedBytes.Length, decodedByteCount); + Assert.True(Base64TestHelper.VerifyUrlDecodingCorrectness(expectedConsumed, decodedBytes.Length, source, decodedBytes)); + + decodedBytes = new byte[requiredSize - outputSize]; + source = source.Slice(consumed); + Assert.Equal(OperationStatus.Done, Base64Url.DecodeFromUtf8(source, decodedBytes, out consumed, out decodedByteCount, isFinalBlock)); + expectedConsumed = decodedBytes.Length / 3 * 4; + Assert.Equal(expectedConsumed, consumed); + Assert.Equal(decodedBytes.Length, decodedByteCount); + Assert.True(Base64TestHelper.VerifyUrlDecodingCorrectness(expectedConsumed, decodedBytes.Length, source, decodedBytes)); + } + + [Theory] + [InlineData("AQ==", 1)] + [InlineData("AQI=", 2)] + [InlineData("AQID", 3)] + [InlineData("AQIDBA%%", 4)] + [InlineData("AQIDBAU=", 5)] + [InlineData("AQIDBAUG", 6)] + public void BasicDecodingWithFinalBlockTrueKnownInputDone(string inputString, int expectedWritten) + { + Span source = Encoding.ASCII.GetBytes(inputString); + Span decodedBytes = new byte[Base64Url.GetMaxDecodedLength(source.Length)]; + + Assert.Equal(expectedWritten, Base64Url.DecodeFromUtf8(source, decodedBytes)); + Assert.True(Base64TestHelper.VerifyUrlDecodingCorrectness(inputString.Length, expectedWritten, source, decodedBytes)); + } + + [Theory] + [InlineData("A", 0, 0, OperationStatus.InvalidData)] + [InlineData("A===", 0, 0, OperationStatus.InvalidData)] + [InlineData("A==", 0, 0, OperationStatus.InvalidData)] + [InlineData("A=", 0, 0, OperationStatus.InvalidData)] + [InlineData("AQ", 2, 1, OperationStatus.Done)] // Padding is optional + [InlineData("AQI", 3, 2, OperationStatus.Done)] + [InlineData("AQIDBA", 6, 4, OperationStatus.Done)] + [InlineData("AQIDBAU", 7, 5, OperationStatus.Done)] + public void BasicDecodingWithFinalBlockTrueInputWithoutPaddingOrInvalidData(string inputString, int expectedConsumed, int expectedWritten, OperationStatus expectedStatus) + { + Span source = Encoding.ASCII.GetBytes(inputString); + Span decodedBytes = new byte[Base64Url.GetMaxDecodedLength(source.Length)]; + + Assert.Equal(expectedStatus, Base64Url.DecodeFromUtf8(source, decodedBytes, out int consumed, out int decodedByteCount)); + Assert.Equal(expectedConsumed, consumed); + Assert.Equal(expectedWritten, decodedByteCount); // expectedWritten == decodedBytes.Length + Assert.True(Base64TestHelper.VerifyUrlDecodingCorrectness(expectedConsumed, expectedWritten, source, decodedBytes)); + } + + [Theory] + [InlineData("A", 0, false)] + [InlineData("A===", 0, false)] + [InlineData("A==", 0, false)] + [InlineData("A=", 0, false)] + [InlineData("AQ", 1, true)] // Padding is optional + [InlineData("AQI", 2, true)] + [InlineData("AQID", 3, true)] + [InlineData("AQIDB", 3, false)] + [InlineData("AQIDBA", 4, true)] + [InlineData("AQIDBAU", 5, true)] + [InlineData("AQ==", 1, true)] + [InlineData("AQI%", 2, true)] + [InlineData("AQIDBA%%", 4, true)] + [InlineData("z_T-H7sqEkerqMweH1uSw1a5ebaAF9xa8B0ze1wet4epo\u5948==", 33, false)] + [InlineData("\u5948z_T-H7sqEkerqMweH1uSw1a5ebaAF9xa8B0ze1wet4epo01234567890123456789012345678901234567890123456789==", 0, false)] + public void TryDecodeFromUtf8VariousInput(string inputString, int expectedWritten, bool succeeds) + { + byte[] source = Encoding.ASCII.GetBytes(inputString); + byte[] decodedBytes = new byte[Base64Url.GetMaxDecodedLength(source.Length)]; + + if (succeeds) + { + Assert.True(Base64Url.TryDecodeFromUtf8(source, decodedBytes, out int bytesWritten)); + Assert.Equal(expectedWritten, bytesWritten); + Assert.True(Base64TestHelper.VerifyUrlDecodingCorrectness(inputString.Length, expectedWritten, source, decodedBytes)); + } + else + { + Assert.Throws(() => Base64Url.TryDecodeFromUtf8(source, decodedBytes, out _)); + } + } + + [Theory] + [InlineData("\u5948cz_T", 0, 0)] // scalar code-path + [InlineData("z_Ta123\u5948", 4, 3)] + [InlineData("\u5948z_T-H7sqEkerqMweH1uSw==", 0, 0)] // Vector128 code-path + [InlineData("z_T-H7sqEkerqMweH1uSw\u5948==", 20, 15)] + [InlineData("\u5948z_T-H7sqEkerqMweH1uSw1a5ebaAF9xa8B0ze1wet4epo==", 0, 0)] // Vector256 / AVX code-path + [InlineData("z_T-H7sqEkerqMweH1uSw1a5ebaAF9xa8B0ze1wet4epo\u5948==", 44, 33)] + [InlineData("\u5948z_T-H7sqEkerqMweH1uSw1a5ebaAF9xa8B0ze1wet4epo01234567890123456789012345678901234567890123456789==", 0, 0)] // Vector512 / Avx512Vbmi code-path + [InlineData("z_T-H7sqEkerqMweH1uSw1a5ebaAF9xa8B0ze1wet4epo01234567890123456789012345678901234567890123456789\u5948==", 92, 69)] + public void BasicDecodingNonAsciiInputInvalid(string inputString, int expectedConsumed, int expectedWritten) + { + Span source = Encoding.UTF8.GetBytes(inputString); + Span decodedBytes = new byte[Base64Url.GetMaxDecodedLength(source.Length)]; + + Assert.Equal(OperationStatus.InvalidData, Base64Url.DecodeFromUtf8(source, decodedBytes, out int consumed, out int decodedByteCount)); + Assert.Equal(expectedConsumed, consumed); + Assert.Equal(expectedWritten, decodedByteCount); + } + + [Theory] + [InlineData("AQID", 3)] + [InlineData("AQIDBAUG", 6)] + public void BasicDecodingWithFinalBlockFalseKnownInputDone(string inputString, int expectedWritten) + { + Span source = Encoding.ASCII.GetBytes(inputString); + Span decodedBytes = new byte[Base64Url.GetMaxDecodedLength(source.Length)]; + + int expectedConsumed = inputString.Length; + Assert.Equal(OperationStatus.Done, Base64Url.DecodeFromUtf8(source, decodedBytes, out int consumed, out int decodedByteCount, isFinalBlock: false)); + Assert.Equal(expectedConsumed, consumed); + Assert.Equal(expectedWritten, decodedByteCount); // expectedWritten == decodedBytes.Length + Assert.True(Base64TestHelper.VerifyDecodingCorrectness(expectedConsumed, decodedBytes.Length, source, decodedBytes)); + } + + [Theory] + [InlineData("A", 0, 0)] + [InlineData("AQ", 0, 0)] // when FinalBlock: false incomplete bytes ignored + [InlineData("AQI", 0, 0)] + [InlineData("AQIDB", 4, 3)] + [InlineData("AQIDBA", 4, 3)] + [InlineData("AQIDBAU", 4, 3)] + public void BasicDecodingWithFinalBlockFalseKnownInputNeedMoreData(string inputString, int expectedConsumed, int expectedWritten) + { + Span source = Encoding.ASCII.GetBytes(inputString); + Span decodedBytes = new byte[Base64Url.GetMaxDecodedLength(source.Length)]; + + Assert.Equal(OperationStatus.NeedMoreData, Base64Url.DecodeFromUtf8(source, decodedBytes, out int consumed, out int decodedByteCount, isFinalBlock: false)); + Assert.Equal(expectedConsumed, consumed); + Assert.Equal(expectedWritten, decodedByteCount); // expectedWritten == decodedBytes.Length + Assert.True(Base64TestHelper.VerifyUrlDecodingCorrectness(expectedConsumed, decodedByteCount, source, decodedBytes)); + } + + [Theory] + [InlineData("AQ==", 0, 0)] + [InlineData("AQI%", 0, 0)] + [InlineData("AQIDBA==", 4, 3)] + [InlineData("AQIDBAU=", 4, 3)] + public void BasicDecodingWithFinalBlockFalseKnownInputInvalid(string inputString, int expectedConsumed, int expectedWritten) + { + Span source = Encoding.ASCII.GetBytes(inputString); + Span decodedBytes = new byte[Base64Url.GetMaxDecodedLength(source.Length)]; + + Assert.Equal(OperationStatus.InvalidData, Base64Url.DecodeFromUtf8(source, decodedBytes, out int consumed, out int decodedByteCount, isFinalBlock: false)); + Assert.Equal(expectedConsumed, consumed); + Assert.Equal(expectedWritten, decodedByteCount); + Assert.True(Base64TestHelper.VerifyDecodingCorrectness(expectedConsumed, expectedWritten, source, decodedBytes)); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public void DecodingInvalidBytes(bool isFinalBlock) + { + // Invalid Bytes: + // 0-44 + // 46=47 + // 58-64 + // 91-94, 96 + // 123-255 + byte[] invalidBytes = Base64TestHelper.UrlInvalidBytes; + Assert.Equal(byte.MaxValue + 1 - 64, invalidBytes.Length); // 192 + + for (int j = 0; j < 8; j++) + { + Span source = "2222PPPP"u8.ToArray(); // valid input + Span decodedBytes = new byte[Base64Url.GetMaxDecodedLength(source.Length)]; + + for (int i = 0; i < invalidBytes.Length; i++) + { + // Don't test padding (byte 61 i.e. '=' or '%'), which is tested in DecodingInvalidBytesPadding + // Don't test chars to be ignored (spaces: 9, 10, 13, 32 i.e. '\n', '\t', '\r', ' ') + if (invalidBytes[i] == Base64TestHelper.EncodingPad || + invalidBytes[i] == Base64TestHelper.UrlEncodingPad || + Base64TestHelper.IsByteToBeIgnored(invalidBytes[i])) + { + continue; + } + + // replace one byte with an invalid input + source[j] = invalidBytes[i]; + + Assert.Equal(OperationStatus.InvalidData, Base64Url.DecodeFromUtf8(source, decodedBytes, out int consumed, out int decodedByteCount, isFinalBlock)); + + if (j < 4) + { + Assert.Equal(0, consumed); + Assert.Equal(0, decodedByteCount); + } + else + { + Assert.Equal(4, consumed); + Assert.Equal(3, decodedByteCount); + Assert.True(Base64TestHelper.VerifyDecodingCorrectness(4, 3, source, decodedBytes)); + } + } + } + + // When isFinalBlock = true input that is not a multiple of 4 is invalid for Base64, but valid for Base64Url + if (isFinalBlock) + { + Span source = "2222PPP"u8.ToArray(); // incomplete input + Span decodedBytes = new byte[Base64Url.GetMaxDecodedLength(source.Length)]; + Assert.Equal(5, Base64Url.DecodeFromUtf8(source, decodedBytes)); + Assert.True(Base64TestHelper.VerifyUrlDecodingCorrectness(7, 5, source, decodedBytes)); + } + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public void DecodingInvalidBytesPadding(bool isFinalBlock) + { + // Only last 2 bytes can be padding, all other occurrence of padding is invalid + for (int j = 0; j < 7; j++) + { + Span source = "2222PPPP"u8.ToArray(); // valid input + Span decodedBytes = new byte[Base64Url.GetMaxDecodedLength(source.Length)]; + source[j] = Base64TestHelper.EncodingPad; + Assert.Equal(OperationStatus.InvalidData, Base64Url.DecodeFromUtf8(source, decodedBytes, out int consumed, out int decodedByteCount, isFinalBlock)); + + if (j < 4) + { + Assert.Equal(0, consumed); + Assert.Equal(0, decodedByteCount); + } + else + { + Assert.Equal(4, consumed); + Assert.Equal(3, decodedByteCount); + Assert.True(Base64TestHelper.VerifyDecodingCorrectness(4, 3, source, decodedBytes)); + } + } + + // Invalid input with valid padding + { + Span source = new byte[] { 50, 50, 50, 50, 80, 42, 42, 42 }; + Span decodedBytes = new byte[Base64Url.GetMaxDecodedLength(source.Length)]; + source[6] = Base64TestHelper.EncodingPad; + source[7] = Base64TestHelper.EncodingPad; // invalid input - "2222P*==" + Assert.Equal(OperationStatus.InvalidData, Base64Url.DecodeFromUtf8(source, decodedBytes, out int consumed, out int decodedByteCount, isFinalBlock)); + + Assert.Equal(4, consumed); + Assert.Equal(3, decodedByteCount); + Assert.True(Base64TestHelper.VerifyDecodingCorrectness(4, 3, source, decodedBytes)); + + source = new byte[] { 50, 50, 50, 50, 80, 42, 42, 42 }; + decodedBytes = new byte[Base64Url.GetMaxDecodedLength(source.Length)]; + source[7] = Base64TestHelper.EncodingPad; // invalid input - "2222PP**=" + Assert.Equal(OperationStatus.InvalidData, Base64Url.DecodeFromUtf8(source, decodedBytes, out consumed, out decodedByteCount, isFinalBlock)); + + Assert.Equal(4, consumed); + Assert.Equal(3, decodedByteCount); + Assert.True(Base64TestHelper.VerifyDecodingCorrectness(4, 3, source, decodedBytes)); + } + + // The last byte or the last 2 bytes being the padding character is valid, if isFinalBlock = true + { + Span source = new byte[] { 50, 50, 50, 50, 80, 80, 80, 80 }; + Span decodedBytes = new byte[Base64Url.GetMaxDecodedLength(source.Length)]; + source[6] = Base64TestHelper.EncodingPad; + source[7] = Base64TestHelper.EncodingPad; // valid input - "2222PP==" + + OperationStatus expectedStatus = isFinalBlock ? OperationStatus.Done : OperationStatus.InvalidData; + int expectedConsumed = isFinalBlock ? source.Length : 4; + int expectedWritten = isFinalBlock ? 4 : 3; + + Assert.Equal(expectedStatus, Base64Url.DecodeFromUtf8(source, decodedBytes, out int consumed, out int decodedByteCount, isFinalBlock)); + Assert.Equal(expectedConsumed, consumed); + Assert.Equal(expectedWritten, decodedByteCount); + Assert.True(Base64TestHelper.VerifyUrlDecodingCorrectness(expectedConsumed, expectedWritten, source, decodedBytes)); + + source = new byte[] { 50, 50, 50, 50, 80, 80, 80, 80 }; + decodedBytes = new byte[Base64Url.GetMaxDecodedLength(source.Length)]; + source[7] = Base64TestHelper.UrlEncodingPad; // valid input - "2222PPP=" + + expectedConsumed = isFinalBlock ? source.Length : 4; + expectedWritten = isFinalBlock ? 5 : 3; + Assert.Equal(expectedStatus, Base64Url.DecodeFromUtf8(source, decodedBytes, out consumed, out decodedByteCount, isFinalBlock)); + Assert.Equal(expectedConsumed, consumed); + Assert.Equal(expectedWritten, decodedByteCount); + Assert.True(Base64TestHelper.VerifyUrlDecodingCorrectness(expectedConsumed, expectedWritten, source, decodedBytes)); + } + } + + [Fact] + public void GetMaxDecodedLength() + { + Span sourceEmpty = Span.Empty; + Assert.Equal(0, Base64Url.GetMaxDecodedLength(0)); + + // int.MaxValue - (int.MaxValue % 4) => 2147483644, largest multiple of 4 less than int.MaxValue + int[] input = { 0, 4, 8, 12, 16, 20, 2000000000, 2147483640, 2147483644 }; + int[] expected = { 0, 3, 6, 9, 12, 15, 1500000000, 1610612730, 1610612733 }; + + for (int i = 0; i < input.Length; i++) + { + Assert.Equal(expected[i], Base64Url.GetMaxDecodedLength(input[i])); + } + + // Lengths that are not a multiple of 4. + int[] lengthsNotMultipleOfFour = { 1, 2, 3, 5, 6, 7, 9, 10, 11, 13, 14, 15, 1001, 1002, 1003, 2147483645, 2147483646, 2147483647 }; + int[] expectedOutput = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 750, 751, 752, 1610612733, 1610612734, 1610612735 }; + for (int i = 0; i < lengthsNotMultipleOfFour.Length; i++) + { + Assert.Equal(expectedOutput[i], Base64Url.GetMaxDecodedLength(lengthsNotMultipleOfFour[i])); + } + + // negative input + Assert.Throws(() => Base64Url.GetMaxDecodedLength(-1)); + Assert.Throws(() => Base64Url.GetMaxDecodedLength(int.MinValue)); + } + + private static bool VerifyUrlDecodingCorrectness(string sourceString, Span decodedBytes) + { + string padded = sourceString.Length % 4 == 0 ? sourceString : + sourceString.PadRight(sourceString.Length + (4 - sourceString.Length % 4), '='); + byte[] expectedBytes = Convert.FromBase64String(padded.Replace('_', '/').Replace('-', '+').Replace('%', '=')); + return expectedBytes.AsSpan().SequenceEqual(decodedBytes); + } + + [Fact] + public void DecodeInPlace() + { + const int numberOfBytes = 15; + + for (int numberOfBytesToTest = 0; numberOfBytesToTest <= numberOfBytes; numberOfBytesToTest += 4) + { + Span testBytes = new byte[numberOfBytes]; + Base64TestHelper.InitializeUrlDecodableBytes(testBytes); + string sourceString = Encoding.ASCII.GetString(testBytes.Slice(0, numberOfBytesToTest).ToArray()); + int bytesWritten = Base64Url.DecodeFromUtf8InPlace(testBytes.Slice(0, numberOfBytesToTest)); + + Assert.Equal(Base64Url.GetMaxDecodedLength(numberOfBytesToTest), bytesWritten); + Assert.True(VerifyUrlDecodingCorrectness(sourceString, testBytes.Slice(0, bytesWritten))); + } + } + + [Fact] + public void EncodeAndDecodeInPlace() + { + byte[] testBytes = new byte[256]; + for (int i = 0; i < 256; i++) + { + testBytes[i] = (byte)i; + } + + for (int value = 0; value < 256; value++) + { + Span sourceBytes = testBytes.AsSpan(0, value + 1); + Span buffer = new byte[Base64Url.GetEncodedLength(sourceBytes.Length)]; + + Assert.Equal(OperationStatus.Done, Base64Url.EncodeToUtf8(sourceBytes, buffer, out int consumed, out int written)); + Assert.True(Base64TestHelper.VerifyUrlEncodingCorrectness(consumed, written, sourceBytes, buffer)); + + int bytesWritten = Base64Url.DecodeFromUtf8InPlace(buffer); + + Assert.Equal(sourceBytes.Length, bytesWritten); + Assert.True(sourceBytes.SequenceEqual(buffer.Slice(0, bytesWritten))); + } + } + + [Fact] + public void DecodeInPlaceInvalidBytesThrowsFormatException() + { + byte[] invalidBytes = Base64TestHelper.UrlInvalidBytes; + + for (int j = 0; j < 8; j++) + { + for (int i = 0; i < invalidBytes.Length; i++) + { + byte[] buffer = "2222PPPP"u8.ToArray(); // valid input + + // Don't test padding (byte 61 i.e. '='), which is tested in DecodeInPlaceInvalidBytesPadding + // Don't test chars to be ignored (spaces: 9, 10, 13, 32 i.e. '\n', '\t', '\r', ' ') + if (invalidBytes[i] == Base64TestHelper.EncodingPad || + invalidBytes[i] == Base64TestHelper.UrlEncodingPad || + Base64TestHelper.IsByteToBeIgnored(invalidBytes[i])) + { + continue; + } + + // replace one byte with an invalid input + buffer[j] = invalidBytes[i]; + + Assert.Throws(() => Base64Url.DecodeFromUtf8InPlace(buffer)); + } + } + + // Input that is not a multiple of 4 is valid for remainder 2-3, but invalid for 1 + { + byte[] buffer = "2222P"u8.ToArray(); // incomplete input + Assert.Throws(() => Base64Url.DecodeFromUtf8InPlace(buffer)); + } + } + + [Fact] + public void DecodeInPlaceInvalidBytesPaddingThrowsFormatException() + { + // Only last 2 bytes can be padding, all other occurrence of padding is invalid + for (int j = 0; j < 7; j++) + { + byte[] buffer = "2222PPPP"u8.ToArray(); // valid input + buffer[j] = Base64TestHelper.EncodingPad; + + Assert.Throws(() => Base64Url.DecodeFromUtf8InPlace(buffer)); + } + + // Invalid input with valid padding + { + byte[] buffer = new byte[] { 50, 50, 50, 50, 80, 42, 42, 42 }; + buffer[6] = Base64TestHelper.EncodingPad; + buffer[7] = Base64TestHelper.EncodingPad; // invalid input - "2222P*==" + + Assert.Throws(() => Base64Url.DecodeFromUtf8InPlace(buffer)); + } + + { + byte[] buffer = new byte[] { 50, 50, 50, 50, 80, 42, 42, 42 }; + buffer[7] = Base64TestHelper.EncodingPad; // invalid input - "2222P**=" + + Assert.Throws(() => Base64Url.DecodeFromUtf8InPlace(buffer)); + } + + // The last byte or the last 2 bytes being the padding character is valid + { + Span buffer = new byte[] { 50, 50, 50, 50, 80, 80, 80, 80 }; + buffer[6] = Base64TestHelper.UrlEncodingPad; + buffer[7] = Base64TestHelper.EncodingPad; // valid input - "2222PP==" + string sourceString = Encoding.ASCII.GetString(buffer.ToArray()); + int bytesWritten = Base64Url.DecodeFromUtf8InPlace(buffer); + + Assert.Equal(4, bytesWritten); + Assert.True(VerifyUrlDecodingCorrectness(sourceString, buffer.Slice(0, bytesWritten))); + } + + { + Span buffer = new byte[] { 50, 50, 50, 50, 80, 80, 80, 80 }; + buffer[7] = Base64TestHelper.EncodingPad; // valid input - "2222PPP=" + string sourceString = Encoding.ASCII.GetString(buffer.ToArray()); + int bytesWritten = Base64Url.DecodeFromUtf8InPlace(buffer); + + Assert.Equal(5, bytesWritten); + Assert.True(VerifyUrlDecodingCorrectness(sourceString, buffer.Slice(0, bytesWritten))); + } + + // The last byte or the last 2 bytes being the padding character is valid + { + Span buffer = new byte[] { 50, 50, 50, 50, 80, 80 }; // valid input without padding "2222PP" + + string sourceString = Encoding.ASCII.GetString(buffer.ToArray()); + int bytesWritten = Base64Url.DecodeFromUtf8InPlace(buffer); + + Assert.Equal(4, bytesWritten); + Assert.True(VerifyUrlDecodingCorrectness(sourceString, buffer.Slice(0, bytesWritten))); + } + } + + [Theory] + [MemberData(nameof(ValidBase64Strings_WithCharsThatMustBeIgnored))] + public void BasicDecodingIgnoresCharsToBeIgnoredAsConvertToBase64Does(string utf8WithCharsToBeIgnored, byte[] expectedBytes) + { + byte[] utf8BytesWithByteToBeIgnored = UTF8Encoding.UTF8.GetBytes(utf8WithCharsToBeIgnored); + byte[] resultBytes = new byte[5]; + OperationStatus result = Base64Url.DecodeFromUtf8(utf8BytesWithByteToBeIgnored, resultBytes, out int bytesConsumed, out int bytesWritten); + + // Control value from Convert.FromBase64String + byte[] stringBytes = Convert.FromBase64String(utf8WithCharsToBeIgnored); + + Assert.Equal(OperationStatus.Done, result); + Assert.Equal(utf8WithCharsToBeIgnored.Length, bytesConsumed); + Assert.Equal(expectedBytes.Length, bytesWritten); + Assert.True(expectedBytes.SequenceEqual(resultBytes)); + Assert.True(stringBytes.SequenceEqual(resultBytes)); + } + + [Theory] + [MemberData(nameof(ValidBase64Strings_WithCharsThatMustBeIgnored))] + public void DecodeInPlaceIgnoresCharsToBeIgnoredAsConvertToBase64Does(string utf8WithCharsToBeIgnored, byte[] expectedBytes) + { + Span utf8BytesWithByteToBeIgnored = UTF8Encoding.UTF8.GetBytes(utf8WithCharsToBeIgnored); + int bytesWritten = Base64Url.DecodeFromUtf8InPlace(utf8BytesWithByteToBeIgnored); + Span bytesOverwritten = utf8BytesWithByteToBeIgnored.Slice(0, bytesWritten); + byte[] resultBytesArray = bytesOverwritten.ToArray(); + + // Control value from Convert.FromBase64String + byte[] stringBytes = Convert.FromBase64String(utf8WithCharsToBeIgnored); + + Assert.Equal(expectedBytes.Length, bytesWritten); + Assert.True(expectedBytes.SequenceEqual(resultBytesArray)); + Assert.True(stringBytes.SequenceEqual(resultBytesArray)); + } + + [Theory] + [MemberData(nameof(StringsOnlyWithCharsToBeIgnored))] + public void BasicDecodingWithOnlyCharsToBeIgnored(string utf8WithCharsToBeIgnored) + { + byte[] utf8BytesWithByteToBeIgnored = UTF8Encoding.UTF8.GetBytes(utf8WithCharsToBeIgnored); + byte[] resultBytes = new byte[5]; + OperationStatus result = Base64Url.DecodeFromUtf8(utf8BytesWithByteToBeIgnored, resultBytes, out int bytesConsumed, out int bytesWritten); + + Assert.Equal(OperationStatus.Done, result); + Assert.Equal(0, bytesWritten); + } + + [Theory] + [MemberData(nameof(StringsOnlyWithCharsToBeIgnored))] + public void DecodingInPlaceWithOnlyCharsToBeIgnored(string utf8WithCharsToBeIgnored) + { + Span utf8BytesWithByteToBeIgnored = UTF8Encoding.UTF8.GetBytes(utf8WithCharsToBeIgnored); + int bytesWritten = Base64Url.DecodeFromUtf8InPlace(utf8BytesWithByteToBeIgnored); + + Assert.Equal(0, bytesWritten); + } + + [Theory] + [InlineData(new byte[] { 0xa, 0xa, 0x2d, 0x2d }, 251)] + [InlineData(new byte[] { 0xa, 0x5f, 0xa, 0x2d }, 255)] + [InlineData(new byte[] { 0x5f, 0x5f, 0xa, 0xa }, 255)] + [InlineData(new byte[] { 0x70, 0xa, 0x61, 0xa }, 165)] + [InlineData(new byte[] { 0xa, 0x70, 0xa, 0x61, 0xa }, 165)] + [InlineData(new byte[] { 0x70, 0xa, 0x61, 0xa, 0x3d, 0x3d }, 165)] + public void DecodingLessThan4BytesWithWhiteSpaces(byte[] utf8Bytes, byte decoded) + { + Assert.True(Base64Url.IsValid(utf8Bytes, out int decodedLength)); + Assert.Equal(1, decodedLength); + Span decodedSpan = new byte[decodedLength]; + OperationStatus status = Base64Url.DecodeFromUtf8(utf8Bytes, decodedSpan, out int bytesRead, out int bytesDecoded); + Assert.Equal(OperationStatus.Done, status); + Assert.Equal(utf8Bytes.Length, bytesRead); + Assert.Equal(decodedLength, bytesDecoded); + Assert.Equal(decoded, decodedSpan[0]); + decodedSpan.Clear(); + Assert.True(Base64Url.TryDecodeFromUtf8(utf8Bytes, decodedSpan, out bytesDecoded)); + Assert.Equal(decodedLength, bytesDecoded); + Assert.Equal(decoded, decodedSpan[0]); + + bytesDecoded = Base64Url.DecodeFromUtf8InPlace(utf8Bytes); + Assert.Equal(decodedLength, bytesDecoded); + Assert.Equal(decoded, utf8Bytes[0]); + } + + [Theory] + [InlineData(new byte[] { 0x4a, 0x74, 0xa, 0x4a, 0x4a, 0x74, 0xa, 0x4a }, new byte[] { 38, 210, 73, 180 })] + [InlineData(new byte[] { 0xa, 0x2d, 0x56, 0xa, 0xa, 0xa, 0x2d, 0x4a, 0x4a, 0x4a, }, new byte[] { 249, 95, 137, 36 })] + public void DecodingNotMultipleOf4WithWhiteSpace(byte[] utf8Bytes, byte[] decoded) + { + Assert.True(Base64Url.IsValid(utf8Bytes, out int decodedLength)); + Assert.Equal(4, decodedLength); + Span decodedSpan = new byte[decodedLength]; + OperationStatus status = Base64Url.DecodeFromUtf8(utf8Bytes, decodedSpan, out int bytesRead, out int bytesDecoded); + Assert.Equal(OperationStatus.Done, status); + Assert.Equal(utf8Bytes.Length, bytesRead); + Assert.Equal(decodedLength, bytesDecoded); + Assert.Equal(decoded, decodedSpan); + decodedSpan.Clear(); + Assert.True(Base64Url.TryDecodeFromUtf8(utf8Bytes, decodedSpan, out bytesDecoded)); + Assert.Equal(decodedLength, bytesDecoded); + Assert.Equal(decoded, decodedSpan); + bytesDecoded = Base64Url.DecodeFromUtf8InPlace(utf8Bytes); + Assert.Equal(decodedLength, bytesDecoded); + Assert.Equal(decoded, utf8Bytes.AsSpan().Slice(0, bytesDecoded)); + } + + [Theory] + [MemberData(nameof(BasicDecodingWithExtraWhitespaceShouldBeCountedInConsumedBytes_MemberData))] + public void BasicDecodingWithExtraWhitespaceShouldBeCountedInConsumedBytes(string inputString, int expectedConsumed, int expectedWritten) + { + Span source = Encoding.ASCII.GetBytes(inputString); + Span decodedBytes = new byte[Base64Url.GetMaxDecodedLength(source.Length)]; + + Assert.Equal(OperationStatus.Done, Base64Url.DecodeFromUtf8(source, decodedBytes, out int consumed, out int decodedByteCount)); + Assert.Equal(expectedConsumed, consumed); + Assert.Equal(expectedWritten, decodedByteCount); + Assert.True(Base64TestHelper.VerifyDecodingCorrectness(expectedConsumed, expectedWritten, source, decodedBytes)); + } + } +} diff --git a/src/libraries/System.Memory/tests/Base64Url/Base64UrlEncoderUnitTests.cs b/src/libraries/System.Memory/tests/Base64Url/Base64UrlEncoderUnitTests.cs new file mode 100644 index 0000000000000..bfabcc40d51b0 --- /dev/null +++ b/src/libraries/System.Memory/tests/Base64Url/Base64UrlEncoderUnitTests.cs @@ -0,0 +1,342 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.SpanTests; +using System.Text; +using Xunit; + +namespace System.Buffers.Text.Tests +{ + public class Base64UrlEncoderUnitTests + { + [Fact] + public void BasicEncodingAndDecoding() + { + var bytes = new byte[byte.MaxValue + 1]; + for (int i = 0; i < byte.MaxValue + 1; i++) + { + bytes[i] = (byte)i; + } + + for (int value = 0; value < 256; value++) + { + Span sourceBytes = bytes.AsSpan(0, value + 1); + Span encodedBytes = new byte[Base64Url.GetEncodedLength(sourceBytes.Length)]; + Assert.Equal(OperationStatus.Done, Base64Url.EncodeToUtf8(sourceBytes, encodedBytes, out int consumed, out int encodedBytesCount)); + Assert.Equal(sourceBytes.Length, consumed); + Assert.Equal(encodedBytes.Length, encodedBytesCount); + Assert.True(Base64TestHelper.VerifyUrlEncodingCorrectness(sourceBytes.Length, encodedBytes.Length, sourceBytes, encodedBytes)); + + int decodedLength = Base64Url.GetMaxDecodedLength(encodedBytes.Length); + Assert.True(sourceBytes.Length <= decodedLength); + Span decodedBytes = new byte[decodedLength]; + Assert.Equal(OperationStatus.Done, Base64Url.DecodeFromUtf8(encodedBytes, decodedBytes, out consumed, out int decodedByteCount)); + Assert.Equal(encodedBytes.Length, consumed); + Assert.Equal(sourceBytes.Length, decodedByteCount); + Assert.True(sourceBytes.SequenceEqual(decodedBytes.Slice(0, decodedByteCount))); + } + } + + [Fact] + public void BasicEncoding() + { + var rnd = new Random(42); + for (int i = 0; i < 10; i++) + { + int numBytes = rnd.Next(100, 1000 * 1000); + Span source = new byte[numBytes]; + Base64TestHelper.InitializeBytes(source, numBytes); + + Span encodedBytes = new byte[Base64Url.GetEncodedLength(source.Length)]; + OperationStatus result = Base64Url.EncodeToUtf8(source, encodedBytes, out int consumed, out int encodedBytesCount); + Assert.Equal(OperationStatus.Done, result); + Assert.Equal(source.Length, consumed); + Assert.Equal(encodedBytes.Length, encodedBytesCount); + Assert.True(Base64TestHelper.VerifyUrlEncodingCorrectness(source.Length, encodedBytes.Length, source, encodedBytes)); + } + } + + [Fact] + public void BasicEncodingWithFinalBlockFalse() + { + var rnd = new Random(42); + for (int i = 0; i < 10; i++) + { + int numBytes = rnd.Next(100, 1000 * 1000); + Span source = new byte[numBytes]; + Base64TestHelper.InitializeBytes(source, numBytes); + + Span encodedBytes = new byte[Base64Url.GetEncodedLength(source.Length)]; + int expectedConsumed = source.Length / 3 * 3; // only consume closest multiple of three since isFinalBlock is false + int expectedWritten = source.Length / 3 * 4; + + // The constant random seed guarantees that both states are tested. + OperationStatus expectedStatus = numBytes % 3 == 0 ? OperationStatus.Done : OperationStatus.NeedMoreData; + Assert.Equal(expectedStatus, Base64Url.EncodeToUtf8(source, encodedBytes, out int consumed, out int encodedBytesCount, isFinalBlock: false)); + Assert.Equal(expectedConsumed, consumed); + Assert.Equal(expectedWritten, encodedBytesCount); + Assert.True(Base64TestHelper.VerifyUrlEncodingCorrectness(expectedConsumed, expectedWritten, source, encodedBytes)); + } + } + + [Theory] + [InlineData(1, "AQ")] + [InlineData(2, "AQI")] + [InlineData(3, "AQID")] + [InlineData(4, "AQIDBA")] + [InlineData(5, "AQIDBAU")] + [InlineData(6, "AQIDBAUG")] + [InlineData(7, "AQIDBAUGBw")] + public void BasicEncodingWithFinalBlockTrueKnownInput(int numBytes, string expectedText) + { + int expectedConsumed = numBytes; + int expectedWritten = expectedText.Length; + + Span source = new byte[numBytes]; + for (int i = 0; i < numBytes; i++) + { + source[i] = (byte)(i + 1); + } + Span encodedBytes = new byte[Base64Url.GetEncodedLength(source.Length)]; + + Assert.Equal(OperationStatus.Done, Base64Url.EncodeToUtf8(source, encodedBytes, out int consumed, out int encodedBytesCount)); + Assert.Equal(expectedConsumed, consumed); + Assert.Equal(expectedWritten, encodedBytesCount); + + string encodedText = Encoding.ASCII.GetString(encodedBytes.Slice(0, expectedWritten).ToArray()); + Assert.Equal(expectedText, encodedText); + } + + [Theory] + [InlineData(1, "", 0, 0)] + [InlineData(2, "", 0, 0)] + [InlineData(3, "AQID", 3, 4)] + [InlineData(4, "AQID", 3, 4)] + [InlineData(5, "AQID", 3, 4)] + [InlineData(6, "AQIDBAUG", 6, 8)] + [InlineData(7, "AQIDBAUG", 6, 8)] + public void BasicEncodingWithFinalBlockFalseKnownInput(int numBytes, string expectedText, int expectedConsumed, int expectedWritten) + { + Span source = new byte[numBytes]; + for (int i = 0; i < numBytes; i++) + { + source[i] = (byte)(i + 1); + } + Span encodedBytes = new byte[Base64Url.GetEncodedLength(source.Length)]; + + OperationStatus expectedStatus = numBytes % 3 == 0 ? OperationStatus.Done : OperationStatus.NeedMoreData; + Assert.Equal(expectedStatus, Base64Url.EncodeToUtf8(source, encodedBytes, out int consumed, out int encodedBytesCount, isFinalBlock: false)); + Assert.Equal(expectedConsumed, consumed); + Assert.Equal(expectedWritten, encodedBytesCount); + + string encodedText = Encoding.ASCII.GetString(encodedBytes.Slice(0, expectedWritten).ToArray()); + Assert.Equal(expectedText, encodedText); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public void EncodeEmptySpan(bool isFinalBlock) + { + Span source = Span.Empty; + Span encodedBytes = new byte[Base64Url.GetEncodedLength(source.Length)]; + + Assert.Equal(OperationStatus.Done, Base64Url.EncodeToUtf8(source, encodedBytes, out int consumed, out int encodedBytesCount, isFinalBlock)); + Assert.Equal(0, consumed); + Assert.Equal(0, encodedBytesCount); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public void EncodingOutputTooSmall(bool isFinalBlock) + { + for (int numBytes = 4; numBytes < 20; numBytes++) + { + Span source = new byte[numBytes]; + Base64TestHelper.InitializeBytes(source, numBytes); + + Span encodedBytes = new byte[4]; + Assert.Equal(OperationStatus.DestinationTooSmall, Base64Url.EncodeToUtf8(source, encodedBytes, out int consumed, out int written, isFinalBlock)); + int expectedConsumed = 3; + Assert.Equal(expectedConsumed, consumed); + Assert.Equal(encodedBytes.Length, written); + Assert.True(Base64TestHelper.VerifyUrlEncodingCorrectness(expectedConsumed, encodedBytes.Length, source, encodedBytes)); + } + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public void EncodingOutputTooSmallRetry(bool isFinalBlock) + { + Span source = new byte[750]; + Base64TestHelper.InitializeBytes(source); + + int outputSize = 320; + int requiredSize = Base64Url.GetEncodedLength(source.Length); + + Span encodedBytes = new byte[outputSize]; + Assert.Equal(OperationStatus.DestinationTooSmall, Base64Url.EncodeToUtf8(source, encodedBytes, out int consumed, out int written, isFinalBlock)); + int expectedConsumed = encodedBytes.Length / 4 * 3; + Assert.Equal(expectedConsumed, consumed); + Assert.Equal(encodedBytes.Length, written); + Assert.True(Base64TestHelper.VerifyUrlEncodingCorrectness(expectedConsumed, encodedBytes.Length, source, encodedBytes)); + + encodedBytes = new byte[requiredSize - outputSize]; + source = source.Slice(consumed); + Assert.Equal(OperationStatus.Done, Base64Url.EncodeToUtf8(source, encodedBytes, out consumed, out written, isFinalBlock)); + expectedConsumed = encodedBytes.Length / 4 * 3; + Assert.Equal(expectedConsumed, consumed); + Assert.Equal(encodedBytes.Length, written); + Assert.True(Base64TestHelper.VerifyUrlEncodingCorrectness(expectedConsumed, encodedBytes.Length, source, encodedBytes)); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + [OuterLoop] + public void EncodeTooLargeSpan(bool isFinalBlock) + { + if (!Environment.Is64BitProcess) + return; + + bool allocatedFirst = false; + bool allocatedSecond = false; + IntPtr memBlockFirst = IntPtr.Zero; + IntPtr memBlockSecond = IntPtr.Zero; + + // int.MaxValue - (int.MaxValue % 4) => 2147483644, largest multiple of 4 less than int.MaxValue + // CLR default limit of 2 gigabytes (GB). + // 1610612734, larger than MaximumEncodeLength, requires output buffer of size 2147483648 (which is > int.MaxValue) + const int sourceCount = (int.MaxValue >> 2) * 3 + 1; + const int encodedCount = 2000000000; + + try + { + allocatedFirst = AllocationHelper.TryAllocNative((IntPtr)sourceCount, out memBlockFirst); + allocatedSecond = AllocationHelper.TryAllocNative((IntPtr)encodedCount, out memBlockSecond); + if (allocatedFirst && allocatedSecond) + { + unsafe + { + var source = new Span(memBlockFirst.ToPointer(), sourceCount); + var encodedBytes = new Span(memBlockSecond.ToPointer(), encodedCount); + + Assert.Equal(OperationStatus.DestinationTooSmall, Base64Url.EncodeToUtf8(source, encodedBytes, out int consumed, out int encodedBytesCount, isFinalBlock)); + Assert.Equal((encodedBytes.Length >> 2) * 3, consumed); // encoding 1500000000 bytes fits into buffer of 2000000000 bytes + Assert.Equal(encodedBytes.Length, encodedBytesCount); + } + } + } + finally + { + if (allocatedFirst) + AllocationHelper.ReleaseNative(ref memBlockFirst); + if (allocatedSecond) + AllocationHelper.ReleaseNative(ref memBlockSecond); + } + } + + [Fact] + public void GetEncodedLength() + { + // (int.MaxValue - 4)/(4/3) => 1610612733, otherwise integer overflow + int[] input = { 0, 1, 2, 3, 4, 5, 6, 1610612728, 1610612729, 1610612730, 1610612731, 1610612732, 1610612733 }; + int[] expected = { 0, 2, 3, 4, 6, 7, 8, 2147483638, 2147483639, 2147483640, 2147483642, 2147483643, 2147483644 }; + for (int i = 0; i < input.Length; i++) + { + Assert.Equal(expected[i], Base64Url.GetEncodedLength(input[i])); + } + + // integer overflow + Assert.Throws(() => Base64Url.GetEncodedLength(1610612734)); + Assert.Throws(() => Base64Url.GetEncodedLength(int.MaxValue)); + + // negative input + Assert.Throws(() => Base64Url.GetEncodedLength(-1)); + Assert.Throws(() => Base64Url.GetEncodedLength(int.MinValue)); + } + + [Fact] + public void TryEncodeInPlace() + { + const int numberOfBytes = 15; + Span testBytes = new byte[numberOfBytes / 3 * 4]; // slack since encoding inflates the data + Base64TestHelper.InitializeBytes(testBytes); + + for (int numberOfBytesToTest = 0; numberOfBytesToTest <= numberOfBytes; numberOfBytesToTest++) + { + var expectedText = Convert.ToBase64String(testBytes.Slice(0, numberOfBytesToTest).ToArray()) + .Replace('+', '-').Replace('/', '_').TrimEnd('='); + + Assert.True(Base64Url.TryEncodeToUtf8InPlace(testBytes, numberOfBytesToTest, out int bytesWritten)); + Assert.Equal(Base64Url.GetEncodedLength(numberOfBytesToTest), bytesWritten); + + var encodedText = Encoding.ASCII.GetString(testBytes.Slice(0, bytesWritten).ToArray()); + Assert.Equal(expectedText, encodedText); + } + } + + [Fact] + public void TryEncodeInPlaceOutputTooSmall() + { + byte[] testBytes = { 1, 2, 3 }; + + Assert.False(Base64Url.TryEncodeToUtf8InPlace(testBytes, testBytes.Length, out int bytesWritten)); + Assert.Equal(0, bytesWritten); + } + + [Fact] + public void TryEncodeToUtf8() + { + const int numberOfBytes = 15; + Span testBytes = new byte[numberOfBytes / 3 * 4]; // slack since encoding inflates the data + Base64TestHelper.InitializeBytes(testBytes); + + for (int numberOfBytesToTest = 0; numberOfBytesToTest <= numberOfBytes; numberOfBytesToTest++) + { + ReadOnlySpan source = testBytes.Slice(0, numberOfBytesToTest); + Span destination = new byte[Base64Url.GetEncodedLength(numberOfBytesToTest)]; + Assert.True(Base64Url.TryEncodeToUtf8(source, destination, out int bytesWritten)); + Assert.Equal(destination.Length, bytesWritten); + Assert.True(source.SequenceEqual(Base64Url.DecodeFromUtf8(destination).AsSpan())); + } + } + + [Theory] + [InlineData(1, "AQ")] + [InlineData(2, "AQI")] + [InlineData(3, "AQID")] + [InlineData(4, "AQIDBA")] + [InlineData(5, "AQIDBAU")] + [InlineData(6, "AQIDBAUG")] + [InlineData(7, "AQIDBAUGBw")] + public void TryEncodeToUtf8EncodeUpToDestinationSize(int numBytes, string expectedText) + { + int expectedWritten = expectedText.Length; + + Span source = new byte[numBytes]; + for (int i = 0; i < numBytes; i++) + { + source[i] = (byte)(i + 1); + } + Span destination = new byte[6]; + + if (numBytes < 5) + { + Assert.True(Base64Url.TryEncodeToUtf8(source, destination, out int bytesWritten)); + Assert.Equal(expectedWritten, bytesWritten); + string encodedText = Encoding.ASCII.GetString(destination.Slice(0, expectedWritten).ToArray()); + Assert.Equal(expectedText, encodedText); + } + else + { + Assert.False(Base64Url.TryEncodeToUtf8(source, destination, out int bytesWritten)); + Assert.Equal(4, bytesWritten); + string encodedText = Encoding.ASCII.GetString(destination.Slice(0, 4).ToArray()); + Assert.Equal(expectedText.Substring(0, 4), encodedText); + } + } + } +} diff --git a/src/libraries/System.Memory/tests/Base64Url/Base64UrlUnicodeAPIsUnitTests.cs b/src/libraries/System.Memory/tests/Base64Url/Base64UrlUnicodeAPIsUnitTests.cs new file mode 100644 index 0000000000000..adbf17eae1004 --- /dev/null +++ b/src/libraries/System.Memory/tests/Base64Url/Base64UrlUnicodeAPIsUnitTests.cs @@ -0,0 +1,622 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Linq; +using System.Text; +using Xunit; + +namespace System.Buffers.Text.Tests +{ + public class Base64UrlUnicodeAPIsUnitTests + { + [Theory] + [InlineData("", 0)] + [InlineData("t", 2)] + [InlineData("te", 3)] + [InlineData("tes", 4)] + [InlineData("test", 6)] + [InlineData("test/", 7)] + [InlineData("test/+", 8)] + public static void DecodeEncodeToFromCharsStringRoundTrip(string str, int expectedWritten) + { + byte[] inputBytes = Encoding.UTF8.GetBytes(str); + Span resultChars = new char[Base64Url.GetEncodedLength(inputBytes.Length)]; + OperationStatus operationStatus = Base64Url.EncodeToChars(inputBytes, resultChars, out int bytesConsumed, out int charsWritten); + Assert.Equal(OperationStatus.Done, operationStatus); + Assert.Equal(str.Length, bytesConsumed); + Assert.Equal(expectedWritten, charsWritten); + string result = Base64Url.EncodeToString(inputBytes); + Assert.Equal(result, resultChars); + Assert.Equal(expectedWritten, Base64Url.EncodeToChars(inputBytes, resultChars)); + Assert.True(Base64Url.TryEncodeToChars(inputBytes, resultChars, out charsWritten)); + Assert.Equal(expectedWritten, charsWritten); + Assert.Equal(result, resultChars); + + Span decodedBytes = new byte[Base64Url.GetMaxDecodedLength(resultChars.Length)]; + operationStatus = Base64Url.DecodeFromChars(resultChars, decodedBytes, out bytesConsumed, out int bytesWritten); + Assert.Equal(OperationStatus.Done, operationStatus); + Assert.Equal(resultChars.Length, bytesConsumed); + Assert.Equal(str.Length, bytesWritten); + Assert.Equal(inputBytes, decodedBytes); + Assert.Equal(str.Length, Base64Url.DecodeFromChars(resultChars, decodedBytes)); + Assert.True(Base64Url.TryDecodeFromChars(resultChars, decodedBytes, out bytesConsumed)); + Assert.Equal(str.Length, bytesConsumed); + Assert.Equal(inputBytes, decodedBytes); + Assert.Equal(str, Encoding.UTF8.GetString(decodedBytes)); + } + + [Fact] + public void EncodingWithLargeSpan() + { + var rnd = new Random(42); + for (int i = 0; i < 5; i++) + { + int numBytes = rnd.Next(100, 1000 * 1000); + Span source = new byte[numBytes]; + Base64TestHelper.InitializeBytes(source, numBytes); + + Span encodedBytes = new char[Base64Url.GetEncodedLength(source.Length)]; + OperationStatus result = Base64Url.EncodeToChars(source, encodedBytes, out int consumed, out int encodedBytesCount); + Assert.Equal(OperationStatus.Done, result); + Assert.Equal(source.Length, consumed); + Assert.Equal(encodedBytes.Length, encodedBytesCount); + string expectedText = Convert.ToBase64String(source).Replace('+', '-').Replace('/', '_').TrimEnd('='); + Assert.Equal(expectedText, encodedBytes); + } + } + + [Fact] + public void DecodeWithLargeSpan() + { + var rnd = new Random(42); + for (int i = 0; i < 5; i++) + { + int numBytes; + do + { + numBytes = rnd.Next(100, 1000 * 1000); + } while (numBytes % 4 == 1); // ensure we have a valid length + + Span source = new char[numBytes]; + Base64TestHelper.InitializeUrlDecodableChars(source, numBytes); + + Span decodedBytes = new byte[Base64Url.GetMaxDecodedLength(source.Length)]; + Assert.Equal(OperationStatus.Done, Base64Url.DecodeFromChars(source, decodedBytes, out int consumed, out int decodedByteCount)); + Assert.Equal(source.Length, consumed); + Assert.Equal(decodedBytes.Length, decodedByteCount); + + string sourceString = source.ToString(); + string padded = sourceString.Length % 4 == 0 ? sourceString : + sourceString.PadRight(sourceString.Length + (4 - sourceString.Length % 4), '='); + string base64 = padded.Replace("_", "/").Replace("-", "+"); + byte[] expectedBytes = Convert.FromBase64String(base64); + Assert.True(expectedBytes.AsSpan().SequenceEqual(decodedBytes.Slice(0, decodedByteCount))); + } + } + + [Fact] + public void RoundTripWithLargeSpan() + { + var rnd = new Random(42); + for (int i = 0; i < 5; i++) + { + int numBytes = rnd.Next(100, 1000 * 1000); + Span source = new byte[numBytes]; + Base64TestHelper.InitializeBytes(source, numBytes); + + int expectedLength = Base64Url.GetEncodedLength(source.Length); + char[] encodedBytes = Base64Url.EncodeToChars(source); + Assert.Equal(expectedLength, encodedBytes.Length); + Assert.Equal(new String(encodedBytes), Base64Url.EncodeToString(source)); + + byte[] decoded = Base64Url.DecodeFromChars(encodedBytes); + Assert.Equal(source, decoded); + } + } + + + public static IEnumerable EncodeToStringTests_TestData() + { + yield return new object[] { Enumerable.Range(0, 0).Select(i => (byte)i).ToArray(), "" }; + yield return new object[] { Enumerable.Range(0, 1).Select(i => (byte)i).ToArray(), "AA" }; + yield return new object[] { Enumerable.Range(0, 2).Select(i => (byte)i).ToArray(), "AAE" }; + yield return new object[] { Enumerable.Range(0, 3).Select(i => (byte)i).ToArray(), "AAEC" }; + yield return new object[] { Enumerable.Range(0, 4).Select(i => (byte)i).ToArray(), "AAECAw" }; + yield return new object[] { Enumerable.Range(0, 5).Select(i => (byte)i).ToArray(), "AAECAwQ" }; + yield return new object[] { Enumerable.Range(0, 6).Select(i => (byte)i).ToArray(), "AAECAwQF" }; + yield return new object[] { Enumerable.Range(0, 7).Select(i => (byte)i).ToArray(), "AAECAwQFBg" }; + yield return new object[] { Enumerable.Range(0, 8).Select(i => (byte)i).ToArray(), "AAECAwQFBgc" }; + yield return new object[] { Enumerable.Range(0, 9).Select(i => (byte)i).ToArray(), "AAECAwQFBgcI" }; + yield return new object[] { Enumerable.Range(0, 10).Select(i => (byte)i).ToArray(), "AAECAwQFBgcICQ" }; + yield return new object[] { Enumerable.Range(0, 11).Select(i => (byte)i).ToArray(), "AAECAwQFBgcICQo" }; + yield return new object[] { Enumerable.Range(0, 12).Select(i => (byte)i).ToArray(), "AAECAwQFBgcICQoL" }; + yield return new object[] { Enumerable.Range(0, 13).Select(i => (byte)i).ToArray(), "AAECAwQFBgcICQoLDA" }; + yield return new object[] { Enumerable.Range(0, 14).Select(i => (byte)i).ToArray(), "AAECAwQFBgcICQoLDA0" }; + yield return new object[] { Enumerable.Range(0, 15).Select(i => (byte)i).ToArray(), "AAECAwQFBgcICQoLDA0O" }; + yield return new object[] { Enumerable.Range(0, 16).Select(i => (byte)i).ToArray(), "AAECAwQFBgcICQoLDA0ODw" }; + yield return new object[] { Enumerable.Range(0, 17).Select(i => (byte)i).ToArray(), "AAECAwQFBgcICQoLDA0ODxA" }; + yield return new object[] { Enumerable.Range(0, 18).Select(i => (byte)i).ToArray(), "AAECAwQFBgcICQoLDA0ODxAR" }; + yield return new object[] { Enumerable.Range(0, 19).Select(i => (byte)i).ToArray(), "AAECAwQFBgcICQoLDA0ODxAREg" }; + yield return new object[] { Enumerable.Range(0, 20).Select(i => (byte)i).ToArray(), "AAECAwQFBgcICQoLDA0ODxAREhM" }; + yield return new object[] { Enumerable.Range(0, 21).Select(i => (byte)i).ToArray(), "AAECAwQFBgcICQoLDA0ODxAREhMU" }; + yield return new object[] { Enumerable.Range(0, 22).Select(i => (byte)i).ToArray(), "AAECAwQFBgcICQoLDA0ODxAREhMUFQ" }; + yield return new object[] { Enumerable.Range(0, 23).Select(i => (byte)i).ToArray(), "AAECAwQFBgcICQoLDA0ODxAREhMUFRY" }; + yield return new object[] { Enumerable.Range(0, 24).Select(i => (byte)i).ToArray(), "AAECAwQFBgcICQoLDA0ODxAREhMUFRYX" }; + yield return new object[] { Enumerable.Range(0, 25).Select(i => (byte)i).ToArray(), "AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGA" }; + yield return new object[] { Enumerable.Range(0, 26).Select(i => (byte)i).ToArray(), "AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBk" }; + yield return new object[] { Enumerable.Range(0, 27).Select(i => (byte)i).ToArray(), "AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBka" }; + yield return new object[] { Enumerable.Range(0, 28).Select(i => (byte)i).ToArray(), "AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGw" }; + yield return new object[] { Enumerable.Range(0, 29).Select(i => (byte)i).ToArray(), "AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxw" }; + yield return new object[] { Enumerable.Range(0, 30).Select(i => (byte)i).ToArray(), "AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwd" }; + yield return new object[] { Enumerable.Range(0, 31).Select(i => (byte)i).ToArray(), "AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHg" }; + yield return new object[] { Enumerable.Range(0, 32).Select(i => (byte)i).ToArray(), "AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8" }; + yield return new object[] { Enumerable.Range(0, 33).Select(i => (byte)i).ToArray(), "AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8g" }; + yield return new object[] { Enumerable.Range(0, 34).Select(i => (byte)i).ToArray(), "AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8gIQ" }; + yield return new object[] { Enumerable.Range(0, 35).Select(i => (byte)i).ToArray(), "AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8gISI" }; + yield return new object[] { Enumerable.Range(0, 36).Select(i => (byte)i).ToArray(), "AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8gISIj" }; + yield return new object[] { Enumerable.Range(0, 37).Select(i => (byte)i).ToArray(), "AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8gISIjJA" }; + yield return new object[] { Enumerable.Range(0, 38).Select(i => (byte)i).ToArray(), "AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8gISIjJCU" }; + yield return new object[] { Enumerable.Range(0, 39).Select(i => (byte)i).ToArray(), "AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8gISIjJCUm" }; + yield return new object[] { Enumerable.Range(0, 40).Select(i => (byte)i).ToArray(), "AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8gISIjJCUmJw" }; + yield return new object[] { Enumerable.Range(0, 41).Select(i => (byte)i).ToArray(), "AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8gISIjJCUmJyg" }; + yield return new object[] { Enumerable.Range(0, 42).Select(i => (byte)i).ToArray(), "AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8gISIjJCUmJygp" }; + yield return new object[] { Enumerable.Range(0, 43).Select(i => (byte)i).ToArray(), "AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8gISIjJCUmJygpKg" }; + yield return new object[] { Enumerable.Range(0, 44).Select(i => (byte)i).ToArray(), "AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8gISIjJCUmJygpKis" }; + yield return new object[] { Enumerable.Range(0, 45).Select(i => (byte)i).ToArray(), "AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8gISIjJCUmJygpKiss" }; + yield return new object[] { Enumerable.Range(0, 46).Select(i => (byte)i).ToArray(), "AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8gISIjJCUmJygpKissLQ" }; + yield return new object[] { Enumerable.Range(0, 47).Select(i => (byte)i).ToArray(), "AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8gISIjJCUmJygpKissLS4" }; + yield return new object[] { Enumerable.Range(0, 48).Select(i => (byte)i).ToArray(), "AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8gISIjJCUmJygpKissLS4v" }; + yield return new object[] { Enumerable.Range(0, 49).Select(i => (byte)i).ToArray(), "AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8gISIjJCUmJygpKissLS4vMA" }; + yield return new object[] { Enumerable.Range(0, 50).Select(i => (byte)i).ToArray(), "AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8gISIjJCUmJygpKissLS4vMDE" }; + yield return new object[] { Enumerable.Range(0, 51).Select(i => (byte)i).ToArray(), "AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8gISIjJCUmJygpKissLS4vMDEy" }; + yield return new object[] { Enumerable.Range(0, 52).Select(i => (byte)i).ToArray(), "AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8gISIjJCUmJygpKissLS4vMDEyMw" }; + yield return new object[] { Enumerable.Range(0, 53).Select(i => (byte)i).ToArray(), "AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8gISIjJCUmJygpKissLS4vMDEyMzQ" }; + yield return new object[] { Enumerable.Range(0, 54).Select(i => (byte)i).ToArray(), "AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8gISIjJCUmJygpKissLS4vMDEyMzQ1" }; + yield return new object[] { Enumerable.Range(0, 55).Select(i => (byte)i).ToArray(), "AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8gISIjJCUmJygpKissLS4vMDEyMzQ1Ng" }; + yield return new object[] { Enumerable.Range(0, 56).Select(i => (byte)i).ToArray(), "AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8gISIjJCUmJygpKissLS4vMDEyMzQ1Njc" }; + yield return new object[] { Enumerable.Range(0, 57).Select(i => (byte)i).ToArray(), "AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8gISIjJCUmJygpKissLS4vMDEyMzQ1Njc4" }; + yield return new object[] { Enumerable.Range(0, 58).Select(i => (byte)i).ToArray(), "AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8gISIjJCUmJygpKissLS4vMDEyMzQ1Njc4OQ" }; + yield return new object[] { Enumerable.Range(0, 59).Select(i => (byte)i).ToArray(), "AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8gISIjJCUmJygpKissLS4vMDEyMzQ1Njc4OTo" }; + yield return new object[] { Enumerable.Range(0, 60).Select(i => (byte)i).ToArray(), "AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8gISIjJCUmJygpKissLS4vMDEyMzQ1Njc4OTo7" }; + yield return new object[] { Enumerable.Range(0, 61).Select(i => (byte)i).ToArray(), "AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8gISIjJCUmJygpKissLS4vMDEyMzQ1Njc4OTo7PA" }; + yield return new object[] { Enumerable.Range(0, 62).Select(i => (byte)i).ToArray(), "AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8gISIjJCUmJygpKissLS4vMDEyMzQ1Njc4OTo7PD0" }; + yield return new object[] { Enumerable.Range(0, 63).Select(i => (byte)i).ToArray(), "AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8gISIjJCUmJygpKissLS4vMDEyMzQ1Njc4OTo7PD0-" }; + yield return new object[] { Encoding.Unicode.GetBytes("aaaabbbbccccdddddddeeeeeaaaabbbbccccdddddddeeeeeaaaabbbbccccdd"), "YQBhAGEAYQBiAGIAYgBiAGMAYwBjAGMAZABkAGQAZABkAGQAZABlAGUAZQBlAGUAYQBhAGEAYQBiAGIAYgBiAGMAYwBjAGMAZABkAGQAZABkAGQAZABlAGUAZQBlAGUAYQBhAGEAYQBiAGIAYgBiAGMAYwBjAGMAZABkAA" }; + yield return new object[] { Encoding.Unicode.GetBytes("vbnmbbbbccccdddddddeeeeeaaaabbbbccccdddddddeeeeeaaaabbbbccccddx"), "dgBiAG4AbQBiAGIAYgBiAGMAYwBjAGMAZABkAGQAZABkAGQAZABlAGUAZQBlAGUAYQBhAGEAYQBiAGIAYgBiAGMAYwBjAGMAZABkAGQAZABkAGQAZABlAGUAZQBlAGUAYQBhAGEAYQBiAGIAYgBiAGMAYwBjAGMAZABkAHgA" }; + yield return new object[] { Encoding.Unicode.GetBytes("rrrrbbbbccccdddddddeeeeeaaaabbbbccccdddddddeeeeeaaaabbbbccccdd\0"), "cgByAHIAcgBiAGIAYgBiAGMAYwBjAGMAZABkAGQAZABkAGQAZABlAGUAZQBlAGUAYQBhAGEAYQBiAGIAYgBiAGMAYwBjAGMAZABkAGQAZABkAGQAZABlAGUAZQBlAGUAYQBhAGEAYQBiAGIAYgBiAGMAYwBjAGMAZABkAAAA" }; + yield return new object[] { Encoding.Unicode.GetBytes("uuuubbbbccccdddddddeeeeeaaaabbbbccccdddddddeeeeeaaaabbbbccccdd\0feffe"), "dQB1AHUAdQBiAGIAYgBiAGMAYwBjAGMAZABkAGQAZABkAGQAZABlAGUAZQBlAGUAYQBhAGEAYQBiAGIAYgBiAGMAYwBjAGMAZABkAGQAZABkAGQAZABlAGUAZQBlAGUAYQBhAGEAYQBiAGIAYgBiAGMAYwBjAGMAZABkAAAAZgBlAGYAZgBlAA" }; + yield return new object[] { Encoding.Unicode.GetBytes("kkkkkbbbbccccdddddddeeeeeaaaabbbbccccdddddddeeeeeaaaabbbbccccddx\u043F\u0440\u0438\u0432\u0435\u0442\u043C\u0438\u0440\u4F60\u597D\u4E16\u754C"), "awBrAGsAawBrAGIAYgBiAGIAYwBjAGMAYwBkAGQAZABkAGQAZABkAGUAZQBlAGUAZQBhAGEAYQBhAGIAYgBiAGIAYwBjAGMAYwBkAGQAZABkAGQAZABkAGUAZQBlAGUAZQBhAGEAYQBhAGIAYgBiAGIAYwBjAGMAYwBkAGQAeAA_BEAEOAQyBDUEQgQ8BDgEQARgT31ZFk5MdQ" }; + yield return new object[] { Encoding.Unicode.GetBytes(",,,,bbbbccccdddddddeeeeeaaaabbbbccccdddddddeeeeeaaaabbbbccccddx\u043F\u0440\u0438\u0432\u0435\u0442\u043C\u0438\u0440\u4F60\u597D\u4E16\u754Cddddeeeeea"), "LAAsACwALABiAGIAYgBiAGMAYwBjAGMAZABkAGQAZABkAGQAZABlAGUAZQBlAGUAYQBhAGEAYQBiAGIAYgBiAGMAYwBjAGMAZABkAGQAZABkAGQAZABlAGUAZQBlAGUAYQBhAGEAYQBiAGIAYgBiAGMAYwBjAGMAZABkAHgAPwRABDgEMgQ1BEIEPAQ4BEAEYE99WRZOTHVkAGQAZABkAGUAZQBlAGUAZQBhAA" }; + yield return new object[] { Encoding.Unicode.GetBytes("____bbbbccccdddddddeeeeeaaaabbbbccccdddddddeeeeeaaaabbbbccccddaaaabbbbccccdddddddeeeeeaaaabbbbccccdcccd"), "XwBfAF8AXwBiAGIAYgBiAGMAYwBjAGMAZABkAGQAZABkAGQAZABlAGUAZQBlAGUAYQBhAGEAYQBiAGIAYgBiAGMAYwBjAGMAZABkAGQAZABkAGQAZABlAGUAZQBlAGUAYQBhAGEAYQBiAGIAYgBiAGMAYwBjAGMAZABkAGEAYQBhAGEAYgBiAGIAYgBjAGMAYwBjAGQAZABkAGQAZABkAGQAZQBlAGUAZQBlAGEAYQBhAGEAYgBiAGIAYgBjAGMAYwBjAGQAYwBjAGMAZAA" }; + yield return new object[] { Encoding.Unicode.GetBytes(" bbbbccccdddddddeeeeeaaaabbbbccccdddddddeeeeeaaaabbbbccccddaaaabbbbccccdddddddeeeeeaaaabbbbccccdddddddeeeeeaaaabbbbccccd"), "IAAgACAAIABiAGIAYgBiAGMAYwBjAGMAZABkAGQAZABkAGQAZABlAGUAZQBlAGUAYQBhAGEAYQBiAGIAYgBiAGMAYwBjAGMAZABkAGQAZABkAGQAZABlAGUAZQBlAGUAYQBhAGEAYQBiAGIAYgBiAGMAYwBjAGMAZABkAGEAYQBhAGEAYgBiAGIAYgBjAGMAYwBjAGQAZABkAGQAZABkAGQAZQBlAGUAZQBlAGEAYQBhAGEAYgBiAGIAYgBjAGMAYwBjAGQAZABkAGQAZABkAGQAZQBlAGUAZQBlAGEAYQBhAGEAYgBiAGIAYgBjAGMAYwBjAGQA" }; + yield return new object[] { Encoding.Unicode.GetBytes("\0\0bbbbccccdddddddeeeeeaaaabbbbccccdddddddeeeeeaaaabbbbccccddaaaabbbbccccdddddddeeeeeaaaabbbbccccdddddddeeeeeaaaabbbbccccddx"), "AAAAAGIAYgBiAGIAYwBjAGMAYwBkAGQAZABkAGQAZABkAGUAZQBlAGUAZQBhAGEAYQBhAGIAYgBiAGIAYwBjAGMAYwBkAGQAZABkAGQAZABkAGUAZQBlAGUAZQBhAGEAYQBhAGIAYgBiAGIAYwBjAGMAYwBkAGQAYQBhAGEAYQBiAGIAYgBiAGMAYwBjAGMAZABkAGQAZABkAGQAZABlAGUAZQBlAGUAYQBhAGEAYQBiAGIAYgBiAGMAYwBjAGMAZABkAGQAZABkAGQAZABlAGUAZQBlAGUAYQBhAGEAYQBiAGIAYgBiAGMAYwBjAGMAZABkAHgA" }; + yield return new object[] { Encoding.Unicode.GetBytes("eeeebbbbccccdddddddeeeeeaaaabbbbccccdddddddeeeeeaaaabbbbccccdgggdaaaabbbbccccdddddddeeeeeaaaabbbbccccdddddddeeeeeaaaabbbbccccddx"), "ZQBlAGUAZQBiAGIAYgBiAGMAYwBjAGMAZABkAGQAZABkAGQAZABlAGUAZQBlAGUAYQBhAGEAYQBiAGIAYgBiAGMAYwBjAGMAZABkAGQAZABkAGQAZABlAGUAZQBlAGUAYQBhAGEAYQBiAGIAYgBiAGMAYwBjAGMAZABnAGcAZwBkAGEAYQBhAGEAYgBiAGIAYgBjAGMAYwBjAGQAZABkAGQAZABkAGQAZQBlAGUAZQBlAGEAYQBhAGEAYgBiAGIAYgBjAGMAYwBjAGQAZABkAGQAZABkAGQAZQBlAGUAZQBlAGEAYQBhAGEAYgBiAGIAYgBjAGMAYwBjAGQAZAB4AA" }; + } + + [Theory] + [InlineData("\u5948cz_T", 0, 0)] // scalar code-path + [InlineData("z_Ta123\u5948", 4, 3)] + [InlineData("\u5948z_T-H7sqEkerqMweH1uSw==", 0, 0)] // Vector128 code-path + [InlineData("z_T-H7sqEkerqMweH1uSw\u5948==", 20, 15)] + [InlineData("\u5948z_T-H7sqEkerqMweH1uSw1a5ebaAF9xa8B0ze1wet4epo==", 0, 0)] // Vector256 / AVX code-path + [InlineData("z_T-H7sqEkerqMweH1uSw1a5ebaAF9xa8B0ze1wet4epo\u5948==", 44, 33)] + [InlineData("\u5948z_T-H7sqEkerqMweH1uSw1a5ebaAF9xa8B0ze1wet4epo01234567890123456789012345678901234567890123456789==", 0, 0)] // Vector512 / Avx512Vbmi code-path + [InlineData("z_T-H7sqEkerqMweH1uSw1a5ebaAF9xa8B0ze1wet4epo01234567890123456789012345678901234567890123456789\u5948==", 92, 69)] + public void BasicDecodingNonAsciiInputInvalid(string inputString, int expectedConsumed, int expectedWritten) + { + Span source = inputString.ToArray(); + Span decodedBytes = new byte[Base64Url.GetMaxDecodedLength(source.Length)]; + + Assert.Equal(OperationStatus.InvalidData, Base64Url.DecodeFromChars(source, decodedBytes, out int consumed, out int decodedByteCount)); + Assert.Equal(expectedConsumed, consumed); + Assert.Equal(expectedWritten, decodedByteCount); + } + + + [Theory] + [MemberData(nameof(EncodeToStringTests_TestData))] + public static void EncodeToStringTests(byte[] inputBytes, string expectedBase64) + { + Assert.Equal(expectedBase64, Base64Url.EncodeToString(inputBytes)); + Span chars = new char[Base64Url.GetEncodedLength(inputBytes.Length)]; + Assert.Equal(OperationStatus.Done, Base64Url.EncodeToChars(inputBytes, chars, out int _, out int charsWritten)); + Assert.Equal(expectedBase64, chars.Slice(0, charsWritten)); + } + + [Fact] + public void EncodingOutputTooSmall() + { + for (int numBytes = 4; numBytes < 20; numBytes++) + { + byte[] source = new byte[numBytes]; + Base64TestHelper.InitializeBytes(source, numBytes); + int expectedConsumed = 3; + char[] encodedBytes = new char[4]; + + Assert.Equal(OperationStatus.DestinationTooSmall, Base64Url.EncodeToChars(source, encodedBytes, out int consumed, out int written)); + Assert.Equal(expectedConsumed, consumed); + Assert.Equal(encodedBytes.Length, written); + Assert.True(source.AsSpan().Slice(0, consumed).SequenceEqual(Base64Url.DecodeFromChars(encodedBytes))); + + Assert.Throws("destination", () => Base64Url.EncodeToChars(source, encodedBytes)); + } + } + + [Fact] + public static void Roundtrip() + { + string input = "test"; + Verify(input, result => + { + Assert.Equal(3, result.Length); + + uint triplet = (uint)((result[0] << 16) | (result[1] << 8) | result[2]); + Assert.Equal(45, triplet >> 18); // 't' + Assert.Equal(30, (triplet << 14) >> 26); // 'e' + Assert.Equal(44, (triplet << 20) >> 26); // 's' + Assert.Equal(45, (triplet << 26) >> 26); // 't' + + Assert.Equal(input, Base64Url.EncodeToString(result)); + }); + } + + [Fact] + public static void PartialRoundtripWithoutPadding() + { + string input = "ab"; + Verify(input, result => + { + Assert.Equal(1, result.Length); + + string roundtrippedString = Base64Url.EncodeToString(result); + Assert.NotEqual(input, roundtrippedString); + Assert.Equal(input[0], roundtrippedString[0]); + }); + } + + [Fact] + public static void PartialRoundtripWithPadding2() + { + string input = "ab=="; + Verify(input, result => + { + Assert.Equal(1, result.Length); + + string roundtrippedString = Base64Url.EncodeToString(result); + Assert.NotEqual(input, roundtrippedString); + Assert.Equal(input[0], roundtrippedString[0]); + }); + } + + [Fact] + public static void PartialRoundtripWithPadding1() + { + string input = "789="; + Verify(input, result => + { + Assert.Equal(2, result.Length); + + string roundtrippedString = Base64Url.EncodeToString(result); + Assert.NotEqual(input, roundtrippedString); + Assert.Equal(input[0], roundtrippedString[0]); + Assert.Equal(input[1], roundtrippedString[1]); + }); + } + + [Fact] + public static void ParseWithWhitespace() + { + Verify("abc= \t \r\n ="); + } + + [Fact] + public static void RoundtripWithWhitespace2() + { + string input = "abc= \t\n\t\r "; + VerifyRoundtrip(input, "abc"); + } + + [Fact] + public static void RoundtripWithWhitespace3() + { + string input = " \r\n\t abc = \t\n\t\r "; + VerifyRoundtrip(input, "abc"); + } + + [Fact] + public static void RoundtripWithWhitespace4() + { + string expected = "test"; + string input = expected.Insert(1, new string(' ', 17)).PadLeft(31, ' ').PadRight(12, ' '); + VerifyRoundtrip(input, expected, expectedLengthBytes: 3); + } + + [Fact] + public static void RoundtripLargeString() + { + string input = new string('a', 10000); + VerifyRoundtrip(input, input); + } + + [Fact] + public static void InvalidInput() + { + // Input must not contain invalid characters + VerifyInvalidInput("2+34"); + VerifyInvalidInput("23/4"); + + // Input must not contain 3 or more padding characters in a row + VerifyInvalidInput("a==="); + VerifyInvalidInput("abc====="); + VerifyInvalidInput("a===\r \t \n"); + + // Input must not contain padding characters in the middle of the string + VerifyInvalidInput("No=n"); + VerifyInvalidInput("abcd====abcd"); + + // Input must not contain extra trailing padding characters + VerifyInvalidInput("="); + VerifyInvalidInput("abc==="); + } + + [Fact] + public static void ExtraPaddingCharacter() + { + VerifyInvalidInput("abcdxyz=" + "="); + } + + [Fact] + public static void InvalidCharactersInInput() + { + ushort[] invalidChars = { 30122, 62608, 13917, 19498, 2473, 40845, 35988, 2281, 51246, 36372 }; + + foreach (char ch in invalidChars) + { + var builder = new StringBuilder("abc"); + builder.Insert(1, ch); + VerifyInvalidInput(builder.ToString()); + } + } + + private static void VerifyRoundtrip(string input, string expected = null, int? expectedLengthBytes = null) + { + if (expected == null) + { + expected = input; + } + + Verify(input, result => + { + if (expectedLengthBytes.HasValue) + { + Assert.Equal(expectedLengthBytes.Value, result.Length); + } + Assert.Equal(expected, Base64Url.EncodeToString(result)); + }); + } + + private static void VerifyInvalidInput(string input) + { + char[] inputChars = input.ToCharArray(); + + Assert.Throws(() => Base64Url.DecodeFromChars(input)); + } + + private static void Verify(string input, Action action = null) + { + if (action != null) + { + action(Base64Url.DecodeFromChars(input)); + } + } + + [Fact] + public static void Base64_AllMethodsRoundtripConsistently() + { + var r = new Random(42); + for (int length = 0; length < 128; length++) + { + var original = new byte[length]; + r.NextBytes(original); + + string encodedString = Base64Url.EncodeToString(original); + + char[] encodedArray = new char[encodedString.Length]; + Assert.Equal(OperationStatus.Done, Base64Url.EncodeToChars(original, encodedArray, out _, out int charsWritten)); + Assert.Equal(encodedArray.Length, charsWritten); + AssertExtensions.SequenceEqual(encodedString, encodedArray); + + char[] encodedSpan = new char[encodedString.Length]; + Assert.True(Base64Url.TryEncodeToChars(original, encodedSpan, out charsWritten)); + Assert.Equal(encodedSpan.Length, charsWritten); + AssertExtensions.SequenceEqual(encodedString, encodedSpan); + + AssertExtensions.SequenceEqual(original, Base64Url.DecodeFromChars(encodedString)); + Span decodedBytes = new byte[original.Length]; + int decoded = Base64Url.DecodeFromChars(encodedArray, decodedBytes); + Assert.Equal(original.Length, decoded); + AssertExtensions.SequenceEqual(original, decodedBytes); + + byte[] actualBytes = new byte[original.Length]; + Assert.True(Base64Url.TryDecodeFromChars(encodedSpan, actualBytes, out int bytesWritten)); + Assert.Equal(original.Length, bytesWritten); + AssertExtensions.SequenceEqual(original, actualBytes); + } + } + + [Theory] + [MemberData(nameof(Base64TestData))] + public static void TryDecodeFromChars(string encodedAsString, byte[] expected) + { + char[] encoded = encodedAsString.ToCharArray(); + if (expected == null) + { + byte[] actual = new byte[Base64Url.GetMaxDecodedLength(encodedAsString.Length)]; + Assert.Throws(() => Base64Url.TryDecodeFromChars(encoded, actual, out _)); + } + else + { + // Destination buffer size enough + { + Span actual = new byte[Base64Url.GetMaxDecodedLength(encodedAsString.Length)]; + Assert.True(Base64Url.TryDecodeFromChars(encoded, actual, out int bytesWritten)); + Assert.Equal(expected, actual.Slice(0, bytesWritten)); + Assert.Equal(expected.Length, bytesWritten); + } + + // Buffer too short + if (expected.Length != 0) + { + byte[] actual = new byte[expected.Length - 1]; + Assert.False(Base64Url.TryDecodeFromChars(encoded, actual, out int bytesWritten)); + Assert.Equal(0, bytesWritten); + } + } + } + + public static IEnumerable Base64TestData + { + get + { + foreach ((string bse64UrlString, byte[] expectedArray) tuple in Base64TestDataSeed) + { + yield return new object[] { tuple.bse64UrlString, tuple.expectedArray }; + yield return new object[] { InsertSpaces(tuple.bse64UrlString, 1), tuple.expectedArray }; + yield return new object[] { InsertSpaces(tuple.bse64UrlString, 4), tuple.expectedArray }; + } + } + } + + public static IEnumerable<(string, byte[])> Base64TestDataSeed + { + get + { + // Empty + yield return ("", Array.Empty()); + + // All whitespace characters. + yield return (" \t\r\n", Array.Empty()); + + // Invalid Input length + yield return ("A", null); + + // Cannot continue past end pad + yield return ("AAA=BBBB", null); + yield return ("AA==BBBB", null); + + // Cannot have more than two end pads + yield return ("A===", null); + yield return ("====", null); + + // Verify negative entries of charmap. + for (int i = 0; i < 256; i++) + { + char c = (char)i; + if (!IsValidBase64Char(c)) + { + string text = new string(c, 1) + "AAA"; + yield return (text, null); + } + } + + // Verify >255 character handling. + string largerThanByte = new string((char)256, 1); + yield return (largerThanByte + "AAA", null); + yield return ("A" + largerThanByte + "AA", null); + yield return ("AA" + largerThanByte + "A", null); + yield return ("AAA" + largerThanByte, null); + yield return ("AAAA" + largerThanByte + "AAA", null); + yield return ("AAAA" + "A" + largerThanByte + "AA", null); + yield return ("AAAA" + "AA" + largerThanByte + "A", null); + yield return ("AAAA" + "AAA" + largerThanByte, null); + + // Verify positive entries of charmap. + yield return ("-A==", new byte[] { 0xf8 }); + yield return ("_A=", new byte[] { 0xfc }); + yield return ("0A==", new byte[] { 0xd0 }); + yield return ("1A==", new byte[] { 0xd4 }); + yield return ("2A==", new byte[] { 0xd8 }); + yield return ("3A==", new byte[] { 0xdc }); + yield return ("4A", new byte[] { 0xe0 }); + yield return ("5A=", new byte[] { 0xe4 }); + yield return ("6A==", new byte[] { 0xe8 }); + yield return ("7A==", new byte[] { 0xec }); + yield return ("8A", new byte[] { 0xf0 }); + yield return ("9A=", new byte[] { 0xf4 }); + yield return ("AA=", new byte[] { 0x00 }); + yield return ("BA", new byte[] { 0x04 }); + yield return ("CA", new byte[] { 0x08 }); + yield return ("DA==", new byte[] { 0x0c }); + yield return ("EA==", new byte[] { 0x10 }); + yield return ("FA==", new byte[] { 0x14 }); + yield return ("GA==", new byte[] { 0x18 }); + yield return ("HA==", new byte[] { 0x1c }); + yield return ("IA==", new byte[] { 0x20 }); + yield return ("JA==", new byte[] { 0x24 }); + yield return ("KA==", new byte[] { 0x28 }); + yield return ("LA==", new byte[] { 0x2c }); + yield return ("MA==", new byte[] { 0x30 }); + yield return ("NA==", new byte[] { 0x34 }); + yield return ("OA==", new byte[] { 0x38 }); + yield return ("PA==", new byte[] { 0x3c }); + yield return ("QA==", new byte[] { 0x40 }); + yield return ("RA==", new byte[] { 0x44 }); + yield return ("SA==", new byte[] { 0x48 }); + yield return ("TA==", new byte[] { 0x4c }); + yield return ("UA==", new byte[] { 0x50 }); + yield return ("VA==", new byte[] { 0x54 }); + yield return ("WA==", new byte[] { 0x58 }); + yield return ("XA==", new byte[] { 0x5c }); + yield return ("YA==", new byte[] { 0x60 }); + yield return ("ZA==", new byte[] { 0x64 }); + yield return ("aA==", new byte[] { 0x68 }); + yield return ("bA==", new byte[] { 0x6c }); + yield return ("cA==", new byte[] { 0x70 }); + yield return ("dA==", new byte[] { 0x74 }); + yield return ("eA==", new byte[] { 0x78 }); + yield return ("fA==", new byte[] { 0x7c }); + yield return ("gA==", new byte[] { 0x80 }); + yield return ("hA==", new byte[] { 0x84 }); + yield return ("iA==", new byte[] { 0x88 }); + yield return ("jA==", new byte[] { 0x8c }); + yield return ("kA==", new byte[] { 0x90 }); + yield return ("lA==", new byte[] { 0x94 }); + yield return ("mA==", new byte[] { 0x98 }); + yield return ("nA==", new byte[] { 0x9c }); + yield return ("oA==", new byte[] { 0xa0 }); + yield return ("pA==", new byte[] { 0xa4 }); + yield return ("qA==", new byte[] { 0xa8 }); + yield return ("rA==", new byte[] { 0xac }); + yield return ("sA==", new byte[] { 0xb0 }); + yield return ("tA==", new byte[] { 0xb4 }); + yield return ("uA==", new byte[] { 0xb8 }); + yield return ("vA==", new byte[] { 0xbc }); + yield return ("wA==", new byte[] { 0xc0 }); + yield return ("xA==", new byte[] { 0xc4 }); + yield return ("yA==", new byte[] { 0xc8 }); + yield return ("zA==", new byte[] { 0xcc }); + } + } + + private static string InsertSpaces(string text, int period) + { + StringBuilder sb = new StringBuilder(); + for (int i = 0; i < text.Length; i++) + { + if ((i % period) == 0) + { + sb.Append(" "); + } + sb.Append(text[i]); + } + sb.Append(" "); + return sb.ToString(); + } + + private static bool IsValidBase64Char(char c) + { + return char.IsAsciiLetterOrDigit(c) || c is '-' or '_' || char.IsWhiteSpace(c); + } + } +} diff --git a/src/libraries/System.Memory/tests/Base64Url/Base64UrlValidationUnitTests.cs b/src/libraries/System.Memory/tests/Base64Url/Base64UrlValidationUnitTests.cs new file mode 100644 index 0000000000000..b2cb650fa6e3f --- /dev/null +++ b/src/libraries/System.Memory/tests/Base64Url/Base64UrlValidationUnitTests.cs @@ -0,0 +1,351 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Linq; +using System.Text; +using Xunit; + +namespace System.Buffers.Text.Tests +{ + public class Base64UrlValidationUnitTests : Base64TestBase + { + [Theory] + [InlineData("==")] + [InlineData("-%")] + [InlineData("A=")] + [InlineData("A==")] + [InlineData("4%%")] + [InlineData(" A==")] + [InlineData("AAAAA ==")] + [InlineData("\tLLLL\t=\r")] + [InlineData("6066=")] + public void BasicValidationEdgeCaseScenario(string base64UrlText) + { + Assert.False(Base64Url.IsValid(base64UrlText, out int decodedLength)); + Assert.Equal(0, decodedLength); + } + + [Fact] + public void BasicValidationBytes() + { + var rnd = new Random(42); + for (int i = 0; i < 10; i++) + { + int numBytes; + do + { + numBytes = rnd.Next(100, 1000 * 1000); + } while (numBytes % 4 == 1); // ensure we have a valid length + + Span source = new byte[numBytes]; + Base64TestHelper.InitializeUrlDecodableBytes(source, numBytes); + + Assert.True(Base64Url.IsValid(source)); + Assert.True(Base64Url.IsValid(source, out int decodedLength)); + Assert.True(decodedLength > 0); + } + } + + [Fact] + public void BasicValidationChars() + { + var rnd = new Random(42); + for (int i = 0; i < 10; i++) + { + int numBytes; + do + { + numBytes = rnd.Next(100, 1000 * 1000); + } while (numBytes % 4 == 1); // ensure we have a valid length + + Span source = new byte[numBytes]; + Base64TestHelper.InitializeUrlDecodableBytes(source, numBytes); + Span chars = source + .ToArray() + .Select(Convert.ToChar) + .ToArray() + .AsSpan(); + + Assert.True(Base64Url.IsValid(chars)); + Assert.True(Base64Url.IsValid(chars, out int decodedLength)); + Assert.True(decodedLength > 0); + } + } + + [Fact] + public void BasicValidationInvalidInputLengthBytes() + { + var rnd = new Random(42); + for (int i = 0; i < 10; i++) + { + int numBytes; + do + { + numBytes = rnd.Next(100, 1000 * 1000); + } while (numBytes % 4 != 1); // only remainder of 1 is invalid length + + Span source = new byte[numBytes]; + Base64TestHelper.InitializeUrlDecodableBytes(source, numBytes); + Assert.False(Base64Url.IsValid(source)); + Assert.False(Base64Url.IsValid(source, out int decodedLength)); + Assert.Equal(0, decodedLength); + } + } + + [Fact] + public void BasicValidationInvalidInputLengthChars() + { + var rnd = new Random(42); + for (int i = 0; i < 10; i++) + { + int numBytes; + do + { + numBytes = rnd.Next(100, 1000 * 1000); + } while (numBytes % 4 != 1); // ensure we have a invalid length + + Span source = new char[numBytes]; + Base64TestHelper.InitializeUrlDecodableChars(source, numBytes); + + Assert.False(Base64Url.IsValid(source)); + Assert.False(Base64Url.IsValid(source, out int decodedLength)); + Assert.Equal(0, decodedLength); + } + } + + [Fact] + public void ValidateEmptySpanBytes() + { + Span source = Span.Empty; + + Assert.True(Base64Url.IsValid(source)); + Assert.True(Base64Url.IsValid(source, out int decodedLength)); + Assert.Equal(0, decodedLength); + } + + [Fact] + public void ValidateEmptySpanChars() + { + Span source = Span.Empty; + + Assert.True(Base64Url.IsValid(source)); + Assert.True(Base64Url.IsValid(source, out int decodedLength)); + Assert.Equal(0, decodedLength); + } + + [Fact] + public void ValidateGuidBytes() + { + Span source = new byte[22]; + Span decodedBytes = Guid.NewGuid().ToByteArray(); + Base64Url.EncodeToUtf8(decodedBytes, source, out int _, out int _); + + Assert.True(Base64Url.IsValid(source)); + Assert.True(Base64Url.IsValid(source, out int decodedLength)); + Assert.True(decodedLength > 0); + } + + [Fact] + public void ValidateGuidChars() + { + Span source = new byte[22]; + Span decodedBytes = Guid.NewGuid().ToByteArray(); + Base64Url.EncodeToUtf8(decodedBytes, source, out int _, out int _); + Span chars = source + .ToArray() + .Select(Convert.ToChar) + .ToArray() + .AsSpan(); + + Assert.True(Base64Url.IsValid(chars)); + Assert.True(Base64Url.IsValid(chars, out int decodedLength)); + Assert.True(decodedLength > 0); + } + + [Theory] + [MemberData(nameof(ValidBase64Strings_WithCharsThatMustBeIgnored))] + public void ValidateBytesIgnoresCharsToBeIgnoredBytes(string utf8WithByteToBeIgnored, byte[] expectedBytes) + { + byte[] utf8BytesWithByteToBeIgnored = UTF8Encoding.UTF8.GetBytes(utf8WithByteToBeIgnored); + + Assert.True(Base64Url.IsValid(utf8BytesWithByteToBeIgnored)); + Assert.True(Base64Url.IsValid(utf8BytesWithByteToBeIgnored, out int decodedLength)); + Assert.Equal(expectedBytes.Length, decodedLength); + } + + [Theory] + [MemberData(nameof(ValidBase64Strings_WithCharsThatMustBeIgnored))] + public void ValidateBytesIgnoresCharsToBeIgnoredChars(string utf8WithByteToBeIgnored, byte[] expectedBytes) + { + ReadOnlySpan utf8BytesWithByteToBeIgnored = utf8WithByteToBeIgnored.ToArray(); + + Assert.True(Base64Url.IsValid(utf8BytesWithByteToBeIgnored)); + Assert.True(Base64Url.IsValid(utf8BytesWithByteToBeIgnored, out int decodedLength)); + Assert.Equal(expectedBytes.Length, decodedLength); + } + + [Theory] + [MemberData(nameof(StringsOnlyWithCharsToBeIgnored))] + public void ValidateWithOnlyCharsToBeIgnoredBytes(string utf8WithByteToBeIgnored) + { + byte[] utf8BytesWithByteToBeIgnored = UTF8Encoding.UTF8.GetBytes(utf8WithByteToBeIgnored); + + Assert.True(Base64Url.IsValid(utf8BytesWithByteToBeIgnored)); + Assert.True(Base64Url.IsValid(utf8BytesWithByteToBeIgnored, out int decodedLength)); + Assert.Equal(0, decodedLength); + } + + [Theory] + [MemberData(nameof(StringsOnlyWithCharsToBeIgnored))] + public void ValidateWithOnlyCharsToBeIgnoredChars(string utf8WithByteToBeIgnored) + { + ReadOnlySpan utf8BytesWithByteToBeIgnored = utf8WithByteToBeIgnored.ToArray(); + + Assert.True(Base64Url.IsValid(utf8BytesWithByteToBeIgnored)); + Assert.True(Base64Url.IsValid(utf8BytesWithByteToBeIgnored, out int decodedLength)); + Assert.Equal(0, decodedLength); + } + + [Theory] + [InlineData("YQ==", 1)] + [InlineData("YWI=", 2)] + [InlineData("YWJj", 3)] + [InlineData(" YWI=", 2)] + [InlineData("Y WI=", 2)] + [InlineData("YW I=", 2)] + [InlineData("YWI =", 2)] + [InlineData("YWI= ", 2)] + [InlineData(" YQ==", 1)] + [InlineData("Y Q==", 1)] + [InlineData("YQ ==", 1)] + [InlineData("YQ= =", 1)] + [InlineData("YQ== ", 1)] + [InlineData("YQ%%", 1)] + [InlineData("YWI%", 2)] + [InlineData("YW% ", 1)] + public void ValidateWithPaddingReturnsCorrectCountBytes(string utf8WithByteToBeIgnored, int expectedLength) + { + byte[] utf8BytesWithByteToBeIgnored = UTF8Encoding.UTF8.GetBytes(utf8WithByteToBeIgnored); + + Assert.True(Base64Url.IsValid(utf8BytesWithByteToBeIgnored)); + Assert.True(Base64Url.IsValid(utf8BytesWithByteToBeIgnored, out int decodedLength)); + Assert.Equal(expectedLength, decodedLength); + } + + [Theory] + [InlineData("YQ==", 1)] + [InlineData("YWI=", 2)] + [InlineData("YWJj", 3)] + [InlineData(" YWI=", 2)] + [InlineData("Y WI=", 2)] + [InlineData("YW I=", 2)] + [InlineData("YWI =", 2)] + [InlineData("YWI= ", 2)] + [InlineData(" YQ==", 1)] + [InlineData("Y Q==", 1)] + [InlineData("YQ ==", 1)] + [InlineData("YQ= =", 1)] + [InlineData("YQ== ", 1)] + [InlineData("YQ%%", 1)] + [InlineData("YWI%", 2)] + [InlineData("YW% ", 1)] + public void ValidateWithPaddingReturnsCorrectCountChars(string utf8WithByteToBeIgnored, int expectedLength) + { + ReadOnlySpan utf8BytesWithByteToBeIgnored = utf8WithByteToBeIgnored.ToArray(); + + Assert.True(Base64Url.IsValid(utf8BytesWithByteToBeIgnored)); + Assert.True(Base64Url.IsValid(utf8BytesWithByteToBeIgnored, out int decodedLength)); + Assert.Equal(expectedLength, decodedLength); + } + + [Theory] + [InlineData("YWJ", true, 2)] + [InlineData("YW", true, 1)] + [InlineData("Y", false, 0)] + public void SmallSizeBytes(string utf8Text, bool isValid, int expectedDecodedLength) + { + byte[] utf8BytesWithByteToBeIgnored = UTF8Encoding.UTF8.GetBytes(utf8Text); + + Assert.Equal(isValid, Base64Url.IsValid(utf8BytesWithByteToBeIgnored)); + Assert.Equal(isValid, Base64Url.IsValid(utf8BytesWithByteToBeIgnored, out int decodedLength)); + Assert.Equal(expectedDecodedLength, decodedLength); + } + + [Theory] + [InlineData("YWJ", true, 2)] + [InlineData("YW", true, 1)] + [InlineData("Y", false, 0)] + public void SmallSizeChars(string utf8Text, bool isValid, int expectedDecodedLength) + { + ReadOnlySpan utf8BytesWithByteToBeIgnored = utf8Text; + + Assert.Equal(isValid, Base64Url.IsValid(utf8BytesWithByteToBeIgnored)); + Assert.Equal(isValid, Base64Url.IsValid(utf8BytesWithByteToBeIgnored, out int decodedLength)); + Assert.Equal(expectedDecodedLength, decodedLength); + } + + [Theory] + [InlineData("YQ===")] + [InlineData("YQ=a=")] + [InlineData("YWI=a")] + [InlineData(" aYWI=a")] + [InlineData("a YWI=a")] + [InlineData("aY WI=a")] + [InlineData("aYW I=a")] + [InlineData("aYWI =a")] + [InlineData("aYWI= a")] + [InlineData("a YQ==a")] + [InlineData("aY Q==a")] + [InlineData("aYQ ==a")] + [InlineData("aYQ= =a")] + [InlineData("aYQ== a")] + [InlineData("aYQ==a ")] + [InlineData("YQ+a")] // plus invalid + [InlineData("/Qab")] // slash invalid + public void InvalidBase64UrlBytes(string utf8WithByteToBeIgnored) + { + byte[] utf8BytesWithByteToBeIgnored = UTF8Encoding.UTF8.GetBytes(utf8WithByteToBeIgnored); + + Assert.False(Base64Url.IsValid(utf8BytesWithByteToBeIgnored)); + Assert.False(Base64Url.IsValid(utf8BytesWithByteToBeIgnored, out int decodedLength)); + Assert.Equal(0, decodedLength); + } + + [Theory] + [InlineData("YQ===")] + [InlineData("YQ=a=")] + [InlineData("YWI=a")] + [InlineData("a YWI=a")] + [InlineData("aY WI=a")] + [InlineData("aYW I=a")] + [InlineData("aYWI =a")] + [InlineData("aYWI= a")] + [InlineData("a YQ==a")] + [InlineData("aY Q==a")] + [InlineData("aYQ ==a")] + [InlineData("aYQ= =a")] + [InlineData("aYQ== a")] + [InlineData("aYQ==a ")] + [InlineData("a")] + [InlineData(" a")] + [InlineData(" a")] + [InlineData(" a")] + [InlineData(" a")] + [InlineData("a ")] + [InlineData("a ")] + [InlineData("a ")] + [InlineData("a ")] + [InlineData(" a ")] + [InlineData(" a ")] + [InlineData(" a ")] + [InlineData(" a ")] + public void InvalidBase64UrlChars(string utf8WithByteToBeIgnored) + { + ReadOnlySpan utf8CharsWithCharToBeIgnored = utf8WithByteToBeIgnored; + + Assert.False(Base64Url.IsValid(utf8CharsWithCharToBeIgnored)); + Assert.False(Base64Url.IsValid(utf8CharsWithCharToBeIgnored, out int decodedLength)); + Assert.Equal(0, decodedLength); + } + } +} diff --git a/src/libraries/System.Memory/tests/System.Memory.Tests.csproj b/src/libraries/System.Memory/tests/System.Memory.Tests.csproj index 843d5e1b479ce..0dfd33b13837d 100644 --- a/src/libraries/System.Memory/tests/System.Memory.Tests.csproj +++ b/src/libraries/System.Memory/tests/System.Memory.Tests.csproj @@ -13,6 +13,10 @@ + + + + diff --git a/src/libraries/System.Private.CoreLib/src/System.Private.CoreLib.Shared.projitems b/src/libraries/System.Private.CoreLib/src/System.Private.CoreLib.Shared.projitems index ea91a44b9fcec..654ecbaf438fc 100644 --- a/src/libraries/System.Private.CoreLib/src/System.Private.CoreLib.Shared.projitems +++ b/src/libraries/System.Private.CoreLib/src/System.Private.CoreLib.Shared.projitems @@ -128,6 +128,9 @@ + + + diff --git a/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64.cs b/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64.cs index 4b1f597e5e8b5..c362220d30804 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64.cs @@ -3,13 +3,14 @@ using System.Diagnostics; using System.Runtime.CompilerServices; +using System.Runtime.Intrinsics; namespace System.Buffers.Text { public static partial class Base64 { [Conditional("DEBUG")] - private static unsafe void AssertRead(byte* src, byte* srcStart, int srcLength) + internal static unsafe void AssertRead(byte* src, byte* srcStart, int srcLength) { int vectorElements = Unsafe.SizeOf(); byte* readEnd = src + vectorElements; @@ -23,7 +24,7 @@ private static unsafe void AssertRead(byte* src, byte* srcStart, int sr } [Conditional("DEBUG")] - private static unsafe void AssertWrite(byte* dest, byte* destStart, int destLength) + internal static unsafe void AssertWrite(byte* dest, byte* destStart, int destLength) { int vectorElements = Unsafe.SizeOf(); byte* writeEnd = dest + vectorElements; @@ -35,5 +36,105 @@ private static unsafe void AssertWrite(byte* dest, byte* destStart, int Debug.Fail($"Write for {typeof(TVector)} is not within safe bounds. destIndex: {destIndex}, destLength: {destLength}"); } } + + [Conditional("DEBUG")] + internal static unsafe void AssertRead(ushort* src, ushort* srcStart, int srcLength) + { + int vectorElements = Unsafe.SizeOf(); + ushort* readEnd = src + vectorElements; + ushort* srcEnd = srcStart + srcLength; + + if (readEnd > srcEnd) + { + int srcIndex = (int)(src - srcStart); + Debug.Fail($"Read for {typeof(TVector)} is not within safe bounds. srcIndex: {srcIndex}, srcLength: {srcLength}"); + } + } + + [Conditional("DEBUG")] + internal static unsafe void AssertWrite(ushort* dest, ushort* destStart, int destLength) + { + int vectorElements = Unsafe.SizeOf(); + ushort* writeEnd = dest + vectorElements; + ushort* destEnd = destStart + destLength; + + if (writeEnd > destEnd) + { + int destIndex = (int)(dest - destStart); + Debug.Fail($"Write for {typeof(TVector)} is not within safe bounds. destIndex: {destIndex}, destLength: {destLength}"); + } + } + + internal interface IBase64Encoder where T : unmanaged + { + static abstract ReadOnlySpan EncodingMap { get; } + static abstract sbyte Avx2LutChar62 { get; } + static abstract sbyte Avx2LutChar63 { get; } + static abstract ReadOnlySpan AdvSimdLut4 { get; } + static abstract uint Ssse3AdvSimdLutE3 { get; } + static abstract int GetMaxSrcLength(int srcLength, int destLength); + static abstract int GetMaxEncodedLength(int srcLength); + static abstract uint GetInPlaceDestinationLength(int encodedLength, int leftOver); + static abstract unsafe void EncodeOneOptionallyPadTwo(byte* oneByte, T* dest, ref byte encodingMap); + static abstract unsafe void EncodeTwoOptionallyPadOne(byte* oneByte, T* dest, ref byte encodingMap); + static abstract unsafe void EncodeThreeAndWrite(byte* threeBytes, T* destination, ref byte encodingMap); + static abstract int IncrementPadTwo { get; } + static abstract int IncrementPadOne { get; } + static abstract unsafe void StoreVector512ToDestination(T* dest, T* destStart, int destLength, Vector512 str); + static abstract unsafe void StoreVector256ToDestination(T* dest, T* destStart, int destLength, Vector256 str); + static abstract unsafe void StoreVector128ToDestination(T* dest, T* destStart, int destLength, Vector128 str); + static abstract unsafe void StoreArmVector128x4ToDestination(T* dest, T* destStart, int destLength, Vector128 res1, + Vector128 res2, Vector128 res3, Vector128 res4); + } + + internal interface IBase64Decoder where T : unmanaged + { + static abstract ReadOnlySpan DecodingMap { get; } + static abstract ReadOnlySpan VbmiLookup0 { get; } + static abstract ReadOnlySpan VbmiLookup1 { get; } + static abstract ReadOnlySpan Avx2LutHigh { get; } + static abstract ReadOnlySpan Avx2LutLow { get; } + static abstract ReadOnlySpan Avx2LutShift { get; } + static abstract byte MaskSlashOrUnderscore { get; } + static abstract ReadOnlySpan Vector128LutHigh { get; } + static abstract ReadOnlySpan Vector128LutLow { get; } + static abstract ReadOnlySpan Vector128LutShift { get; } + static abstract ReadOnlySpan AdvSimdLutOne3 { get; } + static abstract uint AdvSimdLutTwo3Uint1 { get; } + static abstract int SrcLength(bool isFinalBlock, int sourceLength); + static abstract int GetMaxDecodedLength(int sourceLength); + static abstract bool IsInvalidLength(int bufferLength); + static abstract bool IsValidPadding(uint padChar); + static abstract bool TryDecode128Core( + Vector128 str, + Vector128 hiNibbles, + Vector128 maskSlashOrUnderscore, + Vector128 mask8F, + Vector128 lutLow, + Vector128 lutHigh, + Vector128 lutShift, + Vector128 shiftForUnderscore, + out Vector128 result); + static abstract bool TryDecode256Core( + Vector256 str, + Vector256 hiNibbles, + Vector256 maskSlashOrUnderscore, + Vector256 lutLow, + Vector256 lutHigh, + Vector256 lutShift, + Vector256 shiftForUnderscore, + out Vector256 result); + static abstract unsafe int DecodeFourElements(T* source, ref sbyte decodingMap); + static abstract unsafe int DecodeRemaining(T* srcEnd, ref sbyte decodingMap, long remaining, out uint t2, out uint t3); + static abstract int IndexOfAnyExceptWhiteSpace(ReadOnlySpan span); + static abstract OperationStatus DecodeWithWhiteSpaceBlockwiseWrapper(ReadOnlySpan source, + Span bytes, ref int bytesConsumed, ref int bytesWritten, bool isFinalBlock = true) + where TTBase64Decoder : IBase64Decoder; + static abstract unsafe bool TryLoadVector512(T* src, T* srcStart, int sourceLength, out Vector512 str); + static abstract unsafe bool TryLoadAvxVector256(T* src, T* srcStart, int sourceLength, out Vector256 str); + static abstract unsafe bool TryLoadVector128(T* src, T* srcStart, int sourceLength, out Vector128 str); + static abstract unsafe bool TryLoadArmVector128x4(T* src, T* srcStart, int sourceLength, + out Vector128 str1, out Vector128 str2, out Vector128 str3, out Vector128 str4); + } } } diff --git a/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Decoder.cs b/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Decoder.cs index 777736e4a4deb..9d3f35b3a0038 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Decoder.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Decoder.cs @@ -35,24 +35,27 @@ public static partial class Base64 /// or if the input is incomplete (i.e. not a multiple of 4) and is . /// public static OperationStatus DecodeFromUtf8(ReadOnlySpan utf8, Span bytes, out int bytesConsumed, out int bytesWritten, bool isFinalBlock = true) => - DecodeFromUtf8(utf8, bytes, out bytesConsumed, out bytesWritten, isFinalBlock, ignoreWhiteSpace: true); + DecodeFrom(utf8, bytes, out bytesConsumed, out bytesWritten, isFinalBlock, ignoreWhiteSpace: true); - private static unsafe OperationStatus DecodeFromUtf8(ReadOnlySpan utf8, Span bytes, out int bytesConsumed, out int bytesWritten, bool isFinalBlock, bool ignoreWhiteSpace) + internal static unsafe OperationStatus DecodeFrom(ReadOnlySpan source, Span bytes, + out int bytesConsumed, out int bytesWritten, bool isFinalBlock, bool ignoreWhiteSpace) + where TBase64Decoder : IBase64Decoder + where T : unmanaged { - if (utf8.IsEmpty) + if (source.IsEmpty) { bytesConsumed = 0; bytesWritten = 0; return OperationStatus.Done; } - fixed (byte* srcBytes = &MemoryMarshal.GetReference(utf8)) + fixed (T* srcBytes = &MemoryMarshal.GetReference(source)) fixed (byte* destBytes = &MemoryMarshal.GetReference(bytes)) { - int srcLength = utf8.Length & ~0x3; // only decode input up to the closest multiple of 4. + int srcLength = TBase64Decoder.SrcLength(isFinalBlock, source.Length); int destLength = bytes.Length; int maxSrcLength = srcLength; - int decodedLength = GetMaxDecodedFromUtf8Length(srcLength); + int decodedLength = TBase64Decoder.GetMaxDecodedLength(srcLength); // max. 2 padding chars if (destLength < decodedLength - 2) @@ -61,17 +64,17 @@ private static unsafe OperationStatus DecodeFromUtf8(ReadOnlySpan utf8, Sp maxSrcLength = destLength / 3 * 4; } - byte* src = srcBytes; + T* src = srcBytes; byte* dest = destBytes; - byte* srcEnd = srcBytes + (uint)srcLength; - byte* srcMax = srcBytes + (uint)maxSrcLength; + T* srcEnd = srcBytes + (uint)srcLength; + T* srcMax = srcBytes + (uint)maxSrcLength; if (maxSrcLength >= 24) { - byte* end = srcMax - 88; + T* end = srcMax - 88; if (Vector512.IsHardwareAccelerated && Avx512Vbmi.IsSupported && (end >= src)) { - Avx512Decode(ref src, ref dest, end, maxSrcLength, destLength, srcBytes, destBytes); + Avx512Decode(ref src, ref dest, end, maxSrcLength, destLength, srcBytes, destBytes); if (src == srcEnd) { @@ -82,7 +85,7 @@ private static unsafe OperationStatus DecodeFromUtf8(ReadOnlySpan utf8, Sp end = srcMax - 45; if (Avx2.IsSupported && (end >= src)) { - Avx2Decode(ref src, ref dest, end, maxSrcLength, destLength, srcBytes, destBytes); + Avx2Decode(ref src, ref dest, end, maxSrcLength, destLength, srcBytes, destBytes); if (src == srcEnd) { @@ -93,7 +96,7 @@ private static unsafe OperationStatus DecodeFromUtf8(ReadOnlySpan utf8, Sp end = srcMax - 66; if (AdvSimd.Arm64.IsSupported && (end >= src)) { - AdvSimdDecode(ref src, ref dest, end, maxSrcLength, destLength, srcBytes, destBytes); + AdvSimdDecode(ref src, ref dest, end, maxSrcLength, destLength, srcBytes, destBytes); if (src == srcEnd) { @@ -104,7 +107,7 @@ private static unsafe OperationStatus DecodeFromUtf8(ReadOnlySpan utf8, Sp end = srcMax - 24; if ((Ssse3.IsSupported || AdvSimd.Arm64.IsSupported) && BitConverter.IsLittleEndian && (end >= src)) { - Vector128Decode(ref src, ref dest, end, maxSrcLength, destLength, srcBytes, destBytes); + Vector128Decode(ref src, ref dest, end, maxSrcLength, destLength, srcBytes, destBytes); if (src == srcEnd) { @@ -126,15 +129,20 @@ private static unsafe OperationStatus DecodeFromUtf8(ReadOnlySpan utf8, Sp // This should never overflow since destLength here is less than int.MaxValue / 4 * 3 (i.e. 1610612733) // Therefore, (destLength / 3) * 4 will always be less than 2147483641 Debug.Assert(destLength < (int.MaxValue / 4 * 3)); - maxSrcLength = (destLength / 3) * 4; + (maxSrcLength, int remainder) = int.DivRem(destLength, 3); + maxSrcLength *= 4; + if (isFinalBlock && remainder > 0) + { + srcLength &= ~0x3; // In case of Base64UrlDecoder source can be not a multiple of 4, round down to multiple of 4 + } } - ref sbyte decodingMap = ref MemoryMarshal.GetReference(DecodingMap); + ref sbyte decodingMap = ref MemoryMarshal.GetReference(TBase64Decoder.DecodingMap); srcMax = srcBytes + maxSrcLength; while (src < srcMax) { - int result = Decode(src, ref decodingMap); + int result = TBase64Decoder.DecodeFourElements(src, ref decodingMap); if (result < 0) { @@ -151,8 +159,6 @@ private static unsafe OperationStatus DecodeFromUtf8(ReadOnlySpan utf8, Sp goto DestinationTooSmallExit; } - // If input is less than 4 bytes, srcLength == sourceIndex == 0 - // If input is not a multiple of 4, sourceIndex == srcLength != 0 if (src == srcEnd) { if (isFinalBlock) @@ -160,7 +166,7 @@ private static unsafe OperationStatus DecodeFromUtf8(ReadOnlySpan utf8, Sp goto InvalidDataExit; } - if (src == srcBytes + utf8.Length) + if (src == srcBytes + source.Length) { goto DoneExit; } @@ -169,24 +175,15 @@ private static unsafe OperationStatus DecodeFromUtf8(ReadOnlySpan utf8, Sp } // if isFinalBlock is false, we will never reach this point - - // Handle last four bytes. There are 0, 1, 2 padding chars. - uint t0 = srcEnd[-4]; - uint t1 = srcEnd[-3]; - uint t2 = srcEnd[-2]; - uint t3 = srcEnd[-1]; - - int i0 = Unsafe.Add(ref decodingMap, (IntPtr)t0); - int i1 = Unsafe.Add(ref decodingMap, (IntPtr)t1); - - i0 <<= 18; - i1 <<= 12; - - i0 |= i1; + // Handle remaining bytes, for Base64 its always 4 bytes, for Base64Url up to 8 bytes left. + // If more than 4 bytes remained it will end up in DestinationTooSmallExit or InvalidDataExit (might succeed after whitespace removed) + long remaining = srcEnd - src; + Debug.Assert(typeof(TBase64Decoder) == typeof(Base64DecoderByte) ? remaining == 4 : remaining < 8); + int i0 = TBase64Decoder.DecodeRemaining(srcEnd, ref decodingMap, remaining, out uint t2, out uint t3); byte* destMax = destBytes + (uint)destLength; - if (t3 != EncodingPad) + if (!TBase64Decoder.IsValidPadding(t3)) { int i2 = Unsafe.Add(ref decodingMap, (IntPtr)t2); int i3 = Unsafe.Add(ref decodingMap, (IntPtr)t3); @@ -207,8 +204,9 @@ private static unsafe OperationStatus DecodeFromUtf8(ReadOnlySpan utf8, Sp WriteThreeLowOrderBytes(dest, i0); dest += 3; + src += 4; } - else if (t2 != EncodingPad) + else if (!TBase64Decoder.IsValidPadding(t2)) { int i2 = Unsafe.Add(ref decodingMap, (IntPtr)t2); @@ -228,6 +226,7 @@ private static unsafe OperationStatus DecodeFromUtf8(ReadOnlySpan utf8, Sp dest[0] = (byte)(i0 >> 16); dest[1] = (byte)(i0 >> 8); dest += 2; + src += remaining; } else { @@ -242,11 +241,10 @@ private static unsafe OperationStatus DecodeFromUtf8(ReadOnlySpan utf8, Sp dest[0] = (byte)(i0 >> 16); dest += 1; + src += remaining; } - src += 4; - - if (srcLength != utf8.Length) + if (srcLength != source.Length) { goto InvalidDataExit; } @@ -257,7 +255,7 @@ private static unsafe OperationStatus DecodeFromUtf8(ReadOnlySpan utf8, Sp return OperationStatus.Done; DestinationTooSmallExit: - if (srcLength != utf8.Length && isFinalBlock) + if (srcLength != source.Length && isFinalBlock) { goto InvalidDataExit; // if input is not a multiple of 4, and there is no more data, return invalid data instead } @@ -275,24 +273,24 @@ private static unsafe OperationStatus DecodeFromUtf8(ReadOnlySpan utf8, Sp bytesConsumed = (int)(src - srcBytes); bytesWritten = (int)(dest - destBytes); return ignoreWhiteSpace ? - InvalidDataFallback(utf8, bytes, ref bytesConsumed, ref bytesWritten, isFinalBlock) : + InvalidDataFallback(source, bytes, ref bytesConsumed, ref bytesWritten, isFinalBlock) : OperationStatus.InvalidData; } - static OperationStatus InvalidDataFallback(ReadOnlySpan utf8, Span bytes, ref int bytesConsumed, ref int bytesWritten, bool isFinalBlock) + static OperationStatus InvalidDataFallback(ReadOnlySpan source, Span bytes, ref int bytesConsumed, ref int bytesWritten, bool isFinalBlock) { - utf8 = utf8.Slice(bytesConsumed); + source = source.Slice(bytesConsumed); bytes = bytes.Slice(bytesWritten); OperationStatus status; do { - int localConsumed = IndexOfAnyExceptWhiteSpace(utf8); + int localConsumed = TBase64Decoder.IndexOfAnyExceptWhiteSpace(source); if (localConsumed < 0) { // The remainder of the input is all whitespace. Mark it all as having been consumed, // and mark the operation as being done. - bytesConsumed += utf8.Length; + bytesConsumed += source.Length; status = OperationStatus.Done; break; } @@ -305,15 +303,15 @@ static OperationStatus InvalidDataFallback(ReadOnlySpan utf8, Span b // Fall back to block-wise decoding. This is very slow, but it's also very non-standard // formatting of the input; whitespace is typically only found between blocks, such as // when Convert.ToBase64String inserts a line break every 76 output characters. - return DecodeWithWhiteSpaceBlockwise(utf8, bytes, ref bytesConsumed, ref bytesWritten, isFinalBlock); + return TBase64Decoder.DecodeWithWhiteSpaceBlockwiseWrapper(source, bytes, ref bytesConsumed, ref bytesWritten, isFinalBlock); } // Skip over the starting whitespace and continue. bytesConsumed += localConsumed; - utf8 = utf8.Slice(localConsumed); + source = source.Slice(localConsumed); // Try again after consumed whitespace - status = DecodeFromUtf8(utf8, bytes, out localConsumed, out int localWritten, isFinalBlock, ignoreWhiteSpace: false); + status = DecodeFrom(source, bytes, out localConsumed, out int localWritten, isFinalBlock, ignoreWhiteSpace: false); bytesConsumed += localConsumed; bytesWritten += localWritten; if (status is not OperationStatus.InvalidData) @@ -321,10 +319,10 @@ static OperationStatus InvalidDataFallback(ReadOnlySpan utf8, Span b break; } - utf8 = utf8.Slice(localConsumed); + source = source.Slice(localConsumed); bytes = bytes.Slice(localWritten); } - while (!utf8.IsEmpty); + while (!source.IsEmpty); return status; } @@ -339,10 +337,7 @@ static OperationStatus InvalidDataFallback(ReadOnlySpan utf8, Span b [MethodImpl(MethodImplOptions.AggressiveInlining)] public static int GetMaxDecodedFromUtf8Length(int length) { - if (length < 0) - { - ThrowHelper.ThrowArgumentOutOfRangeException(ExceptionArgument.length); - } + ArgumentOutOfRangeException.ThrowIfNegative(length); return (length >> 2) * 3; } @@ -363,9 +358,10 @@ public static int GetMaxDecodedFromUtf8Length(int length) /// hence can only be called once with all the data in the buffer. /// public static OperationStatus DecodeFromUtf8InPlace(Span buffer, out int bytesWritten) => - DecodeFromUtf8InPlace(buffer, out bytesWritten, ignoreWhiteSpace: true); + DecodeFromUtf8InPlace(buffer, out bytesWritten, ignoreWhiteSpace: true); - private static unsafe OperationStatus DecodeFromUtf8InPlace(Span buffer, out int bytesWritten, bool ignoreWhiteSpace) + internal static unsafe OperationStatus DecodeFromUtf8InPlace(Span buffer, out int bytesWritten, bool ignoreWhiteSpace) + where TBase64Decoder : IBase64Decoder { if (buffer.IsEmpty) { @@ -379,31 +375,57 @@ private static unsafe OperationStatus DecodeFromUtf8InPlace(Span buffer, o uint sourceIndex = 0; uint destIndex = 0; - // only decode input if it is a multiple of 4 - if (bufferLength % 4 != 0) + if (TBase64Decoder.IsInvalidLength(buffer.Length)) { goto InvalidExit; } - ref sbyte decodingMap = ref MemoryMarshal.GetReference(DecodingMap); + ref sbyte decodingMap = ref MemoryMarshal.GetReference(TBase64Decoder.DecodingMap); - while (sourceIndex < bufferLength - 4) + if (bufferLength > 4) { - int result = Decode(bufferBytes + sourceIndex, ref decodingMap); - if (result < 0) + while (sourceIndex < bufferLength - 4) { - goto InvalidExit; - } + int result = Base64DecoderByte.DecodeFourElements(bufferBytes + sourceIndex, ref decodingMap); + if (result < 0) + { + goto InvalidExit; + } - WriteThreeLowOrderBytes(bufferBytes + destIndex, result); - destIndex += 3; - sourceIndex += 4; + WriteThreeLowOrderBytes(bufferBytes + destIndex, result); + destIndex += 3; + sourceIndex += 4; + } } - uint t0 = bufferBytes[bufferLength - 4]; - uint t1 = bufferBytes[bufferLength - 3]; - uint t2 = bufferBytes[bufferLength - 2]; - uint t3 = bufferBytes[bufferLength - 1]; + uint t0; + uint t1; + uint t2; + uint t3; + + switch (bufferLength - sourceIndex) + { + case 2: + t0 = bufferBytes[bufferLength - 2]; + t1 = bufferBytes[bufferLength - 1]; + t2 = EncodingPad; + t3 = EncodingPad; + break; + case 3: + t0 = bufferBytes[bufferLength - 3]; + t1 = bufferBytes[bufferLength - 2]; + t2 = bufferBytes[bufferLength - 1]; + t3 = EncodingPad; + break; + case 4: + t0 = bufferBytes[bufferLength - 4]; + t1 = bufferBytes[bufferLength - 3]; + t2 = bufferBytes[bufferLength - 2]; + t3 = bufferBytes[bufferLength - 1]; + break; + default: + goto InvalidExit; + } int i0 = Unsafe.Add(ref decodingMap, t0); int i1 = Unsafe.Add(ref decodingMap, t1); @@ -413,7 +435,7 @@ private static unsafe OperationStatus DecodeFromUtf8InPlace(Span buffer, o i0 |= i1; - if (t3 != EncodingPad) + if (!TBase64Decoder.IsValidPadding(t3)) { int i2 = Unsafe.Add(ref decodingMap, t2); int i3 = Unsafe.Add(ref decodingMap, t3); @@ -431,7 +453,7 @@ private static unsafe OperationStatus DecodeFromUtf8InPlace(Span buffer, o WriteThreeLowOrderBytes(bufferBytes + destIndex, i0); destIndex += 3; } - else if (t2 != EncodingPad) + else if (!TBase64Decoder.IsValidPadding(t2)) { int i2 = Unsafe.Add(ref decodingMap, t2); @@ -465,37 +487,38 @@ private static unsafe OperationStatus DecodeFromUtf8InPlace(Span buffer, o InvalidExit: bytesWritten = (int)destIndex; return ignoreWhiteSpace ? - DecodeWithWhiteSpaceFromUtf8InPlace(buffer, ref bytesWritten, sourceIndex) : // The input may have whitespace, attempt to decode while ignoring whitespace. + DecodeWithWhiteSpaceFromUtf8InPlace(buffer, ref bytesWritten, sourceIndex) : // The input may have whitespace, attempt to decode while ignoring whitespace. OperationStatus.InvalidData; } } - private static OperationStatus DecodeWithWhiteSpaceBlockwise(ReadOnlySpan utf8, Span bytes, ref int bytesConsumed, ref int bytesWritten, bool isFinalBlock = true) + internal static OperationStatus DecodeWithWhiteSpaceBlockwise(ReadOnlySpan source, Span bytes, ref int bytesConsumed, ref int bytesWritten, bool isFinalBlock = true) + where TBase64Decoder : IBase64Decoder { const int BlockSize = 4; Span buffer = stackalloc byte[BlockSize]; OperationStatus status = OperationStatus.Done; - while (!utf8.IsEmpty) + while (!source.IsEmpty) { int encodedIdx = 0; int bufferIdx = 0; int skipped = 0; - for (; encodedIdx < utf8.Length && (uint)bufferIdx < (uint)buffer.Length; ++encodedIdx) + for (; encodedIdx < source.Length && (uint)bufferIdx < (uint)buffer.Length; ++encodedIdx) { - if (IsWhiteSpace(utf8[encodedIdx])) + if (IsWhiteSpace(source[encodedIdx])) { skipped++; } else { - buffer[bufferIdx] = utf8[encodedIdx]; + buffer[bufferIdx] = source[encodedIdx]; bufferIdx++; } } - utf8 = utf8.Slice(encodedIdx); + source = source.Slice(encodedIdx); bytesConsumed += skipped; if (bufferIdx == 0) @@ -503,13 +526,23 @@ private static OperationStatus DecodeWithWhiteSpaceBlockwise(ReadOnlySpan continue; } - bool hasAnotherBlock = utf8.Length >= BlockSize && bufferIdx == BlockSize; + bool hasAnotherBlock; + + if (typeof(TBase64Decoder) == typeof(Base64DecoderByte)) + { + hasAnotherBlock = source.Length >= BlockSize; + } + else + { + hasAnotherBlock = source.Length > 1; + } + bool localIsFinalBlock = !hasAnotherBlock; // If this block contains padding and there's another block, then only whitespace may follow for being valid. if (hasAnotherBlock) { - int paddingCount = GetPaddingCount(ref buffer[^1]); + int paddingCount = GetPaddingCount(ref buffer[^1]); if (paddingCount > 0) { hasAnotherBlock = false; @@ -522,7 +555,7 @@ private static OperationStatus DecodeWithWhiteSpaceBlockwise(ReadOnlySpan localIsFinalBlock = false; } - status = DecodeFromUtf8(buffer.Slice(0, bufferIdx), bytes, out int localConsumed, out int localWritten, localIsFinalBlock, ignoreWhiteSpace: false); + status = DecodeFrom(buffer.Slice(0, bufferIdx), bytes, out int localConsumed, out int localWritten, localIsFinalBlock, ignoreWhiteSpace: false); bytesConsumed += localConsumed; bytesWritten += localWritten; @@ -534,9 +567,9 @@ private static OperationStatus DecodeWithWhiteSpaceBlockwise(ReadOnlySpan // The remaining data must all be whitespace in order to be valid. if (!hasAnotherBlock) { - for (int i = 0; i < utf8.Length; ++i) + for (int i = 0; i < source.Length; ++i) { - if (!IsWhiteSpace(utf8[i])) + if (!IsWhiteSpace(source[i])) { // Revert previous dest increment, since an invalid state followed. bytesConsumed -= localConsumed; @@ -558,19 +591,28 @@ private static OperationStatus DecodeWithWhiteSpaceBlockwise(ReadOnlySpan } [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static int GetPaddingCount(ref byte ptrToLastElement) + private static int GetPaddingCount(ref byte ptrToLastElement) + where TBase64Decoder : IBase64Decoder { int padding = 0; - if (ptrToLastElement == EncodingPad) padding++; - if (Unsafe.Subtract(ref ptrToLastElement, 1) == EncodingPad) padding++; + if (TBase64Decoder.IsValidPadding(ptrToLastElement)) + { + padding++; + } + + if (TBase64Decoder.IsValidPadding(Unsafe.Subtract(ref ptrToLastElement, 1))) + { + padding++; + } return padding; } - private static OperationStatus DecodeWithWhiteSpaceFromUtf8InPlace(Span utf8, ref int destIndex, uint sourceIndex) + private static OperationStatus DecodeWithWhiteSpaceFromUtf8InPlace(Span source, ref int destIndex, uint sourceIndex) + where TBase64Decoder : IBase64Decoder { - const int BlockSize = 4; + int BlockSize = Math.Min(source.Length - (int)sourceIndex, 4); Span buffer = stackalloc byte[BlockSize]; OperationStatus status = OperationStatus.Done; @@ -578,20 +620,15 @@ private static OperationStatus DecodeWithWhiteSpaceFromUtf8InPlace(Span ut bool hasPaddingBeenProcessed = false; int localBytesWritten = 0; - while (sourceIndex < (uint)utf8.Length) + while (sourceIndex < (uint)source.Length) { int bufferIdx = 0; - while (bufferIdx < BlockSize) + while (bufferIdx < BlockSize && sourceIndex < (uint)source.Length) { - if (sourceIndex >= (uint)utf8.Length) // TODO https://github.com/dotnet/runtime/issues/83349: move into the while condition once fixed - { - break; - } - - if (!IsWhiteSpace(utf8[(int)sourceIndex])) + if (!IsWhiteSpace(source[(int)sourceIndex])) { - buffer[bufferIdx] = utf8[(int)sourceIndex]; + buffer[bufferIdx] = source[(int)sourceIndex]; bufferIdx++; } @@ -605,8 +642,20 @@ private static OperationStatus DecodeWithWhiteSpaceFromUtf8InPlace(Span ut if (bufferIdx != 4) { - status = OperationStatus.InvalidData; - break; + // Base64 require 4 bytes, for Base64Url it can be less than 4 bytes but not 1 byte. + if (typeof(TBase64Decoder) == typeof(Base64DecoderByte) || bufferIdx == 1) + { + status = OperationStatus.InvalidData; + break; + } + else // For Base64Url fill empty slots in last block with padding + { + while (bufferIdx < BlockSize) // Can happen only for last block + { + Debug.Assert(source.Length == sourceIndex); + buffer[bufferIdx++] = (byte)EncodingPad; + } + } } if (hasPaddingBeenProcessed) @@ -618,7 +667,7 @@ private static OperationStatus DecodeWithWhiteSpaceFromUtf8InPlace(Span ut break; } - status = DecodeFromUtf8InPlace(buffer, out localBytesWritten, ignoreWhiteSpace: false); + status = DecodeFromUtf8InPlace(buffer, out localBytesWritten, ignoreWhiteSpace: false); localDestIndex += localBytesWritten; hasPaddingBeenProcessed = localBytesWritten < 3; @@ -630,7 +679,7 @@ private static OperationStatus DecodeWithWhiteSpaceFromUtf8InPlace(Span ut // Write result to source span in place. for (int i = 0; i < localBytesWritten; i++) { - utf8[localDestIndex - localBytesWritten + i] = buffer[i]; + source[localDestIndex - localBytesWritten + i] = buffer[i]; } } @@ -641,7 +690,9 @@ private static OperationStatus DecodeWithWhiteSpaceFromUtf8InPlace(Span ut [MethodImpl(MethodImplOptions.AggressiveInlining)] [CompExactlyDependsOn(typeof(Avx512BW))] [CompExactlyDependsOn(typeof(Avx512Vbmi))] - private static unsafe void Avx512Decode(ref byte* srcBytes, ref byte* destBytes, byte* srcEnd, int sourceLength, int destLength, byte* srcStart, byte* destStart) + private static unsafe void Avx512Decode(ref T* srcBytes, ref byte* destBytes, T* srcEnd, int sourceLength, int destLength, T* srcStart, byte* destStart) + where TBase64Decoder : IBase64Decoder + where T : unmanaged { // Reference for VBMI implementation : https://github.com/WojciechMula/base64simd/tree/master/decode // If we have AVX512 support, pick off 64 bytes at a time for as long as we can, @@ -649,20 +700,12 @@ private static unsafe void Avx512Decode(ref byte* srcBytes, ref byte* destBytes, // string. Also, because we write 16 zeroes at the end of the output, ensure // that there are at least 22 valid bytes of input data remaining to close the // gap. 64 + 2 + 22 = 88 bytes. - byte* src = srcBytes; + T* src = srcBytes; byte* dest = destBytes; // The JIT won't hoist these "constants", so help it - Vector512 vbmiLookup0 = Vector512.Create( - 0x80808080, 0x80808080, 0x80808080, 0x80808080, - 0x80808080, 0x80808080, 0x80808080, 0x80808080, - 0x80808080, 0x80808080, 0x3e808080, 0x3f808080, - 0x37363534, 0x3b3a3938, 0x80803d3c, 0x80808080).AsSByte(); - Vector512 vbmiLookup1 = Vector512.Create( - 0x02010080, 0x06050403, 0x0a090807, 0x0e0d0c0b, - 0x1211100f, 0x16151413, 0x80191817, 0x80808080, - 0x1c1b1a80, 0x201f1e1d, 0x24232221, 0x28272625, - 0x2c2b2a29, 0x302f2e2d, 0x80333231, 0x80808080).AsSByte(); + Vector512 vbmiLookup0 = Vector512.Create(TBase64Decoder.VbmiLookup0).AsSByte(); + Vector512 vbmiLookup1 = Vector512.Create(TBase64Decoder.VbmiLookup1).AsSByte(); Vector512 vbmiPackedLanesControl = Vector512.Create( 0x06000102, 0x090a0405, 0x0c0d0e08, 0x16101112, 0x191a1415, 0x1c1d1e18, 0x26202122, 0x292a2425, @@ -673,11 +716,13 @@ private static unsafe void Avx512Decode(ref byte* srcBytes, ref byte* destBytes, Vector512 mergeConstant1 = Vector512.Create(0x00011000).AsInt16(); // This algorithm requires AVX512VBMI support. - // Vbmi was first introduced in CannonLake and is avaialable from IceLake on. + // Vbmi was first introduced in CannonLake and is available from IceLake on. do { - AssertRead>(src, srcStart, sourceLength); - Vector512 str = Vector512.Load(src).AsSByte(); + if (!TBase64Decoder.TryLoadVector512(src, srcStart, sourceLength, out Vector512 str)) + { + break; + } // Step 1: Translate encoded Base64 input to their original indices // This step also checks for invalid inputs and exits. @@ -712,7 +757,9 @@ private static unsafe void Avx512Decode(ref byte* srcBytes, ref byte* destBytes, [MethodImpl(MethodImplOptions.AggressiveInlining)] [CompExactlyDependsOn(typeof(Avx2))] - private static unsafe void Avx2Decode(ref byte* srcBytes, ref byte* destBytes, byte* srcEnd, int sourceLength, int destLength, byte* srcStart, byte* destStart) + private static unsafe void Avx2Decode(ref T* srcBytes, ref byte* destBytes, T* srcEnd, int sourceLength, int destLength, T* srcStart, byte* destStart) + where TBase64Decoder : IBase64Decoder + where T : unmanaged { // If we have AVX2 support, pick off 32 bytes at a time for as long as we can, // but make sure that we quit before seeing any == markers at the end of the @@ -723,35 +770,11 @@ private static unsafe void Avx2Decode(ref byte* srcBytes, ref byte* destBytes, b // See SSSE3-version below for an explanation of how the code works. // The JIT won't hoist these "constants", so help it - Vector256 lutHi = Vector256.Create( - 0x10, 0x10, 0x01, 0x02, - 0x04, 0x08, 0x04, 0x08, - 0x10, 0x10, 0x10, 0x10, - 0x10, 0x10, 0x10, 0x10, - 0x10, 0x10, 0x01, 0x02, - 0x04, 0x08, 0x04, 0x08, - 0x10, 0x10, 0x10, 0x10, - 0x10, 0x10, 0x10, 0x10); - - Vector256 lutLo = Vector256.Create( - 0x15, 0x11, 0x11, 0x11, - 0x11, 0x11, 0x11, 0x11, - 0x11, 0x11, 0x13, 0x1A, - 0x1B, 0x1B, 0x1B, 0x1A, - 0x15, 0x11, 0x11, 0x11, - 0x11, 0x11, 0x11, 0x11, - 0x11, 0x11, 0x13, 0x1A, - 0x1B, 0x1B, 0x1B, 0x1A); - - Vector256 lutShift = Vector256.Create( - 0, 16, 19, 4, - -65, -65, -71, -71, - 0, 0, 0, 0, - 0, 0, 0, 0, - 0, 16, 19, 4, - -65, -65, -71, -71, - 0, 0, 0, 0, - 0, 0, 0, 0); + Vector256 lutHi = Vector256.Create(TBase64Decoder.Avx2LutHigh); + + Vector256 lutLo = Vector256.Create(TBase64Decoder.Avx2LutLow); + + Vector256 lutShift = Vector256.Create(TBase64Decoder.Avx2LutShift); Vector256 packBytesInLaneMask = Vector256.Create( 2, 1, 0, 6, @@ -773,33 +796,29 @@ private static unsafe void Avx2Decode(ref byte* srcBytes, ref byte* destBytes, b -1, -1, -1, -1, -1, -1, -1, -1).AsInt32(); - Vector256 mask2F = Vector256.Create((sbyte)'/'); + Vector256 maskSlashOrUnderscore = Vector256.Create((sbyte)TBase64Decoder.MaskSlashOrUnderscore); + Vector256 shiftForUnderscore = Vector256.Create((sbyte)33); Vector256 mergeConstant0 = Vector256.Create(0x01400140).AsSByte(); Vector256 mergeConstant1 = Vector256.Create(0x00011000).AsInt16(); - byte* src = srcBytes; + T* src = srcBytes; byte* dest = destBytes; //while (remaining >= 45) do { - AssertRead>(src, srcStart, sourceLength); - Vector256 str = Avx.LoadVector256(src).AsSByte(); + if (!TBase64Decoder.TryLoadAvxVector256(src, srcStart, sourceLength, out Vector256 str)) + { + break; + } - Vector256 hiNibbles = Avx2.And(Avx2.ShiftRightLogical(str.AsInt32(), 4).AsSByte(), mask2F); - Vector256 loNibbles = Avx2.And(str, mask2F); - Vector256 hi = Avx2.Shuffle(lutHi, hiNibbles); - Vector256 lo = Avx2.Shuffle(lutLo, loNibbles); + Vector256 hiNibbles = Avx2.And(Avx2.ShiftRightLogical(str.AsInt32(), 4).AsSByte(), maskSlashOrUnderscore); - if (!Avx.TestZ(lo, hi)) + if (!TBase64Decoder.TryDecode256Core(str, hiNibbles, maskSlashOrUnderscore, lutLo, lutHi, lutShift, shiftForUnderscore, out str)) { break; } - Vector256 eq2F = Avx2.CompareEqual(str, mask2F); - Vector256 shift = Avx2.Shuffle(lutShift, Avx2.Add(eq2F, hiNibbles)); - str = Avx2.Add(str, shift); - // in, lower lane, bits, upper case are most significant bits, lower case are least significant bits: // 00llllll 00kkkkLL 00jjKKKK 00JJJJJJ // 00iiiiii 00hhhhII 00ggHHHH 00GGGGGG @@ -843,7 +862,7 @@ private static unsafe void Avx2Decode(ref byte* srcBytes, ref byte* destBytes, b [MethodImpl(MethodImplOptions.AggressiveInlining)] [CompExactlyDependsOn(typeof(Ssse3))] [CompExactlyDependsOn(typeof(AdvSimd.Arm64))] - private static Vector128 SimdShuffle(Vector128 left, Vector128 right, Vector128 mask8F) + internal static Vector128 SimdShuffle(Vector128 left, Vector128 right, Vector128 mask8F) { Debug.Assert((Ssse3.IsSupported || AdvSimd.Arm64.IsSupported) && BitConverter.IsLittleEndian); @@ -857,7 +876,9 @@ private static Vector128 SimdShuffle(Vector128 left, Vector128 [MethodImpl(MethodImplOptions.AggressiveInlining)] [CompExactlyDependsOn(typeof(AdvSimd.Arm64))] - private static unsafe void AdvSimdDecode(ref byte* srcBytes, ref byte* destBytes, byte* srcEnd, int sourceLength, int destLength, byte* srcStart, byte* destStart) + private static unsafe void AdvSimdDecode(ref T* srcBytes, ref byte* destBytes, T* srcEnd, int sourceLength, int destLength, T* srcStart, byte* destStart) + where TBase64Decoder : IBase64Decoder + where T : unmanaged { // C# implementation of https://github.com/aklomp/base64/blob/3a5add8652076612a8407627a42c768736a4263f/lib/arch/neon64/dec_loop.c // If we have AdvSimd support, pick off 64 bytes at a time for as long as we can, @@ -890,7 +911,7 @@ private static unsafe void AdvSimdDecode(ref byte* srcBytes, ref byte* destBytes // 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 255, 255, 255, 255, 255, 255 var decLutOne = (Vector128.AllBitsSet, Vector128.AllBitsSet, - Vector128.Create(0xFFFFFFFF, 0xFFFFFFFF, 0x3EFFFFFF, 0x3FFFFFFF).AsByte(), + Vector128.Create(TBase64Decoder.AdvSimdLutOne3).AsByte(), Vector128.Create(0x37363534, 0x3B3A3938, 0xFFFF3D3C, 0xFFFFFFFF).AsByte()); // Values in 'decLutTwo' maps input values from 63 to 127. @@ -900,18 +921,21 @@ private static unsafe void AdvSimdDecode(ref byte* srcBytes, ref byte* destBytes // 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 255, 255, 255, 255 var decLutTwo = (Vector128.Create(0x0100FF00, 0x05040302, 0x09080706, 0x0D0C0B0A).AsByte(), Vector128.Create(0x11100F0E, 0x15141312, 0x19181716, 0xFFFFFFFF).AsByte(), - Vector128.Create(0x1B1AFFFF, 0x1F1E1D1C, 0x23222120, 0x27262524).AsByte(), + Vector128.Create(TBase64Decoder.AdvSimdLutTwo3Uint1, 0x1F1E1D1C, 0x23222120, 0x27262524).AsByte(), Vector128.Create(0x2B2A2928, 0x2F2E2D2C, 0x33323130, 0xFFFFFFFF).AsByte()); - byte* src = srcBytes; + T* src = srcBytes; byte* dest = destBytes; Vector128 offset = Vector128.Create(63); do { // Step 1: Load 64 bytes and de-interleave. - AssertRead>(src, srcStart, sourceLength); - var (str1, str2, str3, str4) = AdvSimd.Arm64.LoadVector128x4AndUnzip(src); + if (!TBase64Decoder.TryLoadArmVector128x4(src, srcStart, sourceLength, + out Vector128 str1, out Vector128 str2, out Vector128 str3, out Vector128 str4)) + { + break; + } // Step 2: Map each valid input to its Base64 value. // We use two look-ups to compute partial results and combine them later. @@ -993,7 +1017,9 @@ private static unsafe void AdvSimdDecode(ref byte* srcBytes, ref byte* destBytes [MethodImpl(MethodImplOptions.AggressiveInlining)] [CompExactlyDependsOn(typeof(AdvSimd.Arm64))] [CompExactlyDependsOn(typeof(Ssse3))] - private static unsafe void Vector128Decode(ref byte* srcBytes, ref byte* destBytes, byte* srcEnd, int sourceLength, int destLength, byte* srcStart, byte* destStart) + private static unsafe void Vector128Decode(ref T* srcBytes, ref byte* destBytes, T* srcEnd, int sourceLength, int destLength, T* srcStart, byte* destStart) + where TBase64Decoder : IBase64Decoder + where T : unmanaged { Debug.Assert((Ssse3.IsSupported || AdvSimd.Arm64.IsSupported) && BitConverter.IsLittleEndian); @@ -1070,44 +1096,35 @@ private static unsafe void Vector128Decode(ref byte* srcBytes, ref byte* destByt // 1111 0x10 andlut 0x10 0x10 0x10 0x10 0x10 0x10 0x10 0x10 0x10 0x10 0x10 0x10 0x10 0x10 0x10 0x10 // The JIT won't hoist these "constants", so help it - Vector128 lutHi = Vector128.Create(0x02011010, 0x08040804, 0x10101010, 0x10101010).AsByte(); - Vector128 lutLo = Vector128.Create(0x11111115, 0x11111111, 0x1A131111, 0x1A1B1B1B).AsByte(); - Vector128 lutShift = Vector128.Create(0x04131000, 0xb9b9bfbf, 0x00000000, 0x00000000).AsSByte(); + Vector128 lutHi = Vector128.Create(TBase64Decoder.Vector128LutHigh).AsByte(); + Vector128 lutLo = Vector128.Create(TBase64Decoder.Vector128LutLow).AsByte(); + Vector128 lutShift = Vector128.Create(TBase64Decoder.Vector128LutShift).AsSByte(); Vector128 packBytesMask = Vector128.Create(0x06000102, 0x090A0405, 0x0C0D0E08, 0xffffffff).AsSByte(); - Vector128 mergeConstant0 = Vector128.Create(0x01400140).AsByte(); + Vector128 mergeConstant0 = Vector128.Create(0x01400140).AsByte(); Vector128 mergeConstant1 = Vector128.Create(0x00011000).AsInt16(); - Vector128 one = Vector128.Create((byte)1); - Vector128 mask2F = Vector128.Create((byte)'/'); - Vector128 mask8F = Vector128.Create((byte)0x8F); - - byte* src = srcBytes; + Vector128 one = Vector128.Create((byte)1); + Vector128 mask2F = Vector128.Create(TBase64Decoder.MaskSlashOrUnderscore); + Vector128 mask8F = Vector128.Create((byte)0x8F); + Vector128 shiftForUnderscore = Vector128.Create((byte)33); + T* src = srcBytes; byte* dest = destBytes; //while (remaining >= 24) do { - AssertRead>(src, srcStart, sourceLength); - Vector128 str = Vector128.LoadUnsafe(ref *src); + if (!TBase64Decoder.TryLoadVector128(src, srcStart, sourceLength, out Vector128 str)) + { + break; + } // lookup Vector128 hiNibbles = Vector128.ShiftRightLogical(str.AsInt32(), 4).AsByte() & mask2F; - Vector128 loNibbles = str & mask2F; - Vector128 hi = SimdShuffle(lutHi, hiNibbles, mask8F); - Vector128 lo = SimdShuffle(lutLo, loNibbles, mask8F); - // Check for invalid input: if any "and" values from lo and hi are not zero, - // fall back on bytewise code to do error checking and reporting: - if ((lo & hi) != Vector128.Zero) + if (!TBase64Decoder.TryDecode128Core(str, hiNibbles, mask2F, mask8F, lutLo, lutHi, lutShift, shiftForUnderscore, out str)) { break; } - Vector128 eq2F = Vector128.Equals(str, mask2F); - Vector128 shift = SimdShuffle(lutShift.AsByte(), (eq2F + hiNibbles), mask8F); - - // Now simply add the delta values to the input: - str += shift; - // in, bits, upper case are most significant bits, lower case are least significant bits // 00llllll 00kkkkLL 00jjKKKK 00JJJJJJ // 00iiiiii 00hhhhII 00ggHHHH 00GGGGGG @@ -1165,30 +1182,6 @@ private static unsafe void Vector128Decode(ref byte* srcBytes, ref byte* destByt destBytes = dest; } - [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static unsafe int Decode(byte* encodedBytes, ref sbyte decodingMap) - { - uint t0 = encodedBytes[0]; - uint t1 = encodedBytes[1]; - uint t2 = encodedBytes[2]; - uint t3 = encodedBytes[3]; - - int i0 = Unsafe.Add(ref decodingMap, t0); - int i1 = Unsafe.Add(ref decodingMap, t1); - int i2 = Unsafe.Add(ref decodingMap, t2); - int i3 = Unsafe.Add(ref decodingMap, t3); - - i0 <<= 18; - i1 <<= 12; - i2 <<= 6; - - i0 |= i3; - i1 |= i2; - - i0 |= i1; - return i0; - } - [MethodImpl(MethodImplOptions.AggressiveInlining)] private static unsafe void WriteThreeLowOrderBytes(byte* destination, int value) { @@ -1197,20 +1190,6 @@ private static unsafe void WriteThreeLowOrderBytes(byte* destination, int value) destination[2] = (byte)value; } - [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static int IndexOfAnyExceptWhiteSpace(ReadOnlySpan span) - { - for (int i = 0; i < span.Length; i++) - { - if (!IsWhiteSpace(span[i])) - { - return i; - } - } - - return -1; - } - [MethodImpl(MethodImplOptions.AggressiveInlining)] internal static bool IsWhiteSpace(int value) { @@ -1234,25 +1213,284 @@ internal static bool IsWhiteSpace(int value) return value == 32; } - // Pre-computing this table using a custom string(s_characters) and GenerateDecodingMapAndVerify (found in tests) - private static ReadOnlySpan DecodingMap => - [ - -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, - -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, - -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 62, -1, -1, -1, 63, //62 is placed at index 43 (for +), 63 at index 47 (for /) - 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, -1, -1, -1, -1, -1, -1, //52-61 are placed at index 48-57 (for 0-9), 64 at index 61 (for =) - -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, - 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, -1, -1, -1, -1, -1, //0-25 are placed at index 65-90 (for A-Z) - -1, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, - 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, -1, -1, -1, -1, -1, //26-51 are placed at index 97-122 (for a-z) - -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, // Bytes over 122 ('z') are invalid and cannot be decoded - -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, // Hence, padding the map with 255, which indicates invalid input - -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, - -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, - -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, - -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, - -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, - -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, - ]; + internal readonly struct Base64DecoderByte : IBase64Decoder + { + // Pre-computing this table using a custom string(s_characters) and GenerateDecodingMapAndVerify (found in tests) + public static ReadOnlySpan DecodingMap => + [ + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 62, -1, -1, -1, 63, //62 is placed at index 43 (for +), 63 at index 47 (for /) + 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, -1, -1, -1, -1, -1, -1, //52-61 are placed at index 48-57 (for 0-9) + -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, + 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, -1, -1, -1, -1, -1, //0-25 are placed at index 65-90 (for A-Z) + -1, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, + 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, -1, -1, -1, -1, -1, //26-51 are placed at index 97-122 (for a-z) + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, // Bytes over 122 ('z') are invalid and cannot be decoded + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, // Hence, padding the map with 255, which indicates invalid input + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + ]; + + public static ReadOnlySpan VbmiLookup0 => + [ + 0x80808080, 0x80808080, 0x80808080, 0x80808080, + 0x80808080, 0x80808080, 0x80808080, 0x80808080, + 0x80808080, 0x80808080, 0x3e808080, 0x3f808080, + 0x37363534, 0x3b3a3938, 0x80803d3c, 0x80808080 + ]; + + public static ReadOnlySpan VbmiLookup1 => + [ + 0x02010080, 0x06050403, 0x0a090807, 0x0e0d0c0b, + 0x1211100f, 0x16151413, 0x80191817, 0x80808080, + 0x1c1b1a80, 0x201f1e1d, 0x24232221, 0x28272625, + 0x2c2b2a29, 0x302f2e2d, 0x80333231, 0x80808080 + ]; + + public static ReadOnlySpan Avx2LutHigh => + [ + 0x10, 0x10, 0x01, 0x02, + 0x04, 0x08, 0x04, 0x08, + 0x10, 0x10, 0x10, 0x10, + 0x10, 0x10, 0x10, 0x10, + 0x10, 0x10, 0x01, 0x02, + 0x04, 0x08, 0x04, 0x08, + 0x10, 0x10, 0x10, 0x10, + 0x10, 0x10, 0x10, 0x10 + ]; + + public static ReadOnlySpan Avx2LutLow => + [ + 0x15, 0x11, 0x11, 0x11, + 0x11, 0x11, 0x11, 0x11, + 0x11, 0x11, 0x13, 0x1A, + 0x1B, 0x1B, 0x1B, 0x1A, + 0x15, 0x11, 0x11, 0x11, + 0x11, 0x11, 0x11, 0x11, + 0x11, 0x11, 0x13, 0x1A, + 0x1B, 0x1B, 0x1B, 0x1A + ]; + + public static ReadOnlySpan Avx2LutShift => + [ + 0, 16, 19, 4, + -65, -65, -71, -71, + 0, 0, 0, 0, + 0, 0, 0, 0, + 0, 16, 19, 4, + -65, -65, -71, -71, + 0, 0, 0, 0, + 0, 0, 0, 0 + ]; + + public static byte MaskSlashOrUnderscore => (byte)'/'; + + public static ReadOnlySpan Vector128LutHigh => [0x02011010, 0x08040804, 0x10101010, 0x10101010]; + + public static ReadOnlySpan Vector128LutLow => [0x11111115, 0x11111111, 0x1A131111, 0x1A1B1B1B]; + + public static ReadOnlySpan Vector128LutShift => [0x04131000, 0xb9b9bfbf, 0x00000000, 0x00000000]; + + public static ReadOnlySpan AdvSimdLutOne3 => [0xFFFFFFFF, 0xFFFFFFFF, 0x3EFFFFFF, 0x3FFFFFFF]; + + public static uint AdvSimdLutTwo3Uint1 => 0x1B1AFFFF; + + public static int GetMaxDecodedLength(int utf8Length) => GetMaxDecodedFromUtf8Length(utf8Length); + + public static bool IsInvalidLength(int bufferLength) => bufferLength % 4 != 0; // only decode input if it is a multiple of 4 + + public static bool IsValidPadding(uint padChar) => padChar == EncodingPad; + + public static int SrcLength(bool _, int utf8Length) => utf8Length & ~0x3; // only decode input up to the closest multiple of 4. + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + [CompExactlyDependsOn(typeof(AdvSimd.Arm64))] + [CompExactlyDependsOn(typeof(Ssse3))] + public static bool TryDecode128Core( + Vector128 str, + Vector128 hiNibbles, + Vector128 maskSlashOrUnderscore, + Vector128 mask8F, + Vector128 lutLow, + Vector128 lutHigh, + Vector128 lutShift, + Vector128 _, + out Vector128 result) + { + Vector128 loNibbles = str & maskSlashOrUnderscore; + Vector128 hi = SimdShuffle(lutHigh, hiNibbles, mask8F); + Vector128 lo = SimdShuffle(lutLow, loNibbles, mask8F); + + // Check for invalid input: if any "and" values from lo and hi are not zero, + // fall back on bytewise code to do error checking and reporting: + if ((lo & hi) != Vector128.Zero) + { + result = default; + return false; + } + + Vector128 eq2F = Vector128.Equals(str, maskSlashOrUnderscore); + Vector128 shift = SimdShuffle(lutShift.AsByte(), (eq2F + hiNibbles), mask8F); + + // Now simply add the delta values to the input: + result = str + shift; + + return true; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + [CompExactlyDependsOn(typeof(Avx2))] + public static bool TryDecode256Core( + Vector256 str, + Vector256 hiNibbles, + Vector256 maskSlashOrUnderscore, + Vector256 lutLow, + Vector256 lutHigh, + Vector256 lutShift, + Vector256 _, + out Vector256 result) + { + Vector256 loNibbles = Avx2.And(str, maskSlashOrUnderscore); + Vector256 hi = Avx2.Shuffle(lutHigh, hiNibbles); + Vector256 lo = Avx2.Shuffle(lutLow, loNibbles); + + if (!Avx.TestZ(lo, hi)) + { + result = default; + return false; + } + + Vector256 eq2F = Avx2.CompareEqual(str, maskSlashOrUnderscore); + Vector256 shift = Avx2.Shuffle(lutShift, Avx2.Add(eq2F, hiNibbles)); + + result = Avx2.Add(str, shift); + + return true; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static unsafe int DecodeFourElements(byte* source, ref sbyte decodingMap) + { + // The 'source' span expected to have at least 4 elements, and the 'decodingMap' consists 256 sbytes + uint t0 = source[0]; + uint t1 = source[1]; + uint t2 = source[2]; + uint t3 = source[3]; + + int i0 = Unsafe.Add(ref decodingMap, t0); + int i1 = Unsafe.Add(ref decodingMap, t1); + int i2 = Unsafe.Add(ref decodingMap, t2); + int i3 = Unsafe.Add(ref decodingMap, t3); + + i0 <<= 18; + i1 <<= 12; + i2 <<= 6; + + i0 |= i3; + i1 |= i2; + + i0 |= i1; + return i0; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static unsafe int DecodeRemaining(byte* srcEnd, ref sbyte decodingMap, long remaining, out uint t2, out uint t3) + { + uint t0; + uint t1; + t2 = EncodingPad; + t3 = EncodingPad; + switch (remaining) + { + case 2: + t0 = srcEnd[-2]; + t1 = srcEnd[-1]; + break; + case 3: + t0 = srcEnd[-3]; + t1 = srcEnd[-2]; + t2 = srcEnd[-1]; + break; + case 4: + t0 = srcEnd[-4]; + t1 = srcEnd[-3]; + t2 = srcEnd[-2]; + t3 = srcEnd[-1]; + break; + default: + return -1; + } + + int i0 = Unsafe.Add(ref decodingMap, (IntPtr)t0); + int i1 = Unsafe.Add(ref decodingMap, (IntPtr)t1); + + i0 <<= 18; + i1 <<= 12; + + i0 |= i1; + return i0; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static int IndexOfAnyExceptWhiteSpace(ReadOnlySpan span) + { + for (int i = 0; i < span.Length; i++) + { + if (!IsWhiteSpace(span[i])) + { + return i; + } + } + + return -1; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static OperationStatus DecodeWithWhiteSpaceBlockwiseWrapper(ReadOnlySpan utf8, + Span bytes, ref int bytesConsumed, ref int bytesWritten, bool isFinalBlock = true) + where TBase64Decoder : IBase64Decoder => + DecodeWithWhiteSpaceBlockwise(utf8, bytes, ref bytesConsumed, ref bytesWritten, isFinalBlock); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static unsafe bool TryLoadVector512(byte* src, byte* srcStart, int sourceLength, out Vector512 str) + { + AssertRead>(src, srcStart, sourceLength); + str = Vector512.Load(src).AsSByte(); + return true; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + [CompExactlyDependsOn(typeof(Avx2))] + public static unsafe bool TryLoadAvxVector256(byte* src, byte* srcStart, int sourceLength, out Vector256 str) + { + AssertRead>(src, srcStart, sourceLength); + str = Avx.LoadVector256(src).AsSByte(); + return true; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static unsafe bool TryLoadVector128(byte* src, byte* srcStart, int sourceLength, out Vector128 str) + { + AssertRead>(src, srcStart, sourceLength); + str = Vector128.LoadUnsafe(ref *src); + return true; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + [CompExactlyDependsOn(typeof(AdvSimd.Arm64))] + public static unsafe bool TryLoadArmVector128x4(byte* src, byte* srcStart, int sourceLength, + out Vector128 str1, out Vector128 str2, out Vector128 str3, out Vector128 str4) + { + AssertRead>(src, srcStart, sourceLength); + (str1, str2, str3, str4) = AdvSimd.Arm64.LoadVector128x4AndUnzip(src); + + return true; + } + } } } diff --git a/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Encoder.cs b/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Encoder.cs index b63c711e41032..9df864b5bf601 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Encoder.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Encoder.cs @@ -1,7 +1,6 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. -using System.Runtime; using System.Runtime.CompilerServices; using System.Runtime.InteropServices; using System.Runtime.Intrinsics; @@ -29,39 +28,36 @@ public static partial class Base64 /// Set to when the source buffer contains the entirety of the data to encode. /// Set to if this method is being called in a loop and if more input data may follow. /// At the end of the loop, call this (potentially with an empty source buffer) passing . - /// It returns the OperationStatus enum values: + /// It returns the enum values: /// - Done - on successful processing of the entire input span /// - DestinationTooSmall - if there is not enough space in the output span to fit the encoded input /// - NeedMoreData - only if is , otherwise the output is padded if the input is not a multiple of 3 /// It does not return InvalidData since that is not possible for base64 encoding. /// - public static unsafe OperationStatus EncodeToUtf8(ReadOnlySpan bytes, Span utf8, out int bytesConsumed, out int bytesWritten, bool isFinalBlock = true) + public static unsafe OperationStatus EncodeToUtf8(ReadOnlySpan bytes, Span utf8, out int bytesConsumed, out int bytesWritten, bool isFinalBlock = true) => + EncodeTo(bytes, utf8, out bytesConsumed, out bytesWritten, isFinalBlock); + + internal static unsafe OperationStatus EncodeTo(ReadOnlySpan source, + Span destination, out int bytesConsumed, out int bytesWritten, bool isFinalBlock = true) + where TBase64Encoder : IBase64Encoder + where T : unmanaged { - if (bytes.IsEmpty) + if (source.IsEmpty) { bytesConsumed = 0; bytesWritten = 0; return OperationStatus.Done; } - fixed (byte* srcBytes = &MemoryMarshal.GetReference(bytes)) - fixed (byte* destBytes = &MemoryMarshal.GetReference(utf8)) + fixed (byte* srcBytes = &MemoryMarshal.GetReference(source)) + fixed (T* destBytes = &MemoryMarshal.GetReference(destination)) { - int srcLength = bytes.Length; - int destLength = utf8.Length; - int maxSrcLength; - - if (srcLength <= MaximumEncodeLength && destLength >= GetMaxEncodedToUtf8Length(srcLength)) - { - maxSrcLength = srcLength; - } - else - { - maxSrcLength = (destLength >> 2) * 3; - } + int srcLength = source.Length; + int destLength = destination.Length; + int maxSrcLength = TBase64Encoder.GetMaxSrcLength(srcLength, destLength); byte* src = srcBytes; - byte* dest = destBytes; + T* dest = destBytes; byte* srcEnd = srcBytes + (uint)srcLength; byte* srcMax = srcBytes + (uint)maxSrcLength; @@ -70,16 +66,16 @@ public static unsafe OperationStatus EncodeToUtf8(ReadOnlySpan bytes, Span byte* end = srcMax - 64; if (Vector512.IsHardwareAccelerated && Avx512Vbmi.IsSupported && (end >= src)) { - Avx512Encode(ref src, ref dest, end, maxSrcLength, destLength, srcBytes, destBytes); + Avx512Encode(ref src, ref dest, end, maxSrcLength, destLength, srcBytes, destBytes); if (src == srcEnd) goto DoneExit; } - end = srcMax - 64; + end = srcMax - 32; if (Avx2.IsSupported && (end >= src)) { - Avx2Encode(ref src, ref dest, end, maxSrcLength, destLength, srcBytes, destBytes); + Avx2Encode(ref src, ref dest, end, maxSrcLength, destLength, srcBytes, destBytes); if (src == srcEnd) goto DoneExit; @@ -88,7 +84,7 @@ public static unsafe OperationStatus EncodeToUtf8(ReadOnlySpan bytes, Span end = srcMax - 48; if (AdvSimd.Arm64.IsSupported && (end >= src)) { - AdvSimdEncode(ref src, ref dest, end, maxSrcLength, destLength, srcBytes, destBytes); + AdvSimdEncode(ref src, ref dest, end, maxSrcLength, destLength, srcBytes, destBytes); if (src == srcEnd) goto DoneExit; @@ -97,21 +93,19 @@ public static unsafe OperationStatus EncodeToUtf8(ReadOnlySpan bytes, Span end = srcMax - 16; if ((Ssse3.IsSupported || AdvSimd.Arm64.IsSupported) && BitConverter.IsLittleEndian && (end >= src)) { - Vector128Encode(ref src, ref dest, end, maxSrcLength, destLength, srcBytes, destBytes); + Vector128Encode(ref src, ref dest, end, maxSrcLength, destLength, srcBytes, destBytes); if (src == srcEnd) goto DoneExit; } } - ref byte encodingMap = ref MemoryMarshal.GetReference(EncodingMap); - uint result = 0; + ref byte encodingMap = ref MemoryMarshal.GetReference(TBase64Encoder.EncodingMap); srcMax -= 2; while (src < srcMax) { - result = Encode(src, ref encodingMap); - Unsafe.WriteUnaligned(dest, result); + TBase64Encoder.EncodeThreeAndWrite(src, dest, ref encodingMap); src += 3; dest += 4; } @@ -129,17 +123,15 @@ public static unsafe OperationStatus EncodeToUtf8(ReadOnlySpan bytes, Span if (src + 1 == srcEnd) { - result = EncodeAndPadTwo(src, ref encodingMap); - Unsafe.WriteUnaligned(dest, result); + TBase64Encoder.EncodeOneOptionallyPadTwo(src, dest, ref encodingMap); src += 1; - dest += 4; + dest += TBase64Encoder.IncrementPadTwo; } else if (src + 2 == srcEnd) { - result = EncodeAndPadOne(src, ref encodingMap); - Unsafe.WriteUnaligned(dest, result); + TBase64Encoder.EncodeTwoOptionallyPadOne(src, dest, ref encodingMap); src += 2; - dest += 4; + dest += TBase64Encoder.IncrementPadOne; } DoneExit: @@ -168,8 +160,7 @@ public static unsafe OperationStatus EncodeToUtf8(ReadOnlySpan bytes, Span [MethodImpl(MethodImplOptions.AggressiveInlining)] public static int GetMaxEncodedToUtf8Length(int length) { - if ((uint)length > MaximumEncodeLength) - ThrowHelper.ThrowArgumentOutOfRangeException(ExceptionArgument.length); + ArgumentOutOfRangeException.ThrowIfGreaterThan((uint)length, MaximumEncodeLength); return ((length + 2) / 3) * 4; } @@ -189,7 +180,11 @@ public static int GetMaxEncodedToUtf8Length(int length) /// It does not return NeedMoreData since this method tramples the data in the buffer and hence can only be called once with all the data in the buffer. /// It does not return InvalidData since that is not possible for base 64 encoding. /// - public static unsafe OperationStatus EncodeToUtf8InPlace(Span buffer, int dataLength, out int bytesWritten) + public static unsafe OperationStatus EncodeToUtf8InPlace(Span buffer, int dataLength, out int bytesWritten) => + EncodeToUtf8InPlace(buffer, dataLength, out bytesWritten); + + internal static unsafe OperationStatus EncodeToUtf8InPlace(Span buffer, int dataLength, out int bytesWritten) + where TBase64Encoder : IBase64Encoder { if (buffer.IsEmpty) { @@ -199,37 +194,38 @@ public static unsafe OperationStatus EncodeToUtf8InPlace(Span buffer, int fixed (byte* bufferBytes = &MemoryMarshal.GetReference(buffer)) { - int encodedLength = GetMaxEncodedToUtf8Length(dataLength); + int encodedLength = TBase64Encoder.GetMaxEncodedLength(dataLength); if (buffer.Length < encodedLength) - goto FalseExit; + { + bytesWritten = 0; + return OperationStatus.DestinationTooSmall; + } - int leftover = dataLength - (dataLength / 3) * 3; // how many bytes after packs of 3 + int leftover = (int)((uint)dataLength % 3); // how many bytes after packs of 3 - uint destinationIndex = (uint)(encodedLength - 4); + uint destinationIndex = TBase64Encoder.GetInPlaceDestinationLength(encodedLength, leftover); uint sourceIndex = (uint)(dataLength - leftover); - uint result = 0; - ref byte encodingMap = ref MemoryMarshal.GetReference(EncodingMap); + ref byte encodingMap = ref MemoryMarshal.GetReference(TBase64Encoder.EncodingMap); // encode last pack to avoid conditional in the main loop if (leftover != 0) { if (leftover == 1) { - result = EncodeAndPadTwo(bufferBytes + sourceIndex, ref encodingMap); + TBase64Encoder.EncodeOneOptionallyPadTwo(bufferBytes + sourceIndex, bufferBytes + destinationIndex, ref encodingMap); } else { - result = EncodeAndPadOne(bufferBytes + sourceIndex, ref encodingMap); + TBase64Encoder.EncodeTwoOptionallyPadOne(bufferBytes + sourceIndex, bufferBytes + destinationIndex, ref encodingMap); } - Unsafe.WriteUnaligned(bufferBytes + destinationIndex, result); destinationIndex -= 4; } sourceIndex -= 3; while ((int)sourceIndex >= 0) { - result = Encode(bufferBytes + sourceIndex, ref encodingMap); + uint result = Encode(bufferBytes + sourceIndex, ref encodingMap); Unsafe.WriteUnaligned(bufferBytes + destinationIndex, result); destinationIndex -= 4; sourceIndex -= 3; @@ -237,17 +233,15 @@ public static unsafe OperationStatus EncodeToUtf8InPlace(Span buffer, int bytesWritten = encodedLength; return OperationStatus.Done; - - FalseExit: - bytesWritten = 0; - return OperationStatus.DestinationTooSmall; } } [MethodImpl(MethodImplOptions.AggressiveInlining)] [CompExactlyDependsOn(typeof(Avx512BW))] [CompExactlyDependsOn(typeof(Avx512Vbmi))] - private static unsafe void Avx512Encode(ref byte* srcBytes, ref byte* destBytes, byte* srcEnd, int sourceLength, int destLength, byte* srcStart, byte* destStart) + private static unsafe void Avx512Encode(ref byte* srcBytes, ref T* destBytes, byte* srcEnd, int sourceLength, int destLength, byte* srcStart, T* destStart) + where TBase64Encoder : IBase64Encoder + where T : unmanaged { // Reference for VBMI implementation : https://github.com/WojciechMula/base64simd/tree/master/encode // If we have AVX512 support, pick off 48 bytes at a time for as long as we can. @@ -255,7 +249,7 @@ private static unsafe void Avx512Encode(ref byte* srcBytes, ref byte* destBytes, // full 64-byte read without segfaulting. byte* src = srcBytes; - byte* dest = destBytes; + T* dest = destBytes; // The JIT won't hoist these "constants", so help it Vector512 shuffleVecVbmi = Vector512.Create( @@ -263,7 +257,7 @@ private static unsafe void Avx512Encode(ref byte* srcBytes, ref byte* destBytes, 0x0d0e0c0d, 0x10110f10, 0x13141213, 0x16171516, 0x191a1819, 0x1c1d1b1c, 0x1f201e1f, 0x22232122, 0x25262425, 0x28292728, 0x2b2c2a2b, 0x2e2f2d2e).AsSByte(); - Vector512 vbmiLookup = Vector512.Create("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"u8).AsSByte(); + Vector512 vbmiLookup = Vector512.Create(TBase64Encoder.EncodingMap).AsSByte(); Vector512 maskAC = Vector512.Create((uint)0x0fc0fc00).AsUInt16(); Vector512 maskBB = Vector512.Create((uint)0x3f003f00); @@ -273,7 +267,7 @@ private static unsafe void Avx512Encode(ref byte* srcBytes, ref byte* destBytes, AssertRead>(src, srcStart, sourceLength); // This algorithm requires AVX512VBMI support. - // Vbmi was first introduced in CannonLake and is avaialable from IceLake on. + // Vbmi was first introduced in CannonLake and is available from IceLake on. // str = [...|PONM|LKJI|HGFE|DCBA] Vector512 str = Vector512.Load(src).AsSByte(); @@ -301,8 +295,7 @@ private static unsafe void Avx512Encode(ref byte* srcBytes, ref byte* destBytes, // Step 2: Now we have the indices calculated. Next step is to use these indices to translate. str = Avx512Vbmi.PermuteVar64x8(vbmiLookup, str); - AssertWrite>(dest, destStart, destLength); - str.Store((sbyte*)dest); + TBase64Encoder.StoreVector512ToDestination(dest, destStart, destLength, str.AsByte()); src += 48; dest += 64; @@ -320,7 +313,9 @@ private static unsafe void Avx512Encode(ref byte* srcBytes, ref byte* destBytes, [MethodImpl(MethodImplOptions.AggressiveInlining)] [CompExactlyDependsOn(typeof(Avx2))] - private static unsafe void Avx2Encode(ref byte* srcBytes, ref byte* destBytes, byte* srcEnd, int sourceLength, int destLength, byte* srcStart, byte* destStart) + private static unsafe void Avx2Encode(ref byte* srcBytes, ref T* destBytes, byte* srcEnd, int sourceLength, int destLength, byte* srcStart, T* destStart) + where TBase64Encoder : IBase64Encoder + where T : unmanaged { // If we have AVX2 support, pick off 24 bytes at a time for as long as we can. // But because we read 32 bytes at a time, ensure we have enough room to do a @@ -349,11 +344,11 @@ private static unsafe void Avx2Encode(ref byte* srcBytes, ref byte* destBytes, b 65, 71, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, - -19, -16, 0, 0, + TBase64Encoder.Avx2LutChar62, TBase64Encoder.Avx2LutChar63, 0, 0, 65, 71, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, - -19, -16, 0, 0); + TBase64Encoder.Avx2LutChar62, TBase64Encoder.Avx2LutChar63, 0, 0); Vector256 maskAC = Vector256.Create(0x0fc0fc00).AsSByte(); Vector256 maskBB = Vector256.Create(0x003f03f0).AsSByte(); @@ -363,7 +358,7 @@ private static unsafe void Avx2Encode(ref byte* srcBytes, ref byte* destBytes, b Vector256 const25 = Vector256.Create((sbyte)25); byte* src = srcBytes; - byte* dest = destBytes; + T* dest = destBytes; // first load is done at c-0 not to get a segfault AssertRead>(src, srcStart, sourceLength); @@ -471,8 +466,7 @@ private static unsafe void Avx2Encode(ref byte* srcBytes, ref byte* destBytes, b // Add offsets to input values: str = Avx2.Add(str, Avx2.Shuffle(lut, tmp)); - AssertWrite>(dest, destStart, destLength); - Avx.Store(dest, str.AsByte()); + TBase64Encoder.StoreVector256ToDestination(dest, destStart, destLength, str.AsByte()); src += 24; dest += 32; @@ -491,7 +485,9 @@ private static unsafe void Avx2Encode(ref byte* srcBytes, ref byte* destBytes, b [MethodImpl(MethodImplOptions.AggressiveInlining)] [CompExactlyDependsOn(typeof(AdvSimd.Arm64))] - private static unsafe void AdvSimdEncode(ref byte* srcBytes, ref byte* destBytes, byte* srcEnd, int sourceLength, int destLength, byte* srcStart, byte* destStart) + private static unsafe void AdvSimdEncode(ref byte* srcBytes, ref T* destBytes, byte* srcEnd, int sourceLength, int destLength, byte* srcStart, T* destStart) + where TBase64Encoder : IBase64Encoder + where T : unmanaged { // C# implementation of https://github.com/aklomp/base64/blob/3a5add8652076612a8407627a42c768736a4263f/lib/arch/neon64/enc_loop.c Vector128 str1; @@ -504,9 +500,9 @@ private static unsafe void AdvSimdEncode(ref byte* srcBytes, ref byte* destBytes Vector128 tblEnc1 = Vector128.Create("ABCDEFGHIJKLMNOP"u8).AsByte(); Vector128 tblEnc2 = Vector128.Create("QRSTUVWXYZabcdef"u8).AsByte(); Vector128 tblEnc3 = Vector128.Create("ghijklmnopqrstuv"u8).AsByte(); - Vector128 tblEnc4 = Vector128.Create("wxyz0123456789+/"u8).AsByte(); + Vector128 tblEnc4 = Vector128.Create(TBase64Encoder.AdvSimdLut4).AsByte(); byte* src = srcBytes; - byte* dest = destBytes; + T* dest = destBytes; // If we have Neon support, pick off 48 bytes at a time for as long as we can. do @@ -536,8 +532,7 @@ private static unsafe void AdvSimdEncode(ref byte* srcBytes, ref byte* destBytes res4 = AdvSimd.Arm64.VectorTableLookup((tblEnc1, tblEnc2, tblEnc3, tblEnc4), res4); // Interleave and store result: - AssertWrite>(dest, destStart, destLength); - AdvSimd.Arm64.StoreVector128x4AndZip(dest, (res1, res2, res3, res4)); + TBase64Encoder.StoreArmVector128x4ToDestination(dest, destStart, destLength, res1, res2, res3, res4); src += 48; dest += 64; @@ -550,7 +545,9 @@ private static unsafe void AdvSimdEncode(ref byte* srcBytes, ref byte* destBytes [MethodImpl(MethodImplOptions.AggressiveInlining)] [CompExactlyDependsOn(typeof(Ssse3))] [CompExactlyDependsOn(typeof(AdvSimd.Arm64))] - private static unsafe void Vector128Encode(ref byte* srcBytes, ref byte* destBytes, byte* srcEnd, int sourceLength, int destLength, byte* srcStart, byte* destStart) + private static unsafe void Vector128Encode(ref byte* srcBytes, ref T* destBytes, byte* srcEnd, int sourceLength, int destLength, byte* srcStart, T* destStart) + where TBase64Encoder : IBase64Encoder + where T : unmanaged { // If we have SSSE3 support, pick off 12 bytes at a time for as long as we can. // But because we read 16 bytes at a time, ensure we have enough room to do a @@ -561,7 +558,7 @@ private static unsafe void Vector128Encode(ref byte* srcBytes, ref byte* destByt // The JIT won't hoist these "constants", so help it Vector128 shuffleVec = Vector128.Create(0x01020001, 0x04050304, 0x07080607, 0x0A0B090A).AsByte(); - Vector128 lut = Vector128.Create(0xFCFC4741, 0xFCFCFCFC, 0xFCFCFCFC, 0x0000F0ED).AsByte(); + Vector128 lut = Vector128.Create(0xFCFC4741, 0xFCFCFCFC, 0xFCFCFCFC, TBase64Encoder.Ssse3AdvSimdLutE3).AsByte(); Vector128 maskAC = Vector128.Create(0x0fc0fc00).AsByte(); Vector128 maskBB = Vector128.Create(0x003f03f0).AsByte(); Vector128 shiftAC = Vector128.Create(0x04000040).AsUInt16(); @@ -571,7 +568,7 @@ private static unsafe void Vector128Encode(ref byte* srcBytes, ref byte* destByt Vector128 mask8F = Vector128.Create((byte)0x8F); byte* src = srcBytes; - byte* dest = destBytes; + T* dest = destBytes; //while (remaining >= 16) do @@ -659,8 +656,7 @@ private static unsafe void Vector128Encode(ref byte* srcBytes, ref byte* destByt // Add offsets to input values: str += SimdShuffle(lut, tmp.AsByte(), mask8F); - AssertWrite>(dest, destStart, destLength); - str.Store(dest); + TBase64Encoder.StoreVector128ToDestination(dest, destStart, destLength, str); src += 12; dest += 16; @@ -695,52 +691,149 @@ private static unsafe uint Encode(byte* threeBytes, ref byte encodingMap) } } - [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static unsafe uint EncodeAndPadOne(byte* twoBytes, ref byte encodingMap) + internal const uint EncodingPad = '='; // '=', for padding + + internal const int MaximumEncodeLength = (int.MaxValue / 4) * 3; // 1610612733 + + internal readonly struct Base64EncoderByte : IBase64Encoder { - uint t0 = twoBytes[0]; - uint t1 = twoBytes[1]; + public static ReadOnlySpan EncodingMap => "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"u8; - uint i = (t0 << 16) | (t1 << 8); + public static sbyte Avx2LutChar62 => -19; // char '+' diff - uint i0 = Unsafe.Add(ref encodingMap, (IntPtr)(i >> 18)); - uint i1 = Unsafe.Add(ref encodingMap, (IntPtr)((i >> 12) & 0x3F)); - uint i2 = Unsafe.Add(ref encodingMap, (IntPtr)((i >> 6) & 0x3F)); + public static sbyte Avx2LutChar63 => -16; // char '/' diff - if (BitConverter.IsLittleEndian) + public static ReadOnlySpan AdvSimdLut4 => "wxyz0123456789+/"u8; + + public static uint Ssse3AdvSimdLutE3 => 0x0000F0ED; + + public static int IncrementPadTwo => 4; + + public static int IncrementPadOne => 4; + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static int GetMaxSrcLength(int srcLength, int destLength) => + srcLength <= MaximumEncodeLength && destLength >= GetMaxEncodedToUtf8Length(srcLength) ? + srcLength : (destLength >> 2) * 3; + + public static uint GetInPlaceDestinationLength(int encodedLength, int _) => (uint)(encodedLength - 4); + + public static int GetMaxEncodedLength(int srcLength) => GetMaxEncodedToUtf8Length(srcLength); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static unsafe void EncodeOneOptionallyPadTwo(byte* oneByte, byte* dest, ref byte encodingMap) { - return i0 | (i1 << 8) | (i2 << 16) | (EncodingPad << 24); + uint t0 = oneByte[0]; + + uint i = t0 << 8; + + uint i0 = Unsafe.Add(ref encodingMap, (IntPtr)(i >> 10)); + uint i1 = Unsafe.Add(ref encodingMap, (IntPtr)((i >> 4) & 0x3F)); + + if (BitConverter.IsLittleEndian) + { + dest[0] = (byte)i0; + dest[1] = (byte)i1; + dest[2] = (byte)EncodingPad; + dest[3] = (byte)EncodingPad; + } + else + { + dest[3] = (byte)i0; + dest[2] = (byte)i1; + dest[1] = (byte)EncodingPad; + dest[0] = (byte)EncodingPad; + } } - else + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static unsafe void EncodeTwoOptionallyPadOne(byte* twoBytes, byte* dest, ref byte encodingMap) { - return (i0 << 24) | (i1 << 16) | (i2 << 8) | EncodingPad; - } - } + uint t0 = twoBytes[0]; + uint t1 = twoBytes[1]; - [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static unsafe uint EncodeAndPadTwo(byte* oneByte, ref byte encodingMap) - { - uint t0 = oneByte[0]; + uint i = (t0 << 16) | (t1 << 8); - uint i = t0 << 8; + uint i0 = Unsafe.Add(ref encodingMap, (IntPtr)(i >> 18)); + uint i1 = Unsafe.Add(ref encodingMap, (IntPtr)((i >> 12) & 0x3F)); + uint i2 = Unsafe.Add(ref encodingMap, (IntPtr)((i >> 6) & 0x3F)); - uint i0 = Unsafe.Add(ref encodingMap, (IntPtr)(i >> 10)); - uint i1 = Unsafe.Add(ref encodingMap, (IntPtr)((i >> 4) & 0x3F)); + if (BitConverter.IsLittleEndian) + { + dest[0] = (byte)i0; + dest[1] = (byte)i1; + dest[2] = (byte)i2; + dest[3] = (byte)EncodingPad; + } + else + { + dest[3] = (byte)i0; + dest[2] = (byte)i1; + dest[1] = (byte)i2; + dest[0] = (byte)EncodingPad; + } + } - if (BitConverter.IsLittleEndian) + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static unsafe void StoreVector512ToDestination(byte* dest, byte* destStart, int destLength, Vector512 str) { - return i0 | (i1 << 8) | (EncodingPad << 16) | (EncodingPad << 24); + AssertWrite>(dest, destStart, destLength); + str.Store(dest); } - else + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + [CompExactlyDependsOn(typeof(Avx2))] + public static unsafe void StoreVector256ToDestination(byte* dest, byte* destStart, int destLength, Vector256 str) { - return (i0 << 24) | (i1 << 16) | (EncodingPad << 8) | EncodingPad; + AssertWrite>(dest, destStart, destLength); + Avx.Store(dest, str.AsByte()); } - } - internal const uint EncodingPad = '='; // '=', for padding + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static unsafe void StoreVector128ToDestination(byte* dest, byte* destStart, int destLength, Vector128 str) + { + AssertWrite>(dest, destStart, destLength); + str.Store(dest); + } - private const int MaximumEncodeLength = (int.MaxValue / 4) * 3; // 1610612733 + [MethodImpl(MethodImplOptions.AggressiveInlining)] + [CompExactlyDependsOn(typeof(AdvSimd.Arm64))] + public static unsafe void StoreArmVector128x4ToDestination(byte* dest, byte* destStart, int destLength, + Vector128 res1, Vector128 res2, Vector128 res3, Vector128 res4) + { + AssertWrite>(dest, destStart, destLength); + AdvSimd.Arm64.StoreVector128x4AndZip(dest, (res1, res2, res3, res4)); + } - internal static ReadOnlySpan EncodingMap => "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"u8; + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static unsafe void EncodeThreeAndWrite(byte* threeBytes, byte* destination, ref byte encodingMap) + { + uint t0 = threeBytes[0]; + uint t1 = threeBytes[1]; + uint t2 = threeBytes[2]; + + uint i = (t0 << 16) | (t1 << 8) | t2; + + byte i0 = Unsafe.Add(ref encodingMap, (IntPtr)(i >> 18)); + byte i1 = Unsafe.Add(ref encodingMap, (IntPtr)((i >> 12) & 0x3F)); + byte i2 = Unsafe.Add(ref encodingMap, (IntPtr)((i >> 6) & 0x3F)); + byte i3 = Unsafe.Add(ref encodingMap, (IntPtr)(i & 0x3F)); + + if (BitConverter.IsLittleEndian) + { + destination[0] = i0; + destination[1] = i1; + destination[2] = i2; + destination[3] = i3; + } + else + { + destination[3] = i0; + destination[2] = i1; + destination[1] = i2; + destination[0] = i3; + } + } + } } } diff --git a/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Url/Base64UrlDecoder.cs b/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Url/Base64UrlDecoder.cs new file mode 100644 index 0000000000000..6e7d97f79426c --- /dev/null +++ b/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Url/Base64UrlDecoder.cs @@ -0,0 +1,788 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Diagnostics; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using System.Runtime.Intrinsics; +using System.Runtime.Intrinsics.Arm; +using System.Runtime.Intrinsics.X86; +using System.Text; +using static System.Buffers.Text.Base64; + +namespace System.Buffers.Text +{ + // AVX2 and Vector128 version based on https://github.com/gfoidl/Base64/blob/5383320e28cac6c7ac6f86502fb05d23a048a21d/source/gfoidl.Base64/Internal/Encodings/Base64UrlEncoding.cs + + public static partial class Base64Url + { + private const int MaxStackallocThreshold = 256; + + /// + /// Returns the maximum length (in bytes) of the result if you were to decode base 64 encoded text from a span of size . + /// + /// The specified is less than 0. + /// + public static int GetMaxDecodedLength(int base64Length) + { + ArgumentOutOfRangeException.ThrowIfNegative(base64Length); + + (uint whole, uint remainder) = uint.DivRem((uint)base64Length, 4); + + return (int)(whole * 3 + (remainder > 0 ? remainder - 1 : 0)); + } + + /// + /// Decodes the span of UTF-8 encoded text represented as Base64Url into binary data. + /// + /// The input span which contains UTF-8 encoded text in Base64Url that needs to be decoded. + /// The output span which contains the result of the operation, i.e. the decoded binary data. + /// When this method returns, contains the number of input bytes consumed during the operation. This can be used to slice the input for subsequent calls, if necessary. This parameter is treated as uninitialized. + /// When this method returns, contains the number of bytes written into the output span. This can be used to slice the output for subsequent calls, if necessary. This parameter is treated as uninitialized. + /// when the input span contains the entirety of data to encode; when more data may follow, + /// such as when calling in a loop. Calls with should be followed up with another call where this parameter is call. The default is . + /// One of the enumeration values that indicates the success or failure of the operation. + /// + /// As padding is optional for Base64Url the length not required to be a multiple of 4 even if is . + /// If the length is not a multiple of 4 and is the remainders decoded accordingly: + /// - Remainder of 3 bytes - decoded into 2 bytes data, decoding succeeds. + /// - Remainder of 2 bytes - decoded into 1 byte data. decoding succeeds. + /// - Remainder of 1 byte - will cause OperationStatus.InvalidData result. + /// + public static OperationStatus DecodeFromUtf8(ReadOnlySpan source, Span destination, out int bytesConsumed, out int bytesWritten, bool isFinalBlock = true) => + DecodeFrom(source, destination, out bytesConsumed, out bytesWritten, isFinalBlock, ignoreWhiteSpace: true); + + /// + /// Decodes the span of UTF-8 encoded text in Base64Url into binary data, in-place. + /// The decoded binary output is smaller than the text data contained in the input (the operation deflates the data). + /// + /// The input span which contains the base 64 text data that needs to be decoded. + /// The number of bytes written into . This can be used to slice the output for subsequent calls, if necessary. + /// contains an invalid Base64Url character, + /// more than two padding characters, or a non white space character among the padding characters. + /// + /// As padding is optional for Base64Url the length not required to be a multiple of 4. + /// If the length is not a multiple of 4 the remainders decoded accordingly: + /// - Remainder of 3 bytes - decoded into 2 bytes data, decoding succeeds. + /// - Remainder of 2 bytes - decoded into 1 byte data. decoding succeeds. + /// - Remainder of 1 byte - is invalid input, causes FormatException. + /// + public static int DecodeFromUtf8InPlace(Span buffer) + { + OperationStatus status = DecodeFromUtf8InPlace(buffer, out int bytesWritten, ignoreWhiteSpace: true); + + // Base64.DecodeFromUtf8InPlace returns OperationStatus, therefore doesn't throw. + // For Base64Url, this is not an OperationStatus API and thus throws. + if (status == OperationStatus.InvalidData) + { + throw new FormatException(SR.Format_BadBase64Char); + } + + Debug.Assert(status is OperationStatus.Done); + return bytesWritten; + } + + /// + /// Decodes the span of UTF-8 encoded text represented as Base64Url into binary data. + /// + /// The input span which contains UTF-8 encoded text in Base64Url that needs to be decoded. + /// The output span which contains the result of the operation, i.e. the decoded binary data. + /// The number of bytes written into . This can be used to slice the output for subsequent calls, if necessary. + /// The buffer in is too small to hold the encoded output. + /// contains an invalid Base64Url character, + /// more than two padding characters, or a non white space character among the padding characters. + /// + /// As padding is optional for Base64Url the length not required to be a multiple of 4. + /// If the length is not a multiple of 4 the remainders decoded accordingly: + /// - Remainder of 3 bytes - decoded into 2 bytes data, decoding succeeds. + /// - Remainder of 2 bytes - decoded into 1 byte data. decoding succeeds. + /// - Remainder of 1 byte - is invalid input, causes FormatException. + /// + public static int DecodeFromUtf8(ReadOnlySpan source, Span destination) + { + OperationStatus status = DecodeFromUtf8(source, destination, out _, out int bytesWritten); + + if (status == OperationStatus.Done) + { + return bytesWritten; + } + + if (status == OperationStatus.DestinationTooSmall) + { + throw new ArgumentException(SR.Argument_DestinationTooShort, nameof(destination)); + } + + Debug.Assert(status is OperationStatus.InvalidData); + throw new FormatException(SR.Format_BadBase64Char); + } + + /// + /// Decodes the span of UTF-8 encoded text represented as Base64Url into binary data. + /// + /// The input span which contains UTF-8 encoded text in Base64Url that needs to be decoded. + /// The output span which contains the result of the operation, i.e. the decoded binary data. + /// When this method returns, contains the number of bytes written into the output span. This can be used to slice the output for subsequent calls, if necessary. This parameter is treated as uninitialized. + /// if bytes decoded successfully, otherwise . + /// contains an invalid Base64Url character, + /// more than two padding characters, or a non white space character among the padding characters. + public static bool TryDecodeFromUtf8(ReadOnlySpan source, Span destination, out int bytesWritten) + { + OperationStatus status = DecodeFromUtf8(source, destination, out _, out bytesWritten); + + if (status == OperationStatus.InvalidData) + { + throw new FormatException(SR.Format_BadBase64Char); + } + + Debug.Assert(status is OperationStatus.Done or OperationStatus.DestinationTooSmall); + return status == OperationStatus.Done; + } + + /// + /// Decodes the span of UTF-8 encoded text represented as Base64Url into binary data. + /// + /// The input span which contains UTF-8 encoded text in Base64Url that needs to be decoded. + /// >A byte array which contains the result of the decoding operation. + /// contains an invalid Base64Url character, + /// more than two padding characters, or a non white space character among the padding characters. + public static byte[] DecodeFromUtf8(ReadOnlySpan source) + { + int upperBound = GetMaxDecodedLength(source.Length); + byte[]? rented = null; + + Span destination = upperBound <= MaxStackallocThreshold + ? stackalloc byte[MaxStackallocThreshold] + : (rented = ArrayPool.Shared.Rent(upperBound)); + + OperationStatus status = DecodeFromUtf8(source, destination, out _, out int bytesWritten); + Debug.Assert(status is OperationStatus.Done or OperationStatus.InvalidData); + byte[] ret = destination.Slice(0, bytesWritten).ToArray(); + + if (rented is not null) + { + ArrayPool.Shared.Return(rented); + } + + return status == OperationStatus.Done ? ret : throw new FormatException(SR.Format_BadBase64Char); + } + + /// + /// Decodes the span of unicode ASCII chars represented as Base64Url into binary data. + /// + /// The input span which contains unicode ASCII chars in Base64Url that needs to be decoded. + /// The output span which contains the result of the operation, i.e. the decoded binary data. + /// When this method returns, contains the number of input chars consumed during the operation. This can be used to slice the input for subsequent calls, if necessary. This parameter is treated as uninitialized. + /// When this method returns, contains the number of bytes written into the output span. This can be used to slice the output for subsequent calls, if necessary. This parameter is treated as uninitialized. + /// when the input span contains the entirety of data to encode; when more data may follow, + /// such as when calling in a loop. Calls with should be followed up with another call where this parameter is call. The default is . + /// One of the enumeration values that indicates the success or failure of the operation. + /// + /// As padding is optional for Base64Url the length not required to be a multiple of 4 even if is . + /// If the length is not a multiple of 4 and is the remainders decoded accordingly: + /// - Remainder of 3 chars - decoded into 2 bytes data, decoding succeeds. + /// - Remainder of 2 chars - decoded into 1 byte data. decoding succeeds. + /// - Remainder of 1 char - will cause OperationStatus.InvalidData result. + /// + public static OperationStatus DecodeFromChars(ReadOnlySpan source, Span destination, + out int charsConsumed, out int bytesWritten, bool isFinalBlock = true) => + DecodeFrom(MemoryMarshal.Cast(source), destination, out charsConsumed, out bytesWritten, isFinalBlock, ignoreWhiteSpace: true); + + private static OperationStatus DecodeWithWhiteSpaceBlockwise(ReadOnlySpan source, Span bytes, ref int bytesConsumed, ref int bytesWritten, bool isFinalBlock = true) + where TBase64Decoder : IBase64Decoder + { + const int BlockSize = 4; + Span buffer = stackalloc ushort[BlockSize]; + OperationStatus status = OperationStatus.Done; + + while (!source.IsEmpty) + { + int encodedIdx = 0; + int bufferIdx = 0; + int skipped = 0; + + for (; encodedIdx < source.Length && (uint)bufferIdx < (uint)buffer.Length; ++encodedIdx) + { + if (IsWhiteSpace(source[encodedIdx])) + { + skipped++; + } + else + { + buffer[bufferIdx] = source[encodedIdx]; + bufferIdx++; + } + } + + source = source.Slice(encodedIdx); + bytesConsumed += skipped; + + if (bufferIdx == 0) + { + continue; + } + + bool hasAnotherBlock = source.Length >= BlockSize && bufferIdx == BlockSize; + bool localIsFinalBlock = !hasAnotherBlock; + + // If this block contains padding and there's another block, then only whitespace may follow for being valid. + if (hasAnotherBlock) + { + int paddingCount = GetPaddingCount(ref buffer[^1]); + if (paddingCount > 0) + { + hasAnotherBlock = false; + localIsFinalBlock = true; + } + } + + if (localIsFinalBlock && !isFinalBlock) + { + localIsFinalBlock = false; + } + + status = DecodeFrom(buffer.Slice(0, bufferIdx), bytes, out int localConsumed, out int localWritten, localIsFinalBlock, ignoreWhiteSpace: false); + bytesConsumed += localConsumed; + bytesWritten += localWritten; + + if (status != OperationStatus.Done) + { + return status; + } + + // The remaining data must all be whitespace in order to be valid. + if (!hasAnotherBlock) + { + for (int i = 0; i < source.Length; ++i) + { + if (!IsWhiteSpace(source[i])) + { + // Revert previous dest increment, since an invalid state followed. + bytesConsumed -= localConsumed; + bytesWritten -= localWritten; + + return OperationStatus.InvalidData; + } + + bytesConsumed++; + } + + break; + } + + bytes = bytes.Slice(localWritten); + } + + return status; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static int GetPaddingCount(ref ushort ptrToLastElement) + where TBase64Decoder : IBase64Decoder + { + int padding = 0; + + if (TBase64Decoder.IsValidPadding(ptrToLastElement)) + { + padding++; + } + + if (TBase64Decoder.IsValidPadding(Unsafe.Subtract(ref ptrToLastElement, 1))) + { + padding++; + } + + return padding; + } + + /// + /// Decodes the span of unicode ASCII chars represented as Base64Url into binary data. + /// + /// The input span which contains ASCII chars in Base64Url that needs to be decoded. + /// The output span which contains the result of the operation, i.e. the decoded binary data. + /// The number of bytes written into the output span. This can be used to slice the output for subsequent calls, if necessary. + /// The buffer in is too small to hold the encoded output. + /// contains a invalid Base64Url character, + /// more than two padding characters, or a non white space character among the padding characters. + public static int DecodeFromChars(ReadOnlySpan source, Span destination) + { + OperationStatus status = DecodeFromChars(source, destination, out _, out int bytesWritten); + + if (status == OperationStatus.Done) + { + return bytesWritten; + } + + if (status == OperationStatus.DestinationTooSmall) + { + throw new ArgumentException(SR.Argument_DestinationTooShort, nameof(destination)); + } + + Debug.Assert(status == OperationStatus.InvalidData); + throw new FormatException(SR.Format_BadBase64Char); + } + + /// + /// Decodes the span of unicode ASCII chars represented as Base64Url into binary data. + /// + /// The input span which contains ASCII chars in Base64Url that needs to be decoded. + /// The output span which contains the result of the operation, i.e. the decoded binary data. + /// When this method returns, contains the number of bytes written into the output span. This can be used to slice the output for subsequent calls, if necessary. This parameter is treated as uninitialized. + /// if bytes decoded successfully, otherwise . + /// contains an invalid Base64Url character, + /// more than two padding characters, or a non white space character among the padding characters. + public static bool TryDecodeFromChars(ReadOnlySpan source, Span destination, out int bytesWritten) + { + OperationStatus status = DecodeFromChars(source, destination, out _, out bytesWritten); + + if (status == OperationStatus.InvalidData) + { + throw new FormatException(SR.Format_BadBase64Char); + } + + return status == OperationStatus.Done; + } + + /// + /// Decodes the span of unicode ASCII chars represented as Base64Url into binary data. + /// + /// The input span which contains ASCII chars in Base64Url that needs to be decoded. + /// A byte array which contains the result of the decoding operation. + /// contains a invalid Base64Url character, + /// more than two padding characters, or a non white space character among the padding characters. + public static byte[] DecodeFromChars(ReadOnlySpan source) + { + int upperBound = GetMaxDecodedLength(source.Length); + byte[]? rented = null; + + Span destination = upperBound <= MaxStackallocThreshold + ? stackalloc byte[MaxStackallocThreshold] + : (rented = ArrayPool.Shared.Rent(upperBound)); + + OperationStatus status = DecodeFromChars(source, destination, out _, out int bytesWritten); + byte[] ret = destination.Slice(0, bytesWritten).ToArray(); + + if (rented is not null) + { + ArrayPool.Shared.Return(rented); + } + + return status == OperationStatus.Done ? ret : throw new FormatException(SR.Format_BadBase64Char); + } + + private readonly struct Base64UrlDecoderByte : IBase64Decoder + { + public static ReadOnlySpan DecodingMap => + [ + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 62, -1, -1, //62 is placed at index 45 (for -), 63 at index 95 (for _) + 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, -1, -1, -1, -1, -1, -1, //52-61 are placed at index 48-57 (for 0-9) + -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, + 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, -1, -1, -1, -1, 63, //0-25 are placed at index 65-90 (for A-Z) + -1, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, + 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, -1, -1, -1, -1, -1, //26-51 are placed at index 97-122 (for a-z) + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, // Bytes over 122 ('z') are invalid and cannot be decoded + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, // Hence, padding the map with 255, which indicates invalid input + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + ]; + + public static ReadOnlySpan VbmiLookup0 => + [ + 0x80808080, 0x80808080, 0x80808080, 0x80808080, + 0x80808080, 0x80808080, 0x80808080, 0x80808080, + 0x80808080, 0x80808080, 0x80808080, 0x80803e80, + 0x37363534, 0x3b3a3938, 0x80803d3c, 0x80808080 + ]; + + public static ReadOnlySpan VbmiLookup1 => + [ + 0x02010080, 0x06050403, 0x0a090807, 0x0e0d0c0b, + 0x1211100f, 0x16151413, 0x80191817, 0x3f808080, + 0x1c1b1a80, 0x201f1e1d, 0x24232221, 0x28272625, + 0x2c2b2a29, 0x302f2e2d, 0x80333231, 0x80808080 + ]; + + public static ReadOnlySpan Avx2LutHigh => + [ + 0x00, 0x00, 0x2d, 0x39, + 0x4f, 0x5a, 0x6f, 0x7a, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x2d, 0x39, + 0x4f, 0x5a, 0x6f, 0x7a, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00 + ]; + + public static ReadOnlySpan Avx2LutLow => + [ + 0x01, 0x01, 0x2d, 0x30, + 0x41, 0x50, 0x61, 0x70, + 0x01, 0x01, 0x01, 0x01, + 0x01, 0x01, 0x01, 0x01, + 0x01, 0x01, 0x2d, 0x30, + 0x41, 0x50, 0x61, 0x70, + 0x01, 0x01, 0x01, 0x01, + 0x01, 0x01, 0x01, 0x01 + ]; + + public static ReadOnlySpan Avx2LutShift => + [ + 0, 0, 17, 4, + -65, -65, -71, -71, + 0, 0, 0, 0, + 0, 0, 0, 0, + 0, 0, 17, 4, + -65, -65, -71, -71, + 0, 0, 0, 0, + 0, 0, 0, 0 + ]; + + public static byte MaskSlashOrUnderscore => (byte)'_'; // underscore + + public static ReadOnlySpan Vector128LutHigh => [0x392d0000, 0x7a6f5a4f, 0x00000000, 0x00000000]; + + public static ReadOnlySpan Vector128LutLow => [0x302d0101, 0x70615041, 0x01010101, 0x01010101]; + + public static ReadOnlySpan Vector128LutShift => [0x04110000, 0xb9b9bfbf, 0x00000000, 0x00000000]; + + public static ReadOnlySpan AdvSimdLutOne3 => [0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFF3EFF]; + + public static uint AdvSimdLutTwo3Uint1 => 0x1B1AFF3F; + + public static int GetMaxDecodedLength(int sourceLength) => Base64Url.GetMaxDecodedLength(sourceLength); + + public static bool IsInvalidLength(int bufferLength) => (bufferLength & 3) == 1; // One byte cannot be decoded completely + + public static bool IsValidPadding(uint padChar) => padChar == EncodingPad || padChar == UrlEncodingPad; + + public static int SrcLength(bool isFinalBlock, int sourceLength) => isFinalBlock ? sourceLength : sourceLength & ~0x3; + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + [CompExactlyDependsOn(typeof(AdvSimd.Arm64))] + [CompExactlyDependsOn(typeof(Ssse3))] + public static bool TryDecode128Core( + Vector128 str, + Vector128 hiNibbles, + Vector128 maskSlashOrUnderscore, + Vector128 mask8F, + Vector128 lutLow, + Vector128 lutHigh, + Vector128 lutShift, + Vector128 shiftForUnderscore, + out Vector128 result) + { + Vector128 lowerBound = SimdShuffle(lutLow, hiNibbles, mask8F); + Vector128 upperBound = SimdShuffle(lutHigh, hiNibbles, mask8F); + + Vector128 below = Vector128.LessThan(str, lowerBound); + Vector128 above = Vector128.GreaterThan(str, upperBound); + Vector128 eq5F = Vector128.Equals(str, maskSlashOrUnderscore); + + // Take care as arguments are flipped in order! + Vector128 outside = Vector128.AndNot(below | above, eq5F); + + if (outside != Vector128.Zero) + { + result = default; + return false; + } + + Vector128 shift = SimdShuffle(lutShift.AsByte(), hiNibbles, mask8F); + str += shift; + + result = str + (eq5F & shiftForUnderscore); + return true; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + [CompExactlyDependsOn(typeof(Avx2))] + public static bool TryDecode256Core( + Vector256 str, + Vector256 hiNibbles, + Vector256 maskSlashOrUnderscore, + Vector256 lutLow, + Vector256 lutHigh, + Vector256 lutShift, + Vector256 shiftForUnderscore, + out Vector256 result) + { + Vector256 lowerBound = Avx2.Shuffle(lutLow, hiNibbles); + Vector256 upperBound = Avx2.Shuffle(lutHigh, hiNibbles); + + Vector256 below = Vector256.LessThan(str, lowerBound); + Vector256 above = Vector256.GreaterThan(str, upperBound); + Vector256 eq5F = Vector256.Equals(str, maskSlashOrUnderscore); + + // Take care as arguments are flipped in order! + Vector256 outside = Vector256.AndNot(below | above, eq5F); + + if (outside != Vector256.Zero) + { + result = default; + return false; + } + + Vector256 shift = Avx2.Shuffle(lutShift, hiNibbles); + str += shift; + + result = str + (eq5F & shiftForUnderscore); + return true; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static unsafe int DecodeFourElements(byte* source, ref sbyte decodingMap) => + Base64DecoderByte.DecodeFourElements(source, ref decodingMap); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static unsafe int DecodeRemaining(byte* srcEnd, ref sbyte decodingMap, long remaining, out uint t2, out uint t3) + => Base64DecoderByte.DecodeRemaining(srcEnd, ref decodingMap, remaining, out t2, out t3); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static int IndexOfAnyExceptWhiteSpace(ReadOnlySpan span) => Base64DecoderByte.IndexOfAnyExceptWhiteSpace(span); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static OperationStatus DecodeWithWhiteSpaceBlockwiseWrapper(ReadOnlySpan utf8, Span bytes, + ref int bytesConsumed, ref int bytesWritten, bool isFinalBlock = true) where TBase64Decoder : IBase64Decoder => + Base64.DecodeWithWhiteSpaceBlockwise(utf8, bytes, ref bytesConsumed, ref bytesWritten, isFinalBlock); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static unsafe bool TryLoadVector512(byte* src, byte* srcStart, int sourceLength, out Vector512 str) => + Base64DecoderByte.TryLoadVector512(src, srcStart, sourceLength, out str); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + [CompExactlyDependsOn(typeof(Avx2))] + public static unsafe bool TryLoadAvxVector256(byte* src, byte* srcStart, int sourceLength, out Vector256 str) => + Base64DecoderByte.TryLoadAvxVector256(src, srcStart, sourceLength, out str); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static unsafe bool TryLoadVector128(byte* src, byte* srcStart, int sourceLength, out Vector128 str) => + Base64DecoderByte.TryLoadVector128(src, srcStart, sourceLength, out str); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + [CompExactlyDependsOn(typeof(AdvSimd.Arm64))] + public static unsafe bool TryLoadArmVector128x4(byte* src, byte* srcStart, int sourceLength, + out Vector128 str1, out Vector128 str2, out Vector128 str3, out Vector128 str4) => + Base64DecoderByte.TryLoadArmVector128x4(src, srcStart, sourceLength, out str1, out str2, out str3, out str4); + } + + private readonly struct Base64UrlDecoderChar : IBase64Decoder + { + public static ReadOnlySpan DecodingMap => Base64UrlDecoderByte.DecodingMap; + + public static ReadOnlySpan VbmiLookup0 => Base64UrlDecoderByte.VbmiLookup0; + + public static ReadOnlySpan VbmiLookup1 => Base64UrlDecoderByte.VbmiLookup1; + + public static ReadOnlySpan Avx2LutHigh => Base64UrlDecoderByte.Avx2LutHigh; + + public static ReadOnlySpan Avx2LutLow => Base64UrlDecoderByte.Avx2LutLow; + + public static ReadOnlySpan Avx2LutShift => Base64UrlDecoderByte.Avx2LutShift; + + public static byte MaskSlashOrUnderscore => Base64UrlDecoderByte.MaskSlashOrUnderscore; + + public static ReadOnlySpan Vector128LutHigh => Base64UrlDecoderByte.Vector128LutHigh; + + public static ReadOnlySpan Vector128LutLow => Base64UrlDecoderByte.Vector128LutLow; + + public static ReadOnlySpan Vector128LutShift => Base64UrlDecoderByte.Vector128LutShift; + + public static ReadOnlySpan AdvSimdLutOne3 => Base64UrlDecoderByte.AdvSimdLutOne3; + + public static uint AdvSimdLutTwo3Uint1 => Base64UrlDecoderByte.AdvSimdLutTwo3Uint1; + + public static int GetMaxDecodedLength(int sourceLength) => Base64UrlDecoderByte.GetMaxDecodedLength(sourceLength); + + public static bool IsInvalidLength(int bufferLength) => Base64DecoderByte.IsInvalidLength(bufferLength); + + public static bool IsValidPadding(uint padChar) => Base64UrlDecoderByte.IsValidPadding(padChar); + + public static int SrcLength(bool isFinalBlock, int sourceLength) => Base64UrlDecoderByte.SrcLength(isFinalBlock, sourceLength); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + [CompExactlyDependsOn(typeof(AdvSimd.Arm64))] + [CompExactlyDependsOn(typeof(Ssse3))] + public static bool TryDecode128Core(Vector128 str, Vector128 hiNibbles, Vector128 maskSlashOrUnderscore, Vector128 mask8F, + Vector128 lutLow, Vector128 lutHigh, Vector128 lutShift, Vector128 shiftForUnderscore, out Vector128 result) => + Base64UrlDecoderByte.TryDecode128Core(str, hiNibbles, maskSlashOrUnderscore, mask8F, lutLow, lutHigh, lutShift, shiftForUnderscore, out result); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + [CompExactlyDependsOn(typeof(Avx2))] + public static bool TryDecode256Core(Vector256 str, Vector256 hiNibbles, Vector256 maskSlashOrUnderscore, Vector256 lutLow, + Vector256 lutHigh, Vector256 lutShift, Vector256 shiftForUnderscore, out Vector256 result) => + Base64UrlDecoderByte.TryDecode256Core(str, hiNibbles, maskSlashOrUnderscore, lutLow, lutHigh, lutShift, shiftForUnderscore, out result); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static unsafe int DecodeFourElements(ushort* source, ref sbyte decodingMap) + { + // The 'source' span expected to have at least 4 elements, and the 'decodingMap' consists 256 sbytes + uint t0 = source[0]; + uint t1 = source[1]; + uint t2 = source[2]; + uint t3 = source[3]; + + if (((t0 | t1 | t2 | t3) & 0xffffff00) != 0) + { + return -1; // One or more chars falls outside the 00..ff range, invalid Base64Url character. + } + + int i0 = Unsafe.Add(ref decodingMap, t0); + int i1 = Unsafe.Add(ref decodingMap, t1); + int i2 = Unsafe.Add(ref decodingMap, t2); + int i3 = Unsafe.Add(ref decodingMap, t3); + + i0 <<= 18; + i1 <<= 12; + i2 <<= 6; + + i0 |= i3; + i1 |= i2; + + i0 |= i1; + return i0; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static unsafe int DecodeRemaining(ushort* srcEnd, ref sbyte decodingMap, long remaining, out uint t2, out uint t3) + { + uint t0; + uint t1; + t2 = EncodingPad; + t3 = EncodingPad; + switch (remaining) + { + case 2: + t0 = srcEnd[-2]; + t1 = srcEnd[-1]; + break; + case 3: + t0 = srcEnd[-3]; + t1 = srcEnd[-2]; + t2 = srcEnd[-1]; + break; + case 4: + t0 = srcEnd[-4]; + t1 = srcEnd[-3]; + t2 = srcEnd[-2]; + t3 = srcEnd[-1]; + break; + default: + return -1; + } + + if (((t0 | t1 | t2 | t3) & 0xffffff00) != 0) + { + return -1; + } + + int i0 = Unsafe.Add(ref decodingMap, (IntPtr)t0); + int i1 = Unsafe.Add(ref decodingMap, (IntPtr)t1); + + i0 <<= 18; + i1 <<= 12; + + i0 |= i1; + return i0; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static int IndexOfAnyExceptWhiteSpace(ReadOnlySpan span) + { + for (int i = 0; i < span.Length; i++) + { + if (!IsWhiteSpace(span[i])) + { + return i; + } + } + + return -1; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static OperationStatus DecodeWithWhiteSpaceBlockwiseWrapper(ReadOnlySpan source, Span bytes, + ref int bytesConsumed, ref int bytesWritten, bool isFinalBlock = true) where TBase64Decoder : IBase64Decoder => + DecodeWithWhiteSpaceBlockwise(source, bytes, ref bytesConsumed, ref bytesWritten, isFinalBlock); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static unsafe bool TryLoadVector512(ushort* src, ushort* srcStart, int sourceLength, out Vector512 str) + { + AssertRead>(src, srcStart, sourceLength); + Vector512 utf16VectorLower = Vector512.Load(src); + Vector512 utf16VectorUpper = Vector512.Load(src + 32); + + if (Ascii.VectorContainsNonAsciiChar(utf16VectorLower | utf16VectorUpper)) + { + str = default; + return false; + } + + str = Vector512.Narrow(utf16VectorLower, utf16VectorUpper).AsSByte(); + return true; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + [CompExactlyDependsOn(typeof(Avx2))] + public static unsafe bool TryLoadAvxVector256(ushort* src, ushort* srcStart, int sourceLength, out Vector256 str) + { + AssertRead>(src, srcStart, sourceLength); + Vector256 utf16VectorLower = Avx.LoadVector256(src); + Vector256 utf16VectorUpper = Avx.LoadVector256(src + 16); + + if (Ascii.VectorContainsNonAsciiChar(utf16VectorLower | utf16VectorUpper)) + { + str = default; + return false; + } + + str = Vector256.Narrow(utf16VectorLower, utf16VectorUpper).AsSByte(); + return true; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static unsafe bool TryLoadVector128(ushort* src, ushort* srcStart, int sourceLength, out Vector128 str) + { + AssertRead>(src, srcStart, sourceLength); + Vector128 utf16VectorLower = Vector128.LoadUnsafe(ref *src); + Vector128 utf16VectorUpper = Vector128.LoadUnsafe(ref *src, 8); + if (Ascii.VectorContainsNonAsciiChar(utf16VectorLower | utf16VectorUpper)) + { + str = default; + return false; + } + + str = Ascii.ExtractAsciiVector(utf16VectorLower, utf16VectorUpper); + return true; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + [CompExactlyDependsOn(typeof(AdvSimd.Arm64))] + public static unsafe bool TryLoadArmVector128x4(ushort* src, ushort* srcStart, int sourceLength, + out Vector128 str1, out Vector128 str2, out Vector128 str3, out Vector128 str4) + { + AssertRead> (src, srcStart, sourceLength); + var (s11, s12, s21, s22) = AdvSimd.Arm64.LoadVector128x4AndUnzip(src); + var (s31, s32, s41, s42) = AdvSimd.Arm64.LoadVector128x4AndUnzip(src + 32); + + if (Ascii.VectorContainsNonAsciiChar(s11 | s12 | s21 | s22 | s31 | s32 | s41 | s42)) + { + str1 = str2 = str3 = str4 = default; + return false; + } + + str1 = Ascii.ExtractAsciiVector(s11, s31); + str2 = Ascii.ExtractAsciiVector(s12, s32); + str3 = Ascii.ExtractAsciiVector(s21, s41); + str4 = Ascii.ExtractAsciiVector(s22, s42); + + return true; + } + } + } +} diff --git a/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Url/Base64UrlEncoder.cs b/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Url/Base64UrlEncoder.cs new file mode 100644 index 0000000000000..13d210638fee3 --- /dev/null +++ b/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Url/Base64UrlEncoder.cs @@ -0,0 +1,441 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Diagnostics; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using System.Runtime.Intrinsics; +using System.Runtime.Intrinsics.Arm; +using System.Runtime.Intrinsics.X86; +using static System.Buffers.Text.Base64; + +namespace System.Buffers.Text +{ + public static partial class Base64Url + { + /// + /// Encodes the span of binary data into UTF-8 encoded text represented as Base64Url. + /// + /// The input span which contains binary data that needs to be encoded. + /// The output span which contains the result of the operation, i.e. the UTF-8 encoded text in Base64Url. + /// When this method returns, contains the number of input bytes consumed during the operation. This can be used to slice the input for subsequent calls, if necessary. This parameter is treated as uninitialized. + /// When this method returns, contains the number of bytes written into the output span. This can be used to slice the output for subsequent calls, if necessary. This parameter is treated as uninitialized. + /// when the input span contains the entirety of data to encode; when more data may follow, + /// such as when calling in a loop, subsequent calls with should end with call. The default is . + /// One of the enumeration values that indicates the success or failure of the operation. + /// This implementation of the base64url encoding omits the optional padding characters. + public static unsafe OperationStatus EncodeToUtf8(ReadOnlySpan source, + Span destination, out int bytesConsumed, out int bytesWritten, bool isFinalBlock = true) => + EncodeTo(source, destination, out bytesConsumed, out bytesWritten, isFinalBlock); + + /// + /// Returns the length (in bytes) of the result if you were to encode binary data within a byte span of size . + /// + /// + /// is less than 0 or greater than 1610612733. + /// + public static int GetEncodedLength(int bytesLength) + { + ArgumentOutOfRangeException.ThrowIfGreaterThan((uint)bytesLength, Base64.MaximumEncodeLength); + + (uint whole, uint remainder) = uint.DivRem((uint)bytesLength, 3); + + return (int)(whole * 4 + (remainder > 0 ? remainder + 1 : 0)); // if remainder is 1 or 2, the encoded length will be 1 byte longer. + } + + /// + /// Encodes the span of binary data into UTF-8 encoded text represented as Base64Url. + /// + /// The input span which contains binary data that needs to be encoded. + /// The output span which contains the result of the operation, i.e. the UTF-8 encoded text in Base64Url. + /// The number of bytes written into the destination span. This can be used to slice the output for subsequent calls, if necessary. + /// The buffer in is too small to hold the encoded output. + /// This implementation of the base64url encoding omits the optional padding characters. + public static int EncodeToUtf8(ReadOnlySpan source, Span destination) + { + OperationStatus status = EncodeToUtf8(source, destination, out _, out int bytesWritten); + + if (status == OperationStatus.Done) + { + return bytesWritten; + } + + Debug.Assert(status == OperationStatus.DestinationTooSmall); + throw new ArgumentException(SR.Argument_DestinationTooShort, nameof(destination)); + } + + /// + /// Encodes the span of binary data into UTF-8 encoded text represented as Base64Url. + /// + /// The input span which contains binary data that needs to be encoded. + /// The output byte array which contains the result of the operation, i.e. the UTF-8 encoded text in Base64Url. + /// This implementation of the base64url encoding omits the optional padding characters. + public static byte[] EncodeToUtf8(ReadOnlySpan source) + { + byte[] destination = new byte[GetEncodedLength(source.Length)]; + EncodeToUtf8(source, destination, out _, out int bytesWritten); + Debug.Assert(destination.Length == bytesWritten); + + return destination; + } + + /// + /// Encodes the span of binary data into unicode ASCII chars represented as Base64Url. + /// + /// The input span which contains binary data that needs to be encoded. + /// The output span which contains the result of the operation, i.e. the ASCII chars in Base64Url. + /// >When this method returns, contains the number of input bytes consumed during the operation. This can be used to slice the input for subsequent calls, if necessary. This parameter is treated as uninitialized. + /// >When this method returns, contains the number of chars written into the output span. This can be used to slice the output for subsequent calls, if necessary. This parameter is treated as uninitialized. + /// when the input span contains the entirety of data to encode; when more data may follow, + /// such as when calling in a loop, subsequent calls with should end with call. The default is . + /// One of the enumeration values that indicates the success or failure of the operation. + /// This implementation of the base64url encoding omits the optional padding characters. + public static OperationStatus EncodeToChars(ReadOnlySpan source, Span destination, + out int bytesConsumed, out int charsWritten, bool isFinalBlock = true) => + EncodeTo(source, MemoryMarshal.Cast(destination), out bytesConsumed, out charsWritten, isFinalBlock); + + /// + /// Encodes the span of binary data into unicode ASCII chars represented as Base64Url. + /// + /// The input span which contains binary data that needs to be encoded. + /// The output span which contains the result of the operation, i.e. the ASCII chars in Base64Url. + /// The number of bytes written into the destination span. This can be used to slice the output for subsequent calls, if necessary. + /// The buffer in is too small to hold the encoded output. + /// This implementation of the base64url encoding omits the optional padding characters. + public static int EncodeToChars(ReadOnlySpan source, Span destination) + { + OperationStatus status = EncodeToChars(source, destination, out _, out int charsWritten); + + if (status == OperationStatus.Done) + { + return charsWritten; + } + + Debug.Assert(status == OperationStatus.DestinationTooSmall); + throw new ArgumentException(SR.Argument_DestinationTooShort, nameof(destination)); + } + + /// + /// Encodes the span of binary data into unicode ASCII chars represented as Base64Url. + /// + /// The input span which contains binary data that needs to be encoded. + /// A char array which contains the result of the operation, i.e. the ASCII chars in Base64Url. + /// This implementation of the base64url encoding omits the optional padding characters. + public static char[] EncodeToChars(ReadOnlySpan source) + { + char[] destination = new char[GetEncodedLength(source.Length)]; + EncodeToChars(source, destination, out _, out int charsWritten); + Debug.Assert(destination.Length == charsWritten); + + return destination; + } + + /// + /// Encodes the span of binary data into unicode string represented as Base64Url ASCII chars. + /// + /// The input span which contains binary data that needs to be encoded. + /// A string which contains the result of the operation, i.e. the ASCII string in Base64Url. + /// This implementation of the base64url encoding omits the optional padding characters. + public static unsafe string EncodeToString(ReadOnlySpan source) + { + int encodedLength = GetEncodedLength(source.Length); + +#pragma warning disable CS8500 // This takes the address of, gets the size of, or declares a pointer to a managed type + return string.Create(encodedLength, (IntPtr)(&source), static (buffer, spanPtr) => + { + ReadOnlySpan source = *(ReadOnlySpan*)spanPtr; + EncodeToChars(source, buffer, out _, out int charsWritten); + Debug.Assert(buffer.Length == charsWritten, $"The source length: {source.Length}, bytes written: {charsWritten}"); + }); +#pragma warning restore CS8500 // This takes the address of, gets the size of, or declares a pointer to a managed type + } + + /// + /// Encodes the span of binary data into unicode ASCII chars represented as Base64Url. + /// + /// The input span which contains binary data that needs to be encoded. + /// The output span which contains the result of the operation, i.e. the ASCII chars in Base64Url. + /// When this method returns, contains the number of chars written into the output span. This can be used to slice the output for subsequent calls, if necessary. This parameter is treated as uninitialized. + /// if chars encoded successfully, otherwise . + /// This implementation of the base64url encoding omits the optional padding characters. + public static bool TryEncodeToChars(ReadOnlySpan source, Span destination, out int charsWritten) + { + OperationStatus status = EncodeToChars(source, destination, out _, out charsWritten); + + return status == OperationStatus.Done; + } + + /// + /// Encodes the span of binary data into UTF-8 encoded chars represented as Base64Url. + /// + /// The input span which contains binary data that needs to be encoded. + /// The output span which contains the result of the operation, i.e. the UTF-8 encoded text in Base64Url. + /// When this method returns, contains the number of chars written into the output span. This can be used to slice the output for subsequent calls, if necessary. This parameter is treated as uninitialized. + /// if bytes encoded successfully, otherwise . + /// This implementation of the base64url encoding omits the optional padding characters. + public static bool TryEncodeToUtf8(ReadOnlySpan source, Span destination, out int bytesWritten) + { + OperationStatus status = EncodeToUtf8(source, destination, out _, out bytesWritten); + + return status == OperationStatus.Done; + } + + /// + /// Encodes the span of binary data (in-place) into UTF-8 encoded text represented as base 64. + /// The encoded text output is larger than the binary data contained in the input (the operation inflates the data). + /// + /// The input span which contains binary data that needs to be encoded. + /// It needs to be large enough to fit the result of the operation. + /// The amount of binary data contained within the buffer that needs to be encoded + /// (and needs to be smaller than the buffer length). + /// When this method returns, contains the number of bytes written into the buffer. This parameter is treated as uninitialized. + /// if bytes encoded successfully, otherwise . + /// This implementation of the base64url encoding omits the optional padding characters. + public static unsafe bool TryEncodeToUtf8InPlace(Span buffer, int dataLength, out int bytesWritten) + { + OperationStatus status = EncodeToUtf8InPlace(buffer, dataLength, out bytesWritten); + + return status == OperationStatus.Done; + } + + private readonly struct Base64UrlEncoderByte : IBase64Encoder + { + public static ReadOnlySpan EncodingMap => "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_"u8; + + public static sbyte Avx2LutChar62 => -17; // char '-' diff + + public static sbyte Avx2LutChar63 => 32; // char '_' diff + + public static ReadOnlySpan AdvSimdLut4 => "wxyz0123456789-_"u8; + + public static uint Ssse3AdvSimdLutE3 => 0x000020EF; + + public static int IncrementPadTwo => 2; + + public static int IncrementPadOne => 3; + + public static int GetMaxSrcLength(int srcLength, int destLength) => + srcLength <= MaximumEncodeLength && destLength >= GetEncodedLength(srcLength) ? + srcLength : GetMaxDecodedLength(destLength); + + public static uint GetInPlaceDestinationLength(int encodedLength, int leftOver) => + leftOver > 0 ? (uint)(encodedLength - leftOver - 1) : (uint)(encodedLength - 4); + + public static int GetMaxEncodedLength(int srcLength) => GetEncodedLength(srcLength); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static unsafe void EncodeOneOptionallyPadTwo(byte* oneByte, byte* dest, ref byte encodingMap) + { + uint t0 = oneByte[0]; + + uint i = t0 << 8; + + uint i0 = Unsafe.Add(ref encodingMap, (IntPtr)(i >> 10)); + uint i1 = Unsafe.Add(ref encodingMap, (IntPtr)((i >> 4) & 0x3F)); + + if (BitConverter.IsLittleEndian) + { + dest[0] = (byte)i0; + dest[1] = (byte)i1; + } + else + { + dest[1] = (byte)i0; + dest[0] = (byte)i1; + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static unsafe void EncodeTwoOptionallyPadOne(byte* twoBytes, byte* dest, ref byte encodingMap) + { + uint t0 = twoBytes[0]; + uint t1 = twoBytes[1]; + + uint i = (t0 << 16) | (t1 << 8); + + uint i0 = Unsafe.Add(ref encodingMap, (IntPtr)(i >> 18)); + uint i1 = Unsafe.Add(ref encodingMap, (IntPtr)((i >> 12) & 0x3F)); + uint i2 = Unsafe.Add(ref encodingMap, (IntPtr)((i >> 6) & 0x3F)); + + if (BitConverter.IsLittleEndian) + { + dest[0] = (byte)i0; + dest[1] = (byte)i1; + dest[2] = (byte)i2; + } + else + { + dest[2] = (byte)i0; + dest[1] = (byte)i1; + dest[0] = (byte)i2; + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static unsafe void StoreVector512ToDestination(byte* dest, byte* destStart, int destLength, Vector512 str) => + Base64EncoderByte.StoreVector512ToDestination(dest, destStart, destLength, str); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + [CompExactlyDependsOn(typeof(Avx2))] + public static unsafe void StoreVector256ToDestination(byte* dest, byte* destStart, int destLength, Vector256 str) => + Base64EncoderByte.StoreVector256ToDestination(dest, destStart, destLength, str); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static unsafe void StoreVector128ToDestination(byte* dest, byte* destStart, int destLength, Vector128 str) => + Base64EncoderByte.StoreVector128ToDestination(dest, destStart, destLength, str); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + [CompExactlyDependsOn(typeof(AdvSimd.Arm64))] + public static unsafe void StoreArmVector128x4ToDestination(byte* dest, byte* destStart, int destLength, + Vector128 res1, Vector128 res2, Vector128 res3, Vector128 res4) => + Base64EncoderByte.StoreArmVector128x4ToDestination(dest, destStart, destLength, res1, res2, res3, res4); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static unsafe void EncodeThreeAndWrite(byte* threeBytes, byte* destination, ref byte encodingMap) => + Base64EncoderByte.EncodeThreeAndWrite(threeBytes, destination, ref encodingMap); + } + + private readonly struct Base64UrlEncoderChar : IBase64Encoder + { + public static ReadOnlySpan EncodingMap => Base64UrlEncoderByte.EncodingMap; + + public static sbyte Avx2LutChar62 => Base64UrlEncoderByte.Avx2LutChar62; + + public static sbyte Avx2LutChar63 => Base64UrlEncoderByte.Avx2LutChar63; + + public static ReadOnlySpan AdvSimdLut4 => Base64UrlEncoderByte.AdvSimdLut4; + + public static uint Ssse3AdvSimdLutE3 => Base64UrlEncoderByte.Ssse3AdvSimdLutE3; + + public static int IncrementPadTwo => Base64UrlEncoderByte.IncrementPadTwo; + + public static int IncrementPadOne => Base64UrlEncoderByte.IncrementPadOne; + + public static int GetMaxSrcLength(int srcLength, int destLength) => + Base64UrlEncoderByte.GetMaxSrcLength(srcLength, destLength); + + public static uint GetInPlaceDestinationLength(int encodedLength, int _) => 0; // not used for char encoding + + public static int GetMaxEncodedLength(int _) => 0; // not used for char encoding + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static unsafe void EncodeOneOptionallyPadTwo(byte* oneByte, ushort* dest, ref byte encodingMap) + { + uint t0 = oneByte[0]; + + uint i = t0 << 8; + + uint i0 = Unsafe.Add(ref encodingMap, (IntPtr)(i >> 10)); + uint i1 = Unsafe.Add(ref encodingMap, (IntPtr)((i >> 4) & 0x3F)); + + if (BitConverter.IsLittleEndian) + { + dest[0] = (ushort)i0; + dest[1] = (ushort)i1; + } + else + { + dest[1] = (ushort)i0; + dest[0] = (ushort)i1; + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static unsafe void EncodeTwoOptionallyPadOne(byte* twoBytes, ushort* dest, ref byte encodingMap) + { + uint t0 = twoBytes[0]; + uint t1 = twoBytes[1]; + + uint i = (t0 << 16) | (t1 << 8); + + uint i0 = Unsafe.Add(ref encodingMap, (IntPtr)(i >> 18)); + uint i1 = Unsafe.Add(ref encodingMap, (IntPtr)((i >> 12) & 0x3F)); + uint i2 = Unsafe.Add(ref encodingMap, (IntPtr)((i >> 6) & 0x3F)); + + if (BitConverter.IsLittleEndian) + { + dest[0] = (ushort)i0; + dest[1] = (ushort)i1; + dest[2] = (ushort)i2; + } + else + { + dest[2] = (ushort)i0; + dest[1] = (ushort)i1; + dest[0] = (ushort)i2; + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static unsafe void StoreVector512ToDestination(ushort* dest, ushort* destStart, int destLength, Vector512 str) + { + AssertWrite>(dest, destStart, destLength); + (Vector512 utf16LowVector, Vector512 utf16HighVector) = Vector512.Widen(str); + utf16LowVector.Store(dest); + utf16HighVector.Store(dest + Vector512.Count); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static unsafe void StoreVector256ToDestination(ushort* dest, ushort* destStart, int destLength, Vector256 str) + { + AssertWrite>(dest, destStart, destLength); + (Vector256 utf16LowVector, Vector256 utf16HighVector) = Vector256.Widen(str); + utf16LowVector.Store(dest); + utf16HighVector.Store(dest + Vector256.Count); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static unsafe void StoreVector128ToDestination(ushort* dest, ushort* destStart, int destLength, Vector128 str) + { + AssertWrite>(dest, destStart, destLength); + (Vector128 utf16LowVector, Vector128 utf16HighVector) = Vector128.Widen(str); + utf16LowVector.Store(dest); + utf16HighVector.Store(dest + Vector128.Count); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + [CompExactlyDependsOn(typeof(AdvSimd.Arm64))] + public static unsafe void StoreArmVector128x4ToDestination(ushort* dest, ushort* destStart, int destLength, + Vector128 res1, Vector128 res2, Vector128 res3, Vector128 res4) + { + AssertWrite>(dest, destStart, destLength); + (Vector128 utf16LowVector1, Vector128 utf16HighVector1) = Vector128.Widen(res1); + (Vector128 utf16LowVector2, Vector128 utf16HighVector2) = Vector128.Widen(res2); + (Vector128 utf16LowVector3, Vector128 utf16HighVector3) = Vector128.Widen(res3); + (Vector128 utf16LowVector4, Vector128 utf16HighVector4) = Vector128.Widen(res4); + AdvSimd.Arm64.StoreVector128x4AndZip(dest, (utf16LowVector1, utf16LowVector2, utf16LowVector3, utf16LowVector4)); + AdvSimd.Arm64.StoreVector128x4AndZip(dest + 32, (utf16HighVector1, utf16HighVector2, utf16HighVector3, utf16HighVector4)); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static unsafe void EncodeThreeAndWrite(byte* threeBytes, ushort* destination, ref byte encodingMap) + { + uint t0 = threeBytes[0]; + uint t1 = threeBytes[1]; + uint t2 = threeBytes[2]; + + uint i = (t0 << 16) | (t1 << 8) | t2; + + byte i0 = Unsafe.Add(ref encodingMap, (IntPtr)(i >> 18)); + byte i1 = Unsafe.Add(ref encodingMap, (IntPtr)((i >> 12) & 0x3F)); + byte i2 = Unsafe.Add(ref encodingMap, (IntPtr)((i >> 6) & 0x3F)); + byte i3 = Unsafe.Add(ref encodingMap, (IntPtr)(i & 0x3F)); + + if (BitConverter.IsLittleEndian) + { + destination[0] = i0; + destination[1] = i1; + destination[2] = i2; + destination[3] = i3; + } + else + { + destination[3] = i0; + destination[2] = i1; + destination[1] = i2; + destination[0] = i3; + } + } + } + } +} diff --git a/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Url/Base64UrlValidator.cs b/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Url/Base64UrlValidator.cs new file mode 100644 index 0000000000000..20eb0e926d62a --- /dev/null +++ b/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Url/Base64UrlValidator.cs @@ -0,0 +1,91 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Runtime.CompilerServices; + +namespace System.Buffers.Text +{ + public static partial class Base64Url + { + /// Validates that the specified span of text is comprised of valid base-64 encoded data. + /// A span of text to validate. + /// if contains a valid, decodable sequence of base-64 encoded data; otherwise, . + /// + /// If the method returns , the same text passed to and + /// would successfully decode (in the case + /// of assuming sufficient output space). + /// Any amount of whitespace is allowed anywhere in the input, where whitespace is defined as the characters ' ', '\t', '\r', or '\n'. + /// + public static bool IsValid(ReadOnlySpan base64UrlText) => + Base64.IsValid(base64UrlText, out _); + + /// Validates that the specified span of text is comprised of valid base-64 encoded data. + /// A span of text to validate. + /// If the method returns true, the number of decoded bytes that will result from decoding the input text. + /// if contains a valid, decodable sequence of base-64 encoded data; otherwise, . + /// + /// If the method returns , the same text passed to and + /// would successfully decode (in the case + /// of assuming sufficient output space). + /// Any amount of whitespace is allowed anywhere in the input, where whitespace is defined as the characters ' ', '\t', '\r', or '\n'. + /// + public static bool IsValid(ReadOnlySpan base64UrlText, out int decodedLength) => + Base64.IsValid(base64UrlText, out decodedLength); + + /// Validates that the specified span of UTF-8 text is comprised of valid base-64 encoded data. + /// A span of UTF-8 text to validate. + /// if contains a valid, decodable sequence of base-64 encoded data; otherwise, . + /// + /// where whitespace is defined as the characters ' ', '\t', '\r', or '\n' (as bytes). + /// + public static bool IsValid(ReadOnlySpan utf8Base64UrlText) => + Base64.IsValid(utf8Base64UrlText, out _); + + /// Validates that the specified span of UTF-8 text is comprised of valid base-64 encoded data. + /// A span of UTF-8 text to validate. + /// If the method returns true, the number of decoded bytes that will result from decoding the input UTF-8 text. + /// if contains a valid, decodable sequence of base-64 encoded data; otherwise, . + /// + /// where whitespace is defined as the characters ' ', '\t', '\r', or '\n' (as bytes). + /// + public static bool IsValid(ReadOnlySpan utf8Base64UrlText, out int decodedLength) => + Base64.IsValid(utf8Base64UrlText, out decodedLength); + + private const uint UrlEncodingPad = '%'; // allowed for url padding + + private readonly struct Base64UrlCharValidatable : Base64.IBase64Validatable + { + private static readonly SearchValues s_validBase64UrlChars = SearchValues.Create("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_"); + + public static int IndexOfAnyExcept(ReadOnlySpan span) => span.IndexOfAnyExcept(s_validBase64UrlChars); + public static bool IsWhiteSpace(char value) => Base64.IsWhiteSpace(value); + public static bool IsEncodingPad(char value) => value == Base64.EncodingPad || value == UrlEncodingPad; + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static bool ValidateAndDecodeLength(int length, int paddingCount, out int decodedLength) => + Base64UrlByteValidatable.ValidateAndDecodeLength(length, paddingCount, out decodedLength); + } + + private readonly struct Base64UrlByteValidatable : Base64.IBase64Validatable + { + private static readonly SearchValues s_validBase64UrlChars = SearchValues.Create(Base64UrlEncoderByte.EncodingMap); + + public static int IndexOfAnyExcept(ReadOnlySpan span) => span.IndexOfAnyExcept(s_validBase64UrlChars); + public static bool IsWhiteSpace(byte value) => Base64.IsWhiteSpace(value); + public static bool IsEncodingPad(byte value) => value == Base64.EncodingPad || value == UrlEncodingPad; + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static bool ValidateAndDecodeLength(int length, int paddingCount, out int decodedLength) + { + // Padding is optional for Base64Url, so need to account remainder. If remainder is 1, then it's invalid. + (uint whole, uint remainder) = uint.DivRem((uint)(length), 4); + if (remainder == 1 || (remainder > 1 && (remainder - paddingCount == 1 || paddingCount == remainder))) + { + decodedLength = 0; + return false; + } + + decodedLength = (int)((whole * 3) + (remainder > 0 ? remainder - 1 : 0) - paddingCount); + return true; + } + } + } +} diff --git a/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Validator.cs b/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Validator.cs index 22071725a2352..3012bf6f4a0a5 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Validator.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Validator.cs @@ -1,6 +1,8 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Runtime.CompilerServices; + namespace System.Buffers.Text { public static partial class Base64 @@ -53,7 +55,7 @@ public static bool IsValid(ReadOnlySpan base64TextUtf8) => public static bool IsValid(ReadOnlySpan base64TextUtf8, out int decodedLength) => IsValid(base64TextUtf8, out decodedLength); - private static bool IsValid(ReadOnlySpan base64Text, out int decodedLength) + internal static bool IsValid(ReadOnlySpan base64Text, out int decodedLength) where TBase64Validatable : IBase64Validatable { int length = 0, paddingCount = 0; @@ -116,14 +118,15 @@ private static bool IsValid(ReadOnlySpan base64Text, o break; } - if (length % 4 != 0) + if (!TBase64Validatable.ValidateAndDecodeLength(length, paddingCount, out decodedLength)) { goto Fail; } + + return true; } - // Remove padding to get exact length. - decodedLength = (int)((uint)length / 4 * 3) - paddingCount; + decodedLength = 0; return true; Fail: @@ -131,11 +134,12 @@ private static bool IsValid(ReadOnlySpan base64Text, o return false; } - private interface IBase64Validatable + internal interface IBase64Validatable { static abstract int IndexOfAnyExcept(ReadOnlySpan span); static abstract bool IsWhiteSpace(T value); static abstract bool IsEncodingPad(T value); + static abstract bool ValidateAndDecodeLength(int length, int paddingCount, out int decodedLength); } private readonly struct Base64CharValidatable : IBase64Validatable @@ -145,15 +149,30 @@ private interface IBase64Validatable public static int IndexOfAnyExcept(ReadOnlySpan span) => span.IndexOfAnyExcept(s_validBase64Chars); public static bool IsWhiteSpace(char value) => Base64.IsWhiteSpace(value); public static bool IsEncodingPad(char value) => value == EncodingPad; + public static bool ValidateAndDecodeLength(int length, int paddingCount, out int decodedLength) => + Base64ByteValidatable.ValidateAndDecodeLength(length, paddingCount, out decodedLength); } private readonly struct Base64ByteValidatable : IBase64Validatable { - private static readonly SearchValues s_validBase64Chars = SearchValues.Create(EncodingMap); + private static readonly SearchValues s_validBase64Chars = SearchValues.Create(Base64EncoderByte.EncodingMap); public static int IndexOfAnyExcept(ReadOnlySpan span) => span.IndexOfAnyExcept(s_validBase64Chars); public static bool IsWhiteSpace(byte value) => Base64.IsWhiteSpace(value); public static bool IsEncodingPad(byte value) => value == EncodingPad; + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static bool ValidateAndDecodeLength(int length, int paddingCount, out int decodedLength) + { + if (length % 4 == 0) + { + // Remove padding to get exact length. + decodedLength = (int)((uint)length / 4 * 3) - paddingCount; + return true; + } + + decodedLength = 0; + return false; + } } } } diff --git a/src/libraries/System.Private.CoreLib/src/System/Text/Ascii.Utility.cs b/src/libraries/System.Private.CoreLib/src/System/Text/Ascii.Utility.cs index b879feb073784..aaeec394495fd 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Text/Ascii.Utility.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Text/Ascii.Utility.cs @@ -1518,7 +1518,7 @@ private static bool VectorContainsNonAsciiChar(Vector128 asciiVector) } [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static bool VectorContainsNonAsciiChar(Vector128 utf16Vector) + internal static bool VectorContainsNonAsciiChar(Vector128 utf16Vector) { // prefer architecture specific intrinsic as they offer better perf if (Sse2.IsSupported) @@ -1555,7 +1555,7 @@ private static bool VectorContainsNonAsciiChar(Vector128 utf16Vector) } [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static bool VectorContainsNonAsciiChar(Vector256 utf16Vector) + internal static bool VectorContainsNonAsciiChar(Vector256 utf16Vector) { if (Avx.IsSupported) { @@ -1572,7 +1572,7 @@ private static bool VectorContainsNonAsciiChar(Vector256 utf16Vector) } [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static bool VectorContainsNonAsciiChar(Vector512 utf16Vector) + internal static bool VectorContainsNonAsciiChar(Vector512 utf16Vector) { const ushort asciiMask = ushort.MaxValue - 127; // 0xFF80 Vector512 zeroIsAscii = utf16Vector & Vector512.Create(asciiMask); diff --git a/src/libraries/System.Runtime/ref/System.Runtime.cs b/src/libraries/System.Runtime/ref/System.Runtime.cs index 9562467301a4a..2b03c0b011ff9 100644 --- a/src/libraries/System.Runtime/ref/System.Runtime.cs +++ b/src/libraries/System.Runtime/ref/System.Runtime.cs @@ -7849,6 +7849,34 @@ public static partial class Base64 public static bool IsValid(System.ReadOnlySpan base64Text) { throw null; } public static bool IsValid(System.ReadOnlySpan base64Text, out int decodedLength) { throw null; } } + public static class Base64Url + { + public static byte[] DecodeFromChars(System.ReadOnlySpan source) { throw null; } + public static int DecodeFromChars(System.ReadOnlySpan source, System.Span destination) { throw null; } + public static System.Buffers.OperationStatus DecodeFromChars(System.ReadOnlySpan source, System.Span destination, out int charsConsumed, out int bytesWritten, bool isFinalBlock = true) { throw null; } + public static byte[] DecodeFromUtf8(System.ReadOnlySpan source) { throw null; } + public static int DecodeFromUtf8(System.ReadOnlySpan source, System.Span destination) { throw null; } + public static System.Buffers.OperationStatus DecodeFromUtf8(System.ReadOnlySpan source, System.Span destination, out int bytesConsumed, out int bytesWritten, bool isFinalBlock = true) { throw null; } + public static int DecodeFromUtf8InPlace(System.Span buffer) { throw null; } + public static char[] EncodeToChars(System.ReadOnlySpan source) { throw null; } + public static int EncodeToChars(System.ReadOnlySpan source, System.Span destination) { throw null; } + public static System.Buffers.OperationStatus EncodeToChars(System.ReadOnlySpan source, System.Span destination, out int bytesConsumed, out int charsWritten, bool isFinalBlock = true) { throw null; } + public static string EncodeToString(System.ReadOnlySpan source) { throw null; } + public static byte[] EncodeToUtf8(System.ReadOnlySpan source) { throw null; } + public static int EncodeToUtf8(System.ReadOnlySpan source, System.Span destination) { throw null; } + public static System.Buffers.OperationStatus EncodeToUtf8(System.ReadOnlySpan source, System.Span destination, out int bytesConsumed, out int bytesWritten, bool isFinalBlock = true) { throw null; } + public static int GetEncodedLength(int bytesLength) { throw null; } + public static int GetMaxDecodedLength(int base64Length) { throw null; } + public static bool IsValid(System.ReadOnlySpan base64UrlText) { throw null; } + public static bool IsValid(System.ReadOnlySpan base64UrlText, out int decodedLength) { throw null; } + public static bool IsValid(System.ReadOnlySpan utf8Base64UrlText) { throw null; } + public static bool IsValid(System.ReadOnlySpan utf8Base64UrlText, out int decodedLength) { throw null; } + public static bool TryDecodeFromChars(System.ReadOnlySpan source, System.Span destination, out int bytesWritten) { throw null; } + public static bool TryDecodeFromUtf8(System.ReadOnlySpan source, System.Span destination, out int bytesWritten) { throw null; } + public static bool TryEncodeToChars(System.ReadOnlySpan source, System.Span destination, out int charsWritten) { throw null; } + public static bool TryEncodeToUtf8(System.ReadOnlySpan source, System.Span destination, out int bytesWritten) { throw null; } + public static bool TryEncodeToUtf8InPlace(System.Span buffer, int dataLength, out int bytesWritten) { throw null; } + } } namespace System.CodeDom.Compiler {