Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Light up String.Manipulation APIs with Vector512 codepath #93043

Merged
merged 2 commits into from
Dec 22, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 96 additions & 30 deletions src/libraries/System.Private.CoreLib/src/System/String.Manipulation.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1188,7 +1188,21 @@ public string Replace(char oldChar, char newChar)
// process the remaining elements vectorized too.
// Thus we adjust the pointers so that at least one full vector from the end can be processed.
nuint length = (uint)Length;
if (Vector128.IsHardwareAccelerated && length >= (uint)Vector128<ushort>.Count)
if (Vector512.IsHardwareAccelerated && length >= (uint)Vector512<ushort>.Count)
{
nuint adjust = (length - remainingLength) & ((uint)Vector512<ushort>.Count - 1);
pSrc = ref Unsafe.Subtract(ref pSrc, adjust);
pDst = ref Unsafe.Subtract(ref pDst, adjust);
remainingLength += adjust;
}
else if (Vector256.IsHardwareAccelerated && length >= (uint)Vector256<ushort>.Count)
{
nuint adjust = (length - remainingLength) & ((uint)Vector256<ushort>.Count - 1);
pSrc = ref Unsafe.Subtract(ref pSrc, adjust);
pDst = ref Unsafe.Subtract(ref pDst, adjust);
remainingLength += adjust;
}
else if (Vector128.IsHardwareAccelerated && length >= (uint)Vector128<ushort>.Count)
{
nuint adjust = (length - remainingLength) & ((uint)Vector128<ushort>.Count - 1);
pSrc = ref Unsafe.Subtract(ref pSrc, adjust);
Expand Down Expand Up @@ -1899,46 +1913,98 @@ internal static void MakeSeparatorListAny(ReadOnlySpan<char> source, ReadOnlySpa

private static void MakeSeparatorListVectorized(ReadOnlySpan<char> sourceSpan, ref ValueListBuilder<int> sepListBuilder, char c, char c2, char c3)
{
// Redundant test so we won't prejit remainder of this method
// on platforms where it is not supported
if (!Vector128.IsHardwareAccelerated)
Debug.Assert(sourceSpan.Length >= Vector128<ushort>.Count);
nuint lengthToExamine = (uint)sourceSpan.Length;
nuint offset = 0;
ref char source = ref MemoryMarshal.GetReference(sourceSpan);

if (Vector512.IsHardwareAccelerated && lengthToExamine >= (uint)Vector512<ushort>.Count*2)
{
throw new PlatformNotSupportedException();
}
tannergooding marked this conversation as resolved.
Show resolved Hide resolved
Vector512<ushort> v1 = Vector512.Create((ushort)c);
Vector512<ushort> v2 = Vector512.Create((ushort)c2);
Vector512<ushort> v3 = Vector512.Create((ushort)c3);

Debug.Assert(sourceSpan.Length >= Vector128<ushort>.Count);
do
{
Vector512<ushort> vector = Vector512.LoadUnsafe(ref source, offset);
Vector512<ushort> v1Eq = Vector512.Equals(vector, v1);
Vector512<ushort> v2Eq = Vector512.Equals(vector, v2);
Vector512<ushort> v3Eq = Vector512.Equals(vector, v3);
Vector512<byte> cmp = (v1Eq | v2Eq | v3Eq).AsByte();
tannergooding marked this conversation as resolved.
Show resolved Hide resolved

nuint offset = 0;
nuint lengthToExamine = (uint)sourceSpan.Length;
if (cmp != Vector512<byte>.Zero)
{
// Skip every other bit
ulong mask = cmp.ExtractMostSignificantBits() & 0x5555555555555555;
do
{
uint bitPos = (uint)BitOperations.TrailingZeroCount(mask) / sizeof(char);
sepListBuilder.Append((int)(offset + bitPos));
mask = BitOperations.ResetLowestSetBit(mask);
} while (mask != 0);
}

ref char source = ref MemoryMarshal.GetReference(sourceSpan);
offset += (nuint)Vector512<ushort>.Count;
} while (offset <= lengthToExamine - (nuint)Vector512<ushort>.Count);
}
else if (Vector256.IsHardwareAccelerated && lengthToExamine >= (uint)Vector256<ushort>.Count*2)
{
Vector256<ushort> v1 = Vector256.Create((ushort)c);
Vector256<ushort> v2 = Vector256.Create((ushort)c2);
Vector256<ushort> v3 = Vector256.Create((ushort)c3);

Vector128<ushort> v1 = Vector128.Create((ushort)c);
Vector128<ushort> v2 = Vector128.Create((ushort)c2);
Vector128<ushort> v3 = Vector128.Create((ushort)c3);
do
{
Vector256<ushort> vector = Vector256.LoadUnsafe(ref source, offset);
Vector256<ushort> v1Eq = Vector256.Equals(vector, v1);
Vector256<ushort> v2Eq = Vector256.Equals(vector, v2);
Vector256<ushort> v3Eq = Vector256.Equals(vector, v3);
Vector256<byte> cmp = (v1Eq | v2Eq | v3Eq).AsByte();

if (cmp != Vector256<byte>.Zero)
{
// Skip every other bit
uint mask = cmp.ExtractMostSignificantBits() & 0x55555555;
do
{
uint bitPos = (uint)BitOperations.TrailingZeroCount(mask) / sizeof(char);
sepListBuilder.Append((int)(offset + bitPos));
mask = BitOperations.ResetLowestSetBit(mask);
} while (mask != 0);
}

do
offset += (nuint)Vector256<ushort>.Count;
} while (offset <= lengthToExamine - (nuint)Vector256<ushort>.Count);
}
else if (Vector128.IsHardwareAccelerated)
{
Vector128<ushort> vector = Vector128.LoadUnsafe(ref source, offset);
Vector128<ushort> v1Eq = Vector128.Equals(vector, v1);
Vector128<ushort> v2Eq = Vector128.Equals(vector, v2);
Vector128<ushort> v3Eq = Vector128.Equals(vector, v3);
Vector128<byte> cmp = (v1Eq | v2Eq | v3Eq).AsByte();
Vector128<ushort> v1 = Vector128.Create((ushort)c);
Vector128<ushort> v2 = Vector128.Create((ushort)c2);
Vector128<ushort> v3 = Vector128.Create((ushort)c3);

if (cmp != Vector128<byte>.Zero)
do
{
// Skip every other bit
uint mask = cmp.ExtractMostSignificantBits() & 0x5555;
do
Vector128<ushort> vector = Vector128.LoadUnsafe(ref source, offset);
Vector128<ushort> v1Eq = Vector128.Equals(vector, v1);
Vector128<ushort> v2Eq = Vector128.Equals(vector, v2);
Vector128<ushort> v3Eq = Vector128.Equals(vector, v3);
Vector128<byte> cmp = (v1Eq | v2Eq | v3Eq).AsByte();

if (cmp != Vector128<byte>.Zero)
{
uint bitPos = (uint)BitOperations.TrailingZeroCount(mask) / sizeof(char);
sepListBuilder.Append((int)(offset + bitPos));
mask = BitOperations.ResetLowestSetBit(mask);
} while (mask != 0);
}
// Skip every other bit
uint mask = cmp.ExtractMostSignificantBits() & 0x5555;
do
{
uint bitPos = (uint)BitOperations.TrailingZeroCount(mask) / sizeof(char);
sepListBuilder.Append((int)(offset + bitPos));
mask = BitOperations.ResetLowestSetBit(mask);
} while (mask != 0);
}

offset += (nuint)Vector128<ushort>.Count;
} while (offset <= lengthToExamine - (nuint)Vector128<ushort>.Count);
offset += (nuint)Vector128<ushort>.Count;
} while (offset <= lengthToExamine - (nuint)Vector128<ushort>.Count);
}

while (offset < lengthToExamine)
{
Expand Down