From c09edb61f1fa0f65165a8b52d6424a8190d1d0af Mon Sep 17 00:00:00 2001 From: Alan Hayward Date: Fri, 20 May 2022 17:15:25 +0100 Subject: [PATCH] Implement System.Buffers.Text.Base64.DecodeFromUtf8 for Arm64 Like the AVX2 and SSE3 versions, this is based off the Aklomp base64 algorithm. The AdvSimd API does not yet have support for squential multi register instructions, such as TBL4/LD4/ST3. This code implements the those instructions using single register instructions. Once API support is added, this code can be greatly simplified and get an additional performance boost. --- .../src/System/Buffers/Text/Base64Decoder.cs | 138 ++++++++++++++++++ 1 file changed, 138 insertions(+) diff --git a/src/libraries/System.Memory/src/System/Buffers/Text/Base64Decoder.cs b/src/libraries/System.Memory/src/System/Buffers/Text/Base64Decoder.cs index 36a7afa536956..ee67a360ca7cf 100644 --- a/src/libraries/System.Memory/src/System/Buffers/Text/Base64Decoder.cs +++ b/src/libraries/System.Memory/src/System/Buffers/Text/Base64Decoder.cs @@ -5,12 +5,14 @@ using System.Runtime.CompilerServices; using System.Runtime.InteropServices; using System.Runtime.Intrinsics; +using System.Runtime.Intrinsics.Arm; using System.Runtime.Intrinsics.X86; namespace System.Buffers.Text { // AVX2 version based on https://github.com/aklomp/base64/tree/e516d769a2a432c08404f1981e73b431566057be/lib/arch/avx2 // SSSE3 version based on https://github.com/aklomp/base64/tree/e516d769a2a432c08404f1981e73b431566057be/lib/arch/ssse3 + // AdvSimd version based on https://github.com/aklomp/base64/blob/e516d769a2a432c08404f1981e73b431566057be/lib/arch/neon64 public static partial class Base64 { @@ -81,6 +83,15 @@ public static unsafe OperationStatus DecodeFromUtf8(ReadOnlySpan utf8, Spa if (src == srcEnd) goto DoneExit; } + + end = srcMax - 96; + if (BitConverter.IsLittleEndian && AdvSimd.Arm64.IsSupported && (end >= src)) + { + AdvSimdDecode(ref src, ref dest, end, maxSrcLength, destLength, srcBytes, destBytes); + + if (src == srcEnd) + goto DoneExit; + } } // Last bytes could have padding characters, so process them separately and treat them as valid only if isFinalBlock is true @@ -644,6 +655,133 @@ private static unsafe void Ssse3Decode(ref byte* srcBytes, ref byte* destBytes, destBytes = dest; } + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static unsafe Vector128 AdvSimdTbx8Byte(Vector128 defaults, Vector128 table0, Vector128 table1, Vector128 table2, Vector128 table3, Vector128 table4, Vector128 table5, Vector128 table6, Vector128 table7, Vector128 indicies, Vector128 offset) + { + // Implement an 8 way table lookup. + // This could be reduced by using two NEON TBX4 instructions. + + Debug.Assert(AdvSimd.Arm64.IsSupported && BitConverter.IsLittleEndian); + + Vector128 dest = defaults; + Vector128 indicies_sub = indicies; + + dest = AdvSimd.Arm64.VectorTableLookupExtension(dest, table0, indicies_sub); + indicies_sub = AdvSimd.Subtract(indicies_sub, offset); + dest = AdvSimd.Arm64.VectorTableLookupExtension(dest, table1, indicies_sub); + indicies_sub = AdvSimd.Subtract(indicies_sub, offset); + dest = AdvSimd.Arm64.VectorTableLookupExtension(dest, table2, indicies_sub); + indicies_sub = AdvSimd.Subtract(indicies_sub, offset); + dest = AdvSimd.Arm64.VectorTableLookupExtension(dest, table3, indicies_sub); + indicies_sub = AdvSimd.Subtract(indicies_sub, offset); + dest = AdvSimd.Arm64.VectorTableLookupExtension(dest, table4, indicies_sub); + indicies_sub = AdvSimd.Subtract(indicies_sub, offset); + dest = AdvSimd.Arm64.VectorTableLookupExtension(dest, table5, indicies_sub); + indicies_sub = AdvSimd.Subtract(indicies_sub, offset); + dest = AdvSimd.Arm64.VectorTableLookupExtension(dest, table6, indicies_sub); + indicies_sub = AdvSimd.Subtract(indicies_sub, offset); + dest = AdvSimd.Arm64.VectorTableLookupExtension(dest, table7, indicies_sub); + + return dest; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static unsafe Vector128 AdvSimdTbx3Byte(Vector128 defaults, Vector128 table0, Vector128 table1, Vector128 table2, Vector128 indicies, Vector128 offset) + { + // Implement a 3 way table lookup. + + Debug.Assert(AdvSimd.Arm64.IsSupported && BitConverter.IsLittleEndian); + + Vector128 dest = defaults; + Vector128 indicies_sub = indicies; + + dest = AdvSimd.Arm64.VectorTableLookupExtension(dest, table0, indicies_sub); + indicies_sub = AdvSimd.Subtract(indicies_sub, offset); + dest = AdvSimd.Arm64.VectorTableLookupExtension(dest, table1, indicies_sub); + indicies_sub = AdvSimd.Subtract(indicies_sub, offset); + dest = AdvSimd.Arm64.VectorTableLookupExtension(dest, table2, indicies_sub); + + return dest; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static unsafe void AdvSimdDecode(ref byte* srcBytes, ref byte* destBytes, byte* srcEnd, int sourceLength, int destLength, byte* srcStart, byte* destStart) + { + Debug.Assert(AdvSimd.Arm64.IsSupported && BitConverter.IsLittleEndian); + + // Complete lookup table - similar to that used in the SS3 decode. + Vector128 dec_lut0 = Vector128.Create(255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255); + Vector128 dec_lut1 = Vector128.Create(255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255); + Vector128 dec_lut2 = Vector128.Create(255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 62, 255, 255, 255, 63); + Vector128 dec_lut3 = Vector128.Create( 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 255, 255, 255, 255, 255, 255); + Vector128 dec_lut4 = Vector128.Create(255, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14); + Vector128 dec_lut5 = Vector128.Create( 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 255, 255, 255, 255, 255); + Vector128 dec_lut6 = Vector128.Create(255, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40); + Vector128 dec_lut7 = Vector128.Create( 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 255, 255, 255, 255, 255); + + // Interleave pattern for the ST3. + Vector128 st3_interleave_index0 = Vector128.Create((byte) 0, 16, 32, 1, 17, 33, 2, 18, 34, 3, 19, 35, 4, 20, 36, 5); + Vector128 st3_interleave_index1 = Vector128.Create((byte)21, 37, 6, 22, 38, 7, 23, 39, 8, 24, 40, 9, 25, 41, 10, 26); + Vector128 st3_interleave_index2 = Vector128.Create((byte)42, 11, 27, 43, 12, 28, 44, 13, 29, 45, 14, 30, 46, 15, 31, 47); + + // Some constants. + Vector128 vzero = Vector128.Create((byte)0); + Vector128 v255 = Vector128.Create((byte)255U); + Vector128 v16 = Vector128.Create((byte)16U); + + byte* src = srcBytes; + byte* dest = destBytes; + + do + { + // Load 64 bytes of data and deinterleave the result. + // This is equivalent to a NEON LD4 instruction. + Vector128 str0 = Vector128.LoadUnsafe(ref *src); + Vector128 str1 = Vector128.LoadUnsafe(ref *src, 16); + Vector128 str2 = Vector128.LoadUnsafe(ref *src, 32); + Vector128 str3 = Vector128.LoadUnsafe(ref *src, 48); + Vector128 tmp0 = AdvSimd.Arm64.UnzipEven(str0.AsInt16(), str1.AsInt16()); + Vector128 tmp1 = AdvSimd.Arm64.UnzipOdd(str0.AsInt16(), str1.AsInt16()); + Vector128 tmp2 = AdvSimd.Arm64.UnzipEven(str2.AsInt16(), str3.AsInt16()); + Vector128 tmp3 = AdvSimd.Arm64.UnzipOdd(str2.AsInt16(), str3.AsInt16()); + str0 = AdvSimd.Arm64.UnzipEven(tmp0.AsByte(), tmp2.AsByte()); + str1 = AdvSimd.Arm64.UnzipOdd(tmp0.AsByte(), tmp2.AsByte()); + str2 = AdvSimd.Arm64.UnzipEven(tmp1.AsByte(), tmp3.AsByte()); + str3 = AdvSimd.Arm64.UnzipOdd(tmp1.AsByte(), tmp3.AsByte()); + + // Table lookup on each 16 bytes. + str0 = AdvSimdTbx8Byte(v255, dec_lut0, dec_lut1, dec_lut2, dec_lut3, dec_lut4, dec_lut5, dec_lut6, dec_lut7, str0, v16); + str1 = AdvSimdTbx8Byte(v255, dec_lut0, dec_lut1, dec_lut2, dec_lut3, dec_lut4, dec_lut5, dec_lut6, dec_lut7, str1, v16); + str2 = AdvSimdTbx8Byte(v255, dec_lut0, dec_lut1, dec_lut2, dec_lut3, dec_lut4, dec_lut5, dec_lut6, dec_lut7, str2, v16); + str3 = AdvSimdTbx8Byte(v255, dec_lut0, dec_lut1, dec_lut2, dec_lut3, dec_lut4, dec_lut5, dec_lut6, dec_lut7, str3, v16); + + // Check for invalid input, any value larger than 63. + Vector128 classified0 = AdvSimd.Arm64.MaxPairwise(str0, str1); + Vector128 classified1 = AdvSimd.Arm64.MaxPairwise(str2, str3); + Vector128 maxChars = AdvSimd.Arm64.MaxPairwise(classified0, classified1); + if ((maxChars.AsUInt64().ToScalar() & 0xc0c0c0c0c0c0c0c0) != 0) + break; + + // Compress each four bytes into three. + Vector128 dec0 = Vector128.BitwiseOr(Vector128.ShiftLeft(str0, 2), Vector128.ShiftRightLogical(str1, 4)); + Vector128 dec1 = Vector128.BitwiseOr(Vector128.ShiftLeft(str1, 4), Vector128.ShiftRightLogical(str2, 2)); + Vector128 dec2 = Vector128.BitwiseOr(Vector128.ShiftLeft(str2, 6), str3); + + // Interleave the decoded result and store out. + // This is equivalent to a NEON ST3 instruction. + AdvSimdTbx3Byte(vzero, dec0, dec1, dec2, st3_interleave_index0, v16).Store(dest); + AdvSimdTbx3Byte(vzero, dec0, dec1, dec2, st3_interleave_index1, v16).Store(dest + 16); + AdvSimdTbx3Byte(vzero, dec0, dec1, dec2, st3_interleave_index2, v16).Store(dest + 32); + + src += 64; + dest += 48; + } + while (src <= srcEnd); + + srcBytes = src; + destBytes = dest; + } + [MethodImpl(MethodImplOptions.AggressiveInlining)] private static unsafe int Decode(byte* encodedBytes, ref sbyte decodingMap) {