Skip to content

Commit

Permalink
Add a dedicated Ascii.IsValid path (#84881)
Browse files Browse the repository at this point in the history
* Add dedicated Ascii.IsValid path

* Remove wrong assert

* Use for loop instead of chasing end byref

* Add an extra assert

* Use nuint in loop for [0, 3] elements

* Remove leftover byref increment
  • Loading branch information
MihaZupan authored Apr 20, 2023
1 parent 42acf9e commit 46e2597
Show file tree
Hide file tree
Showing 3 changed files with 251 additions and 58 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ private static unsafe nuint ChangeCase<TFrom, TTo, TCasing>(TFrom* pSrc, TTo* pD
// Unaligned read and check for non-ASCII data.

Vector128<TFrom> srcVector = Vector128.LoadUnsafe(ref *pSrc);
if (VectorContainsAnyNonAsciiData(srcVector))
if (VectorContainsNonAsciiChar(srcVector))
{
goto Drain64;
}
Expand Down Expand Up @@ -291,7 +291,7 @@ private static unsafe nuint ChangeCase<TFrom, TTo, TCasing>(TFrom* pSrc, TTo* pD
// Unaligned read & check for non-ASCII data.

srcVector = Vector128.LoadUnsafe(ref *pSrc, i);
if (VectorContainsAnyNonAsciiData(srcVector))
if (VectorContainsNonAsciiChar(srcVector))
{
goto Drain64;
}
Expand Down Expand Up @@ -463,30 +463,6 @@ private static unsafe nuint ChangeCase<TFrom, TTo, TCasing>(TFrom* pSrc, TTo* pD
return i;
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static unsafe bool VectorContainsAnyNonAsciiData<T>(Vector128<T> vector)
where T : unmanaged
{
if (sizeof(T) == 1)
{
if (vector.ExtractMostSignificantBits() != 0) { return true; }
}
else if (sizeof(T) == 2)
{
if (VectorContainsNonAsciiChar(vector.AsUInt16()))
{
return true;
}
}
else
{
Debug.Fail("Unknown types provided.");
throw new NotSupportedException();
}

return false;
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static unsafe void Widen8To16AndAndWriteTo(Vector128<byte> narrowVector, char* pDest, nuint destOffset)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,17 @@ private static bool AllCharsInUInt64AreAscii(ulong value)
return (value & ~0x007F007F_007F007Ful) == 0;
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static bool AllCharsInUInt64AreAscii<T>(ulong value)
where T : unmanaged
{
Debug.Assert(typeof(T) == typeof(byte) || typeof(T) == typeof(ushort));

return typeof(T) == typeof(byte)
? AllBytesInUInt64AreAscii(value)
: AllCharsInUInt64AreAscii(value);
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static int GetIndexOfFirstNonAsciiByteInLane_AdvSimd(Vector128<byte> value, Vector128<byte> bitmask)
{
Expand Down Expand Up @@ -1432,6 +1443,52 @@ private static bool VectorContainsNonAsciiChar(Vector128<ushort> utf16Vector)
}
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static bool VectorContainsNonAsciiChar<T>(Vector128<T> vector)
where T : unmanaged
{
Debug.Assert(typeof(T) == typeof(byte) || typeof(T) == typeof(ushort));

return typeof(T) == typeof(byte)
? VectorContainsNonAsciiChar(vector.AsByte())
: VectorContainsNonAsciiChar(vector.AsUInt16());
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static bool AllCharsInVectorAreAscii<T>(Vector128<T> vector)
where T : unmanaged
{
Debug.Assert(typeof(T) == typeof(byte) || typeof(T) == typeof(ushort));

// This is a copy of VectorContainsNonAsciiChar with an inverted condition.
if (typeof(T) == typeof(byte))
{
return
Sse41.IsSupported ? Sse41.TestZ(vector.AsByte(), Vector128.Create((byte)0x80)) :
AdvSimd.Arm64.IsSupported ? AllBytesInUInt64AreAscii(AdvSimd.Arm64.MaxPairwise(vector.AsByte(), vector.AsByte()).AsUInt64().ToScalar()) :
vector.AsByte().ExtractMostSignificantBits() == 0;
}
else
{
return
Sse41.IsSupported ? Sse41.TestZ(vector.AsInt16(), Vector128.Create((short)-128)) :
AdvSimd.Arm64.IsSupported ? AllCharsInUInt64AreAscii(AdvSimd.Arm64.MaxPairwise(vector.AsUInt16(), vector.AsUInt16()).AsUInt64().ToScalar()) :
(vector.AsUInt16() & Vector128.Create((ushort)(ushort.MaxValue - 127))) == Vector128<ushort>.Zero;
}
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static bool AllCharsInVectorAreAscii<T>(Vector256<T> vector)
where T : unmanaged
{
Debug.Assert(Avx.IsSupported);
Debug.Assert(typeof(T) == typeof(byte) || typeof(T) == typeof(ushort));

return typeof(T) == typeof(byte)
? Avx.TestZ(vector.AsByte(), Vector256.Create((byte)0x80))
: Avx.TestZ(vector.AsInt16(), Vector256.Create((short)-128));
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static Vector128<byte> ExtractAsciiVector(Vector128<ushort> vectorFirst, Vector128<ushort> vectorSecond)
{
Expand Down
224 changes: 192 additions & 32 deletions src/libraries/System.Private.CoreLib/src/System/Text/Ascii.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System.Diagnostics;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Runtime.Intrinsics;

namespace System.Text
{
Expand All @@ -14,56 +16,214 @@ public static partial class Ascii
/// <param name="value">The value to inspect.</param>
/// <returns>True if <paramref name="value"/> contains only ASCII bytes or is
/// empty; False otherwise.</returns>
public static unsafe bool IsValid(ReadOnlySpan<byte> value)
{
if (value.IsEmpty)
{
return true;
}

nuint bufferLength = (uint)value.Length;
fixed (byte* pBuffer = &MemoryMarshal.GetReference(value))
{
nuint idxOfFirstNonAsciiElement = GetIndexOfFirstNonAsciiByte(pBuffer, bufferLength);
Debug.Assert(idxOfFirstNonAsciiElement <= bufferLength);
return idxOfFirstNonAsciiElement == bufferLength;
}
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static bool IsValid(ReadOnlySpan<byte> value) =>
IsValidCore(ref MemoryMarshal.GetReference(value), value.Length);

/// <summary>
/// Determines whether the provided value contains only ASCII chars.
/// </summary>
/// <param name="value">The value to inspect.</param>
/// <returns>True if <paramref name="value"/> contains only ASCII chars or is
/// empty; False otherwise.</returns>
public static unsafe bool IsValid(ReadOnlySpan<char> value)
{
if (value.IsEmpty)
{
return true;
}

nuint bufferLength = (uint)value.Length;
fixed (char* pBuffer = &MemoryMarshal.GetReference(value))
{
nuint idxOfFirstNonAsciiElement = GetIndexOfFirstNonAsciiChar(pBuffer, bufferLength);
Debug.Assert(idxOfFirstNonAsciiElement <= bufferLength);
return idxOfFirstNonAsciiElement == bufferLength;
}
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static bool IsValid(ReadOnlySpan<char> value) =>
IsValidCore(ref Unsafe.As<char, ushort>(ref MemoryMarshal.GetReference(value)), value.Length);

/// <summary>
/// Determines whether the provided value is ASCII byte.
/// </summary>
/// <param name="value">The value to inspect.</param>
/// <returns>True if <paramref name="value"/> is ASCII, False otherwise.</returns>
public static unsafe bool IsValid(byte value) => value <= 127;
public static bool IsValid(byte value) => value <= 127;

/// <summary>
/// Determines whether the provided value is ASCII char.
/// </summary>
/// <param name="value">The value to inspect.</param>
/// <returns>True if <paramref name="value"/> is ASCII, False otherwise.</returns>
public static unsafe bool IsValid(char value) => value <= 127;
public static bool IsValid(char value) => value <= 127;

private static unsafe bool IsValidCore<T>(ref T searchSpace, int length) where T : unmanaged
{
Debug.Assert(typeof(T) == typeof(byte) || typeof(T) == typeof(ushort));

if (!Vector128.IsHardwareAccelerated || length < Vector128<T>.Count)
{
uint elementsPerUlong = (uint)(sizeof(ulong) / sizeof(T));

if (length < elementsPerUlong)
{
if (typeof(T) == typeof(byte) && length >= sizeof(uint))
{
// Process byte inputs with lengths [4, 7]
return AllBytesInUInt32AreAscii(
Unsafe.ReadUnaligned<uint>(ref Unsafe.As<T, byte>(ref searchSpace)) |
Unsafe.ReadUnaligned<uint>(ref Unsafe.As<T, byte>(ref Unsafe.Add(ref searchSpace, length - sizeof(uint)))));
}

// Process inputs with lengths [0, 3]
for (nuint j = 0; j < (uint)length; j++)
{
if (typeof(T) == typeof(byte)
? (Unsafe.BitCast<T, byte>(Unsafe.Add(ref searchSpace, j)) > 127)
: (Unsafe.BitCast<T, char>(Unsafe.Add(ref searchSpace, j)) > 127))
{
return false;
}
}

return true;
}

nuint i = 0;

// If vectorization isn't supported, process 16 bytes at a time.
if (!Vector128.IsHardwareAccelerated && length > 2 * elementsPerUlong)
{
nuint finalStart = (nuint)length - 2 * elementsPerUlong;

for (; i < finalStart; i += 2 * elementsPerUlong)
{
if (!AllCharsInUInt64AreAscii<T>(
Unsafe.ReadUnaligned<ulong>(ref Unsafe.As<T, byte>(ref Unsafe.Add(ref searchSpace, i))) |
Unsafe.ReadUnaligned<ulong>(ref Unsafe.As<T, byte>(ref Unsafe.Add(ref searchSpace, i + elementsPerUlong)))))
{
return false;
}
}

i = finalStart;
}

// Process the last [8, 16] bytes.
return AllCharsInUInt64AreAscii<T>(
Unsafe.ReadUnaligned<ulong>(ref Unsafe.As<T, byte>(ref Unsafe.Add(ref searchSpace, i))) |
Unsafe.ReadUnaligned<ulong>(ref Unsafe.Subtract(ref Unsafe.As<T, byte>(ref Unsafe.Add(ref searchSpace, length)), sizeof(ulong))));
}

ref T searchSpaceEnd = ref Unsafe.Add(ref searchSpace, length);

// Process inputs with lengths [16, 32] bytes.
if (length <= 2 * Vector128<T>.Count)
{
return AllCharsInVectorAreAscii(
Vector128.LoadUnsafe(ref searchSpace) |
Vector128.LoadUnsafe(ref Unsafe.Subtract(ref searchSpaceEnd, Vector128<T>.Count)));
}

if (Vector256.IsHardwareAccelerated)
{
// Process inputs with lengths [33, 64] bytes.
if (length <= 2 * Vector256<T>.Count)
{
return AllCharsInVectorAreAscii(
Vector256.LoadUnsafe(ref searchSpace) |
Vector256.LoadUnsafe(ref Unsafe.Subtract(ref searchSpaceEnd, Vector256<T>.Count)));
}

// Process long inputs 128 bytes at a time.
if (length > 4 * Vector256<T>.Count)
{
// Process the first 128 bytes.
if (!AllCharsInVectorAreAscii(
Vector256.LoadUnsafe(ref searchSpace) |
Vector256.LoadUnsafe(ref searchSpace, (nuint)Vector256<T>.Count) |
Vector256.LoadUnsafe(ref searchSpace, 2 * (nuint)Vector256<T>.Count) |
Vector256.LoadUnsafe(ref searchSpace, 3 * (nuint)Vector256<T>.Count)))
{
return false;
}

nuint i = 4 * (nuint)Vector256<T>.Count;

// Try to opportunistically align the reads below. The input isn't pinned, so the GC
// is free to move the references. We're therefore assuming that reads may still be unaligned.
// They may also be unaligned if the input chars aren't 2-byte aligned.
nuint misalignedElements = ((nuint)Unsafe.AsPointer(ref searchSpace) & (nuint)(Vector256<byte>.Count - 1)) / (nuint)sizeof(T);
i -= misalignedElements;
Debug.Assert((int)i > 3 * Vector256<T>.Count);

nuint finalStart = (nuint)length - 4 * (nuint)Vector256<T>.Count;

for (; i < finalStart; i += 4 * (nuint)Vector256<T>.Count)
{
ref T current = ref Unsafe.Add(ref searchSpace, i);

if (!AllCharsInVectorAreAscii(
Vector256.LoadUnsafe(ref current) |
Vector256.LoadUnsafe(ref current, (nuint)Vector256<T>.Count) |
Vector256.LoadUnsafe(ref current, 2 * (nuint)Vector256<T>.Count) |
Vector256.LoadUnsafe(ref current, 3 * (nuint)Vector256<T>.Count)))
{
return false;
}
}

searchSpace = ref Unsafe.Add(ref searchSpace, finalStart);
}

// Process the last [1, 128] bytes.
// The search space has at least 2 * Vector256 bytes available to read.
// We process the first 2 and last 2 vectors, which may overlap.
return AllCharsInVectorAreAscii(
Vector256.LoadUnsafe(ref searchSpace) |
Vector256.LoadUnsafe(ref searchSpace, (nuint)Vector256<T>.Count) |
Vector256.LoadUnsafe(ref Unsafe.Subtract(ref searchSpaceEnd, 2 * Vector256<T>.Count)) |
Vector256.LoadUnsafe(ref Unsafe.Subtract(ref searchSpaceEnd, Vector256<T>.Count)));
}
else
{
// Process long inputs 64 bytes at a time.
if (length > 4 * Vector128<T>.Count)
{
// Process the first 64 bytes.
if (!AllCharsInVectorAreAscii(
Vector128.LoadUnsafe(ref searchSpace) |
Vector128.LoadUnsafe(ref searchSpace, (nuint)Vector128<T>.Count) |
Vector128.LoadUnsafe(ref searchSpace, 2 * (nuint)Vector128<T>.Count) |
Vector128.LoadUnsafe(ref searchSpace, 3 * (nuint)Vector128<T>.Count)))
{
return false;
}

nuint i = 4 * (nuint)Vector128<T>.Count;

// Try to opportunistically align the reads below. The input isn't pinned, so the GC
// is free to move the references. We're therefore assuming that reads may still be unaligned.
// They may also be unaligned if the input chars aren't 2-byte aligned.
nuint misalignedElements = ((nuint)Unsafe.AsPointer(ref searchSpace) & (nuint)(Vector128<byte>.Count - 1)) / (nuint)sizeof(T);
i -= misalignedElements;
Debug.Assert((int)i > 3 * Vector128<T>.Count);

nuint finalStart = (nuint)length - 4 * (nuint)Vector128<T>.Count;

for (; i < finalStart; i += 4 * (nuint)Vector128<T>.Count)
{
ref T current = ref Unsafe.Add(ref searchSpace, i);

if (!AllCharsInVectorAreAscii(
Vector128.LoadUnsafe(ref current) |
Vector128.LoadUnsafe(ref current, (nuint)Vector128<T>.Count) |
Vector128.LoadUnsafe(ref current, 2 * (nuint)Vector128<T>.Count) |
Vector128.LoadUnsafe(ref current, 3 * (nuint)Vector128<T>.Count)))
{
return false;
}
}

searchSpace = ref Unsafe.Add(ref searchSpace, finalStart);
}

// Process the last [1, 64] bytes.
// The search space has at least 2 * Vector128 bytes available to read.
// We process the first 2 and last 2 vectors, which may overlap.
return AllCharsInVectorAreAscii(
Vector128.LoadUnsafe(ref searchSpace) |
Vector128.LoadUnsafe(ref searchSpace, (nuint)Vector128<T>.Count) |
Vector128.LoadUnsafe(ref Unsafe.Subtract(ref searchSpaceEnd, 2 * Vector128<T>.Count)) |
Vector128.LoadUnsafe(ref Unsafe.Subtract(ref searchSpaceEnd, Vector128<T>.Count)));
}
}
}
}

0 comments on commit 46e2597

Please sign in to comment.