From 46e2597cbb9cbf68cf935ded056597afb492111e Mon Sep 17 00:00:00 2001 From: Miha Zupan Date: Thu, 20 Apr 2023 02:14:19 +0200 Subject: [PATCH] Add a dedicated Ascii.IsValid path (#84881) * Add dedicated Ascii.IsValid path * Remove wrong assert * Use for loop instead of chasing end byref * Add an extra assert * Use nuint in loop for [0, 3] elements * Remove leftover byref increment --- .../src/System/Text/Ascii.CaseConversion.cs | 28 +-- .../src/System/Text/Ascii.Utility.cs | 57 +++++ .../src/System/Text/Ascii.cs | 224 +++++++++++++++--- 3 files changed, 251 insertions(+), 58 deletions(-) diff --git a/src/libraries/System.Private.CoreLib/src/System/Text/Ascii.CaseConversion.cs b/src/libraries/System.Private.CoreLib/src/System/Text/Ascii.CaseConversion.cs index 9fa47f66fbded..c226161ec5749 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Text/Ascii.CaseConversion.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Text/Ascii.CaseConversion.cs @@ -238,7 +238,7 @@ private static unsafe nuint ChangeCase(TFrom* pSrc, TTo* pD // Unaligned read and check for non-ASCII data. Vector128 srcVector = Vector128.LoadUnsafe(ref *pSrc); - if (VectorContainsAnyNonAsciiData(srcVector)) + if (VectorContainsNonAsciiChar(srcVector)) { goto Drain64; } @@ -291,7 +291,7 @@ private static unsafe nuint ChangeCase(TFrom* pSrc, TTo* pD // Unaligned read & check for non-ASCII data. srcVector = Vector128.LoadUnsafe(ref *pSrc, i); - if (VectorContainsAnyNonAsciiData(srcVector)) + if (VectorContainsNonAsciiChar(srcVector)) { goto Drain64; } @@ -463,30 +463,6 @@ private static unsafe nuint ChangeCase(TFrom* pSrc, TTo* pD return i; } - [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static unsafe bool VectorContainsAnyNonAsciiData(Vector128 vector) - where T : unmanaged - { - if (sizeof(T) == 1) - { - if (vector.ExtractMostSignificantBits() != 0) { return true; } - } - else if (sizeof(T) == 2) - { - if (VectorContainsNonAsciiChar(vector.AsUInt16())) - { - return true; - } - } - else - { - Debug.Fail("Unknown types provided."); - throw new NotSupportedException(); - } - - return false; - } - [MethodImpl(MethodImplOptions.AggressiveInlining)] private static unsafe void Widen8To16AndAndWriteTo(Vector128 narrowVector, char* pDest, nuint destOffset) { diff --git a/src/libraries/System.Private.CoreLib/src/System/Text/Ascii.Utility.cs b/src/libraries/System.Private.CoreLib/src/System/Text/Ascii.Utility.cs index e8413c3ffe3af..4acf6e82baa6a 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Text/Ascii.Utility.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Text/Ascii.Utility.cs @@ -41,6 +41,17 @@ private static bool AllCharsInUInt64AreAscii(ulong value) return (value & ~0x007F007F_007F007Ful) == 0; } + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static bool AllCharsInUInt64AreAscii(ulong value) + where T : unmanaged + { + Debug.Assert(typeof(T) == typeof(byte) || typeof(T) == typeof(ushort)); + + return typeof(T) == typeof(byte) + ? AllBytesInUInt64AreAscii(value) + : AllCharsInUInt64AreAscii(value); + } + [MethodImpl(MethodImplOptions.AggressiveInlining)] private static int GetIndexOfFirstNonAsciiByteInLane_AdvSimd(Vector128 value, Vector128 bitmask) { @@ -1432,6 +1443,52 @@ private static bool VectorContainsNonAsciiChar(Vector128 utf16Vector) } } + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static bool VectorContainsNonAsciiChar(Vector128 vector) + where T : unmanaged + { + Debug.Assert(typeof(T) == typeof(byte) || typeof(T) == typeof(ushort)); + + return typeof(T) == typeof(byte) + ? VectorContainsNonAsciiChar(vector.AsByte()) + : VectorContainsNonAsciiChar(vector.AsUInt16()); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static bool AllCharsInVectorAreAscii(Vector128 vector) + where T : unmanaged + { + Debug.Assert(typeof(T) == typeof(byte) || typeof(T) == typeof(ushort)); + + // This is a copy of VectorContainsNonAsciiChar with an inverted condition. + if (typeof(T) == typeof(byte)) + { + return + Sse41.IsSupported ? Sse41.TestZ(vector.AsByte(), Vector128.Create((byte)0x80)) : + AdvSimd.Arm64.IsSupported ? AllBytesInUInt64AreAscii(AdvSimd.Arm64.MaxPairwise(vector.AsByte(), vector.AsByte()).AsUInt64().ToScalar()) : + vector.AsByte().ExtractMostSignificantBits() == 0; + } + else + { + return + Sse41.IsSupported ? Sse41.TestZ(vector.AsInt16(), Vector128.Create((short)-128)) : + AdvSimd.Arm64.IsSupported ? AllCharsInUInt64AreAscii(AdvSimd.Arm64.MaxPairwise(vector.AsUInt16(), vector.AsUInt16()).AsUInt64().ToScalar()) : + (vector.AsUInt16() & Vector128.Create((ushort)(ushort.MaxValue - 127))) == Vector128.Zero; + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static bool AllCharsInVectorAreAscii(Vector256 vector) + where T : unmanaged + { + Debug.Assert(Avx.IsSupported); + Debug.Assert(typeof(T) == typeof(byte) || typeof(T) == typeof(ushort)); + + return typeof(T) == typeof(byte) + ? Avx.TestZ(vector.AsByte(), Vector256.Create((byte)0x80)) + : Avx.TestZ(vector.AsInt16(), Vector256.Create((short)-128)); + } + [MethodImpl(MethodImplOptions.AggressiveInlining)] private static Vector128 ExtractAsciiVector(Vector128 vectorFirst, Vector128 vectorSecond) { diff --git a/src/libraries/System.Private.CoreLib/src/System/Text/Ascii.cs b/src/libraries/System.Private.CoreLib/src/System/Text/Ascii.cs index 8934bd91ada59..0b9af8cf0c6ae 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Text/Ascii.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Text/Ascii.cs @@ -2,7 +2,9 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Diagnostics; +using System.Runtime.CompilerServices; using System.Runtime.InteropServices; +using System.Runtime.Intrinsics; namespace System.Text { @@ -14,21 +16,9 @@ public static partial class Ascii /// The value to inspect. /// True if contains only ASCII bytes or is /// empty; False otherwise. - public static unsafe bool IsValid(ReadOnlySpan value) - { - if (value.IsEmpty) - { - return true; - } - - nuint bufferLength = (uint)value.Length; - fixed (byte* pBuffer = &MemoryMarshal.GetReference(value)) - { - nuint idxOfFirstNonAsciiElement = GetIndexOfFirstNonAsciiByte(pBuffer, bufferLength); - Debug.Assert(idxOfFirstNonAsciiElement <= bufferLength); - return idxOfFirstNonAsciiElement == bufferLength; - } - } + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static bool IsValid(ReadOnlySpan value) => + IsValidCore(ref MemoryMarshal.GetReference(value), value.Length); /// /// Determines whether the provided value contains only ASCII chars. @@ -36,34 +26,204 @@ public static unsafe bool IsValid(ReadOnlySpan value) /// The value to inspect. /// True if contains only ASCII chars or is /// empty; False otherwise. - public static unsafe bool IsValid(ReadOnlySpan value) - { - if (value.IsEmpty) - { - return true; - } - - nuint bufferLength = (uint)value.Length; - fixed (char* pBuffer = &MemoryMarshal.GetReference(value)) - { - nuint idxOfFirstNonAsciiElement = GetIndexOfFirstNonAsciiChar(pBuffer, bufferLength); - Debug.Assert(idxOfFirstNonAsciiElement <= bufferLength); - return idxOfFirstNonAsciiElement == bufferLength; - } - } + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static bool IsValid(ReadOnlySpan value) => + IsValidCore(ref Unsafe.As(ref MemoryMarshal.GetReference(value)), value.Length); /// /// Determines whether the provided value is ASCII byte. /// /// The value to inspect. /// True if is ASCII, False otherwise. - public static unsafe bool IsValid(byte value) => value <= 127; + public static bool IsValid(byte value) => value <= 127; /// /// Determines whether the provided value is ASCII char. /// /// The value to inspect. /// True if is ASCII, False otherwise. - public static unsafe bool IsValid(char value) => value <= 127; + public static bool IsValid(char value) => value <= 127; + + private static unsafe bool IsValidCore(ref T searchSpace, int length) where T : unmanaged + { + Debug.Assert(typeof(T) == typeof(byte) || typeof(T) == typeof(ushort)); + + if (!Vector128.IsHardwareAccelerated || length < Vector128.Count) + { + uint elementsPerUlong = (uint)(sizeof(ulong) / sizeof(T)); + + if (length < elementsPerUlong) + { + if (typeof(T) == typeof(byte) && length >= sizeof(uint)) + { + // Process byte inputs with lengths [4, 7] + return AllBytesInUInt32AreAscii( + Unsafe.ReadUnaligned(ref Unsafe.As(ref searchSpace)) | + Unsafe.ReadUnaligned(ref Unsafe.As(ref Unsafe.Add(ref searchSpace, length - sizeof(uint))))); + } + + // Process inputs with lengths [0, 3] + for (nuint j = 0; j < (uint)length; j++) + { + if (typeof(T) == typeof(byte) + ? (Unsafe.BitCast(Unsafe.Add(ref searchSpace, j)) > 127) + : (Unsafe.BitCast(Unsafe.Add(ref searchSpace, j)) > 127)) + { + return false; + } + } + + return true; + } + + nuint i = 0; + + // If vectorization isn't supported, process 16 bytes at a time. + if (!Vector128.IsHardwareAccelerated && length > 2 * elementsPerUlong) + { + nuint finalStart = (nuint)length - 2 * elementsPerUlong; + + for (; i < finalStart; i += 2 * elementsPerUlong) + { + if (!AllCharsInUInt64AreAscii( + Unsafe.ReadUnaligned(ref Unsafe.As(ref Unsafe.Add(ref searchSpace, i))) | + Unsafe.ReadUnaligned(ref Unsafe.As(ref Unsafe.Add(ref searchSpace, i + elementsPerUlong))))) + { + return false; + } + } + + i = finalStart; + } + + // Process the last [8, 16] bytes. + return AllCharsInUInt64AreAscii( + Unsafe.ReadUnaligned(ref Unsafe.As(ref Unsafe.Add(ref searchSpace, i))) | + Unsafe.ReadUnaligned(ref Unsafe.Subtract(ref Unsafe.As(ref Unsafe.Add(ref searchSpace, length)), sizeof(ulong)))); + } + + ref T searchSpaceEnd = ref Unsafe.Add(ref searchSpace, length); + + // Process inputs with lengths [16, 32] bytes. + if (length <= 2 * Vector128.Count) + { + return AllCharsInVectorAreAscii( + Vector128.LoadUnsafe(ref searchSpace) | + Vector128.LoadUnsafe(ref Unsafe.Subtract(ref searchSpaceEnd, Vector128.Count))); + } + + if (Vector256.IsHardwareAccelerated) + { + // Process inputs with lengths [33, 64] bytes. + if (length <= 2 * Vector256.Count) + { + return AllCharsInVectorAreAscii( + Vector256.LoadUnsafe(ref searchSpace) | + Vector256.LoadUnsafe(ref Unsafe.Subtract(ref searchSpaceEnd, Vector256.Count))); + } + + // Process long inputs 128 bytes at a time. + if (length > 4 * Vector256.Count) + { + // Process the first 128 bytes. + if (!AllCharsInVectorAreAscii( + Vector256.LoadUnsafe(ref searchSpace) | + Vector256.LoadUnsafe(ref searchSpace, (nuint)Vector256.Count) | + Vector256.LoadUnsafe(ref searchSpace, 2 * (nuint)Vector256.Count) | + Vector256.LoadUnsafe(ref searchSpace, 3 * (nuint)Vector256.Count))) + { + return false; + } + + nuint i = 4 * (nuint)Vector256.Count; + + // Try to opportunistically align the reads below. The input isn't pinned, so the GC + // is free to move the references. We're therefore assuming that reads may still be unaligned. + // They may also be unaligned if the input chars aren't 2-byte aligned. + nuint misalignedElements = ((nuint)Unsafe.AsPointer(ref searchSpace) & (nuint)(Vector256.Count - 1)) / (nuint)sizeof(T); + i -= misalignedElements; + Debug.Assert((int)i > 3 * Vector256.Count); + + nuint finalStart = (nuint)length - 4 * (nuint)Vector256.Count; + + for (; i < finalStart; i += 4 * (nuint)Vector256.Count) + { + ref T current = ref Unsafe.Add(ref searchSpace, i); + + if (!AllCharsInVectorAreAscii( + Vector256.LoadUnsafe(ref current) | + Vector256.LoadUnsafe(ref current, (nuint)Vector256.Count) | + Vector256.LoadUnsafe(ref current, 2 * (nuint)Vector256.Count) | + Vector256.LoadUnsafe(ref current, 3 * (nuint)Vector256.Count))) + { + return false; + } + } + + searchSpace = ref Unsafe.Add(ref searchSpace, finalStart); + } + + // Process the last [1, 128] bytes. + // The search space has at least 2 * Vector256 bytes available to read. + // We process the first 2 and last 2 vectors, which may overlap. + return AllCharsInVectorAreAscii( + Vector256.LoadUnsafe(ref searchSpace) | + Vector256.LoadUnsafe(ref searchSpace, (nuint)Vector256.Count) | + Vector256.LoadUnsafe(ref Unsafe.Subtract(ref searchSpaceEnd, 2 * Vector256.Count)) | + Vector256.LoadUnsafe(ref Unsafe.Subtract(ref searchSpaceEnd, Vector256.Count))); + } + else + { + // Process long inputs 64 bytes at a time. + if (length > 4 * Vector128.Count) + { + // Process the first 64 bytes. + if (!AllCharsInVectorAreAscii( + Vector128.LoadUnsafe(ref searchSpace) | + Vector128.LoadUnsafe(ref searchSpace, (nuint)Vector128.Count) | + Vector128.LoadUnsafe(ref searchSpace, 2 * (nuint)Vector128.Count) | + Vector128.LoadUnsafe(ref searchSpace, 3 * (nuint)Vector128.Count))) + { + return false; + } + + nuint i = 4 * (nuint)Vector128.Count; + + // Try to opportunistically align the reads below. The input isn't pinned, so the GC + // is free to move the references. We're therefore assuming that reads may still be unaligned. + // They may also be unaligned if the input chars aren't 2-byte aligned. + nuint misalignedElements = ((nuint)Unsafe.AsPointer(ref searchSpace) & (nuint)(Vector128.Count - 1)) / (nuint)sizeof(T); + i -= misalignedElements; + Debug.Assert((int)i > 3 * Vector128.Count); + + nuint finalStart = (nuint)length - 4 * (nuint)Vector128.Count; + + for (; i < finalStart; i += 4 * (nuint)Vector128.Count) + { + ref T current = ref Unsafe.Add(ref searchSpace, i); + + if (!AllCharsInVectorAreAscii( + Vector128.LoadUnsafe(ref current) | + Vector128.LoadUnsafe(ref current, (nuint)Vector128.Count) | + Vector128.LoadUnsafe(ref current, 2 * (nuint)Vector128.Count) | + Vector128.LoadUnsafe(ref current, 3 * (nuint)Vector128.Count))) + { + return false; + } + } + + searchSpace = ref Unsafe.Add(ref searchSpace, finalStart); + } + + // Process the last [1, 64] bytes. + // The search space has at least 2 * Vector128 bytes available to read. + // We process the first 2 and last 2 vectors, which may overlap. + return AllCharsInVectorAreAscii( + Vector128.LoadUnsafe(ref searchSpace) | + Vector128.LoadUnsafe(ref searchSpace, (nuint)Vector128.Count) | + Vector128.LoadUnsafe(ref Unsafe.Subtract(ref searchSpaceEnd, 2 * Vector128.Count)) | + Vector128.LoadUnsafe(ref Unsafe.Subtract(ref searchSpaceEnd, Vector128.Count))); + } + } } }