Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Light up Ascii.Equality.Equals and Ascii.Equality.EqualsIgnoreCase with Vector512 code path #88650

Merged
merged 14 commits into from
Jul 18, 2023
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 158 additions & 5 deletions src/libraries/System.Private.CoreLib/src/System/Text/Ascii.Equality.cs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,36 @@ private static bool Equals<TLeft, TRight, TLoader>(ref TLeft left, ref TRight ri
}
}
}
else if (Vector512.IsHardwareAccelerated && length >= (uint)Vector512<TLeft>.Count)
{
ref TLeft currentLeftSearchSpace = ref left;
ref TRight currentRightSearchSpace = ref right;
// Add Vector512<TLeft>.Count because TLeft == TRight
// Or we are in the Widen case where we iterate 2 * TRight.Count which is the same as TLeft.Count
Debug.Assert(Vector512<TLeft>.Count == Vector512<TRight>.Count
|| (typeof(TLoader) == typeof(WideningLoader) && Vector512<TLeft>.Count == Vector512<TRight>.Count * 2));
ref TRight oneVectorAwayFromRightEnd = ref Unsafe.Add(ref currentRightSearchSpace, length - (uint)Vector512<TLeft>.Count);

// Loop until either we've finished all elements or there's less than a vector's-worth remaining.
do
{
if (!TLoader.EqualAndAscii512(ref currentLeftSearchSpace, ref currentRightSearchSpace))
{
return false;
}

currentRightSearchSpace = ref Unsafe.Add(ref currentRightSearchSpace, Vector512<TLeft>.Count);
currentLeftSearchSpace = ref Unsafe.Add(ref currentLeftSearchSpace, Vector512<TLeft>.Count);
}
while (!Unsafe.IsAddressGreaterThan(ref currentRightSearchSpace, ref oneVectorAwayFromRightEnd));

// If any elements remain, process the last vector in the search space.
if (length % (uint)Vector512<TLeft>.Count != 0)
{
ref TLeft oneVectorAwayFromLeftEnd = ref Unsafe.Add(ref left, length - (uint)Vector512<TLeft>.Count);
return TLoader.EqualAndAscii512(ref oneVectorAwayFromLeftEnd, ref oneVectorAwayFromRightEnd);
}
Comment on lines +85 to +92
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Not specific to this PR since it's just following the existing pattern)

Since we're already doing the ref arithmetic here, we might be able to save a few instructions by changing such loops to

while (Unsafe.IsAddressLessThan(ref currentRightSearchSpace, ref oneVectorAwayFromRightEnd))
{ ... }

ref TLeft oneVectorAwayFromLeftEnd = ref Unsafe.Add(ref left, length - (uint)Vector512<TLeft>.Count);
return TLoader.EqualAndAscii512(ref oneVectorAwayFromLeftEnd, ref oneVectorAwayFromRightEnd);

}
else if (Avx.IsSupported && length >= (uint)Vector256<TLeft>.Count)
{
ref TLeft currentLeftSearchSpace = ref left;
Expand All @@ -74,7 +104,7 @@ private static bool Equals<TLeft, TRight, TLoader>(ref TLeft left, ref TRight ri
// Loop until either we've finished all elements or there's less than a vector's-worth remaining.
do
{
if (!TLoader.EqualAndAscii(ref currentLeftSearchSpace, ref currentRightSearchSpace))
if (!TLoader.EqualAndAscii256(ref currentLeftSearchSpace, ref currentRightSearchSpace))
{
return false;
}
Expand All @@ -88,7 +118,7 @@ private static bool Equals<TLeft, TRight, TLoader>(ref TLeft left, ref TRight ri
if (length % (uint)Vector256<TLeft>.Count != 0)
{
ref TLeft oneVectorAwayFromLeftEnd = ref Unsafe.Add(ref left, length - (uint)Vector256<TLeft>.Count);
return TLoader.EqualAndAscii(ref oneVectorAwayFromLeftEnd, ref oneVectorAwayFromRightEnd);
return TLoader.EqualAndAscii256(ref oneVectorAwayFromLeftEnd, ref oneVectorAwayFromRightEnd);
}
}
else
Expand Down Expand Up @@ -198,6 +228,77 @@ private static bool EqualsIgnoreCase<TLeft, TRight, TLoader>(ref TLeft left, ref
}
}
}
else if (Vector512.IsHardwareAccelerated && length >= (uint)Vector512<TRight>.Count)
{
ref TLeft currentLeftSearchSpace = ref left;
ref TLeft oneVectorAwayFromLeftEnd = ref Unsafe.Add(ref currentLeftSearchSpace, length - TLoader.Count512);
ref TRight currentRightSearchSpace = ref right;
ref TRight oneVectorAwayFromRightEnd = ref Unsafe.Add(ref currentRightSearchSpace, length - (uint)Vector512<TRight>.Count);

Vector512<TRight> leftValues;
Vector512<TRight> rightValues;

Vector512<TRight> loweringMask = Vector512.Create(TRight.CreateTruncating(0x20));
Vector512<TRight> vecA = Vector512.Create(TRight.CreateTruncating('a'));
Vector512<TRight> vecZMinusA = Vector512.Create(TRight.CreateTruncating(('z' - 'a')));

// Loop until either we've finished all elements or there's less than a vector's-worth remaining.
do
{
leftValues = TLoader.Load512(ref currentLeftSearchSpace);
rightValues = Vector512.LoadUnsafe(ref currentRightSearchSpace);
if (!AllCharsInVectorAreAscii(leftValues | rightValues))
{
return false;
}

Vector512<TRight> notEquals = ~Vector512.Equals(leftValues, rightValues);

if (notEquals != Vector512<TRight>.Zero)
{
// not exact match

leftValues |= loweringMask;
rightValues |= loweringMask;

if (Vector512.GreaterThanAny((leftValues - vecA) & notEquals, vecZMinusA) || leftValues != rightValues)
{
return false; // first input isn't in [A-Za-z], and not exact match of lowered
}
}

currentRightSearchSpace = ref Unsafe.Add(ref currentRightSearchSpace, (uint)Vector512<TRight>.Count);
currentLeftSearchSpace = ref Unsafe.Add(ref currentLeftSearchSpace, TLoader.Count512);
}
while (!Unsafe.IsAddressGreaterThan(ref currentRightSearchSpace, ref oneVectorAwayFromRightEnd));

// If any elements remain, process the last vector in the search space.
if (length % (uint)Vector512<TRight>.Count != 0)
{
leftValues = TLoader.Load512(ref oneVectorAwayFromLeftEnd);
rightValues = Vector512.LoadUnsafe(ref oneVectorAwayFromRightEnd);

if (!AllCharsInVectorAreAscii(leftValues | rightValues))
{
return false;
}

Vector512<TRight> notEquals = ~Vector512.Equals(leftValues, rightValues);

if (notEquals != Vector512<TRight>.Zero)
{
// not exact match

leftValues |= loweringMask;
rightValues |= loweringMask;

if (Vector512.GreaterThanAny((leftValues - vecA) & notEquals, vecZMinusA) || leftValues != rightValues)
{
return false; // first input isn't in [A-Za-z], and not exact match of lowered
}
}
}
}
else if (Avx.IsSupported && length >= (uint)Vector256<TRight>.Count)
{
ref TLeft currentLeftSearchSpace = ref left;
Expand Down Expand Up @@ -353,21 +454,26 @@ private interface ILoader<TLeft, TRight>
{
static abstract nuint Count128 { get; }
static abstract nuint Count256 { get; }
static abstract nuint Count512 { get; }
static abstract Vector128<TRight> Load128(ref TLeft ptr);
static abstract Vector256<TRight> Load256(ref TLeft ptr);
static abstract bool EqualAndAscii(ref TLeft left, ref TRight right);
static abstract Vector512<TRight> Load512(ref TLeft ptr);
static abstract bool EqualAndAscii256(ref TLeft left, ref TRight right);
static abstract bool EqualAndAscii512(ref TLeft left, ref TRight right);
}

private readonly struct PlainLoader<T> : ILoader<T, T> where T : unmanaged, INumberBase<T>
{
public static nuint Count128 => (uint)Vector128<T>.Count;
public static nuint Count256 => (uint)Vector256<T>.Count;
public static nuint Count512 => (uint)Vector512<T>.Count;
public static Vector128<T> Load128(ref T ptr) => Vector128.LoadUnsafe(ref ptr);
public static Vector256<T> Load256(ref T ptr) => Vector256.LoadUnsafe(ref ptr);
public static Vector512<T> Load512(ref T ptr) => Vector512.LoadUnsafe(ref ptr);

[MethodImpl(MethodImplOptions.AggressiveInlining)]
[CompExactlyDependsOn(typeof(Avx))]
public static bool EqualAndAscii(ref T left, ref T right)
public static bool EqualAndAscii256(ref T left, ref T right)
{
Vector256<T> leftValues = Vector256.LoadUnsafe(ref left);
Vector256<T> rightValues = Vector256.LoadUnsafe(ref right);
Expand All @@ -379,12 +485,27 @@ public static bool EqualAndAscii(ref T left, ref T right)

return true;
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static bool EqualAndAscii512(ref T left, ref T right)
{
Vector512<T> leftValues = Vector512.LoadUnsafe(ref left);
Vector512<T> rightValues = Vector512.LoadUnsafe(ref right);

if (leftValues != rightValues || !AllCharsInVectorAreAscii(leftValues))
{
return false;
}

return true;
}
}

private readonly struct WideningLoader : ILoader<byte, ushort>
{
public static nuint Count128 => sizeof(long);
public static nuint Count256 => (uint)Vector128<byte>.Count;
public static nuint Count512 => (uint)Vector256<byte>.Count;

[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static Vector128<ushort> Load128(ref byte ptr)
Expand Down Expand Up @@ -412,9 +533,16 @@ public static Vector256<ushort> Load256(ref byte ptr)
return Vector256.Create(lower, upper);
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static Vector512<ushort> Load512(ref byte ptr)
{
(Vector512<ushort> lower, Vector512<ushort> _) = Vector512.Widen(Vector256.LoadUnsafe(ref ptr).ToVector512());
return lower;
tannergooding marked this conversation as resolved.
Show resolved Hide resolved
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
[CompExactlyDependsOn(typeof(Avx))]
public static bool EqualAndAscii(ref byte utf8, ref ushort utf16)
public static bool EqualAndAscii256(ref byte utf8, ref ushort utf16)
{
// We widen the utf8 param so we can compare it to utf16, this doubles how much of the utf16 vector we search
Debug.Assert(Vector256<byte>.Count == Vector256<ushort>.Count * 2);
Expand All @@ -437,6 +565,31 @@ public static bool EqualAndAscii(ref byte utf8, ref ushort utf16)

return true;
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static bool EqualAndAscii512(ref byte utf8, ref ushort utf16)
{
// We widen the utf8 param so we can compare it to utf16, this doubles how much of the utf16 vector we search
Debug.Assert(Vector512<byte>.Count == Vector512<ushort>.Count * 2);

Vector512<byte> leftNotWidened = Vector512.LoadUnsafe(ref utf8);
if (!AllCharsInVectorAreAscii(leftNotWidened))
{
return false;
}

(Vector512<ushort> leftLower, Vector512<ushort> leftUpper) = Vector512.Widen(leftNotWidened);
Vector512<ushort> right = Vector512.LoadUnsafe(ref utf16);
Vector512<ushort> rightNext = Vector512.LoadUnsafe(ref utf16, (uint)Vector512<ushort>.Count);

// A branchless version of "leftLower != right || leftUpper != rightNext"
if (((leftLower ^ right) | (leftUpper ^ rightNext)) != Vector512<ushort>.Zero)
{
return false;
}

return true;
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1499,6 +1499,22 @@ private static bool AllCharsInVectorAreAscii<T>(Vector256<T> vector)
}
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static bool AllCharsInVectorAreAscii<T>(Vector512<T> vector)
where T : unmanaged
{
Debug.Assert(typeof(T) == typeof(byte) || typeof(T) == typeof(ushort));

if (typeof(T) == typeof(byte))
{
return vector.AsByte().ExtractMostSignificantBits() == 0;
}
else
{
return (vector.AsUInt16() & Vector512.Create((ushort)0xFF80)) == Vector512<ushort>.Zero;
}
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static Vector128<byte> ExtractAsciiVector(Vector128<ushort> vectorFirst, Vector128<ushort> vectorSecond)
{
Expand Down