Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve Span.Reverse fast path performance #70944

Merged
merged 11 commits into from
Nov 18, 2022
99 changes: 66 additions & 33 deletions src/libraries/System.Private.CoreLib/src/System/SpanHelpers.Byte.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1132,22 +1132,21 @@ public static void Reverse(ref byte buf, nuint length)
{
Debug.Assert(length > 1);

ref byte first = ref buf;
ref byte last = ref Unsafe.Add(ref first, length);
nint remainder = (nint)length;
nint offset = 0;

if (Avx2.IsSupported && length >= (nuint)Vector256<byte>.Count)
if (Avx2.IsSupported && remainder >= Vector256<byte>.Count)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The trick with taking the vectorized path for sizes smaller than 2*Vector256.Count is neat, but it hurts performance in number of cases. There can be a lot of redundant work done for certain sizes with this PR.

I have opened #78604 on this.

{
Vector256<byte> reverseMask = Vector256.Create(
(byte)15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, // first 128-bit lane
15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0); // second 128-bit lane

last = ref Unsafe.Subtract(ref last, Vector256<byte>.Count);
nint lastOffset = remainder - Vector256<byte>.Count;
do
{
// Load the values into vectors
Vector256<byte> tempFirst = Vector256.LoadUnsafe(ref first);
Vector256<byte> tempLast = Vector256.LoadUnsafe(ref last);
Vector256<byte> tempFirst = Vector256.LoadUnsafe(ref buf, (nuint)offset);
Vector256<byte> tempLast = Vector256.LoadUnsafe(ref buf, (nuint)lastOffset);

// Avx2 operates on two 128-bit lanes rather than the full 256-bit vector.
// Perform a shuffle to reverse each 128-bit lane, then permute to finish reversing the vector:
Expand All @@ -1174,55 +1173,89 @@ public static void Reverse(ref byte buf, nuint length)
tempLast = Avx2.Permute2x128(tempLast, tempLast, 0b00_01);

// Store the reversed vectors
tempLast.StoreUnsafe(ref first);
tempFirst.StoreUnsafe(ref last);
tempLast.StoreUnsafe(ref buf, (nuint)offset);
tempFirst.StoreUnsafe(ref buf, (nuint)lastOffset);

offset += Vector256<byte>.Count;
lastOffset -= Vector256<byte>.Count;
} while (lastOffset >= offset);

remainder = lastOffset + Vector256<byte>.Count - offset;
}
else if (Vector128.IsHardwareAccelerated && remainder >= Vector128<byte>.Count)
{
nint lastOffset = remainder - Vector128<byte>.Count;
do
{
// Load the values into vectors
Vector128<byte> tempFirst = Vector128.LoadUnsafe(ref buf, (nuint)offset);
Vector128<byte> tempLast = Vector128.LoadUnsafe(ref buf, (nuint)lastOffset);

// Shuffle to reverse each vector:
// +---------------------------------------------------------------+
// | A | B | C | D | E | F | G | H | I | J | K | L | M | N | O | P |
// +---------------------------------------------------------------+
// --->
// +---------------------------------------------------------------+
// | P | O | N | M | L | K | J | I | H | G | F | E | D | C | B | A |
// +---------------------------------------------------------------+
tempFirst = Vector128.Shuffle(tempFirst, Vector128.Create(
(byte)15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0));
tempLast = Vector128.Shuffle(tempLast, Vector128.Create(
(byte)15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0));

// Store the reversed vectors
tempLast.StoreUnsafe(ref buf, (nuint)offset);
tempFirst.StoreUnsafe(ref buf, (nuint)lastOffset);

first = ref Unsafe.Add(ref first, Vector256<byte>.Count);
last = ref Unsafe.Subtract(ref last, Vector256<byte>.Count);
} while (!Unsafe.IsAddressGreaterThan(ref first, ref last));
offset += Vector128<byte>.Count;
lastOffset -= Vector128<byte>.Count;
} while (lastOffset >= offset);

remainder = Unsafe.ByteOffset(ref first, ref Unsafe.Add(ref last, Vector256<byte>.Count));
remainder = lastOffset + Vector128<byte>.Count - offset;
}
else if (remainder >= sizeof(long))

if (remainder >= sizeof(long))
{
last = ref Unsafe.Subtract(ref last, sizeof(long));
nint lastOffset = (nint)length - offset - sizeof(long);
do
{
long tempFirst = Unsafe.ReadUnaligned<long>(ref first);
long tempLast = Unsafe.ReadUnaligned<long>(ref last);
long tempFirst = Unsafe.ReadUnaligned<long>(ref Unsafe.Add(ref buf, offset));
long tempLast = Unsafe.ReadUnaligned<long>(ref Unsafe.Add(ref buf, lastOffset));

// swap and store in reversed position
Unsafe.WriteUnaligned(ref first, BinaryPrimitives.ReverseEndianness(tempLast));
Unsafe.WriteUnaligned(ref last, BinaryPrimitives.ReverseEndianness(tempFirst));
Unsafe.WriteUnaligned(ref Unsafe.Add(ref buf, offset), BinaryPrimitives.ReverseEndianness(tempLast));
Unsafe.WriteUnaligned(ref Unsafe.Add(ref buf, lastOffset), BinaryPrimitives.ReverseEndianness(tempFirst));

first = ref Unsafe.Add(ref first, sizeof(long));
last = ref Unsafe.Subtract(ref last, sizeof(long));
} while (!Unsafe.IsAddressGreaterThan(ref first, ref last));
offset += sizeof(long);
lastOffset -= sizeof(long);
} while (lastOffset >= offset);

remainder = Unsafe.ByteOffset(ref first, ref Unsafe.Add(ref last, sizeof(long)));
remainder = lastOffset + sizeof(long) - offset;
}
else if (remainder >= sizeof(int))

if (remainder >= sizeof(int))
{
last = ref Unsafe.Subtract(ref last, sizeof(int));
nint lastOffset = (nint)length - offset - sizeof(int);
do
{
int tempFirst = Unsafe.ReadUnaligned<int>(ref first);
int tempLast = Unsafe.ReadUnaligned<int>(ref last);
int tempFirst = Unsafe.ReadUnaligned<int>(ref Unsafe.Add(ref buf, offset));
int tempLast = Unsafe.ReadUnaligned<int>(ref Unsafe.Add(ref buf, lastOffset));

// swap and store in reversed position
Unsafe.WriteUnaligned(ref first, BinaryPrimitives.ReverseEndianness(tempLast));
Unsafe.WriteUnaligned(ref last, BinaryPrimitives.ReverseEndianness(tempFirst));
Unsafe.WriteUnaligned(ref Unsafe.Add(ref buf, offset), BinaryPrimitives.ReverseEndianness(tempLast));
Unsafe.WriteUnaligned(ref Unsafe.Add(ref buf, lastOffset), BinaryPrimitives.ReverseEndianness(tempFirst));

first = ref Unsafe.Add(ref first, sizeof(int));
last = ref Unsafe.Subtract(ref last, sizeof(int));
} while (!Unsafe.IsAddressGreaterThan(ref first, ref last));
offset += sizeof(int);
lastOffset -= sizeof(int);
} while (lastOffset >= offset);

remainder = Unsafe.ByteOffset(ref first, ref Unsafe.Add(ref last, sizeof(int)));
remainder = lastOffset + sizeof(int) - offset;
}

if (remainder > 1)
{
ReverseInner(ref first, (nuint)remainder);
ReverseInner(ref Unsafe.Add(ref buf, offset), (nuint)remainder);
}
}
}
Expand Down
53 changes: 28 additions & 25 deletions src/libraries/System.Private.CoreLib/src/System/SpanHelpers.Char.cs
Original file line number Diff line number Diff line change
Expand Up @@ -735,21 +735,23 @@ public static void Reverse(ref char buf, nuint length)
{
Debug.Assert(length > 1);

ref char first = ref buf;
ref char last = ref Unsafe.Add(ref first, length);
nint remainder = (nint)length;
nint offset = 0;

if (Avx2.IsSupported && length >= (nuint)Vector256<ushort>.Count)
if (Avx2.IsSupported && remainder >= Vector256<ushort>.Count)
{
Vector256<byte> reverseMask = Vector256.Create(
(byte)14, 15, 12, 13, 10, 11, 8, 9, 6, 7, 4, 5, 2, 3, 0, 1, // first 128-bit lane
14, 15, 12, 13, 10, 11, 8, 9, 6, 7, 4, 5, 2, 3, 0, 1); // second 128-bit lane

last = ref Unsafe.Subtract(ref last, Vector256<ushort>.Count);
nint lastOffset = remainder - Vector256<ushort>.Count;
do
{
Vector256<byte> tempFirst = Vector256.LoadUnsafe(ref Unsafe.As<char, byte>(ref first));
Vector256<byte> tempLast = Vector256.LoadUnsafe(ref Unsafe.As<char, byte>(ref last));
ref byte first = ref Unsafe.As<char, byte>(ref Unsafe.Add(ref buf, offset));
ref byte last = ref Unsafe.As<char, byte>(ref Unsafe.Add(ref buf, lastOffset));

Vector256<byte> tempFirst = Vector256.LoadUnsafe(ref first);
Vector256<byte> tempLast = Vector256.LoadUnsafe(ref last);

// Avx2 operates on two 128-bit lanes rather than the full 256-bit vector.
// Perform a shuffle to reverse each 128-bit lane, then permute to finish reversing the vector:
Expand All @@ -770,23 +772,25 @@ public static void Reverse(ref char buf, nuint length)
tempLast = Avx2.Permute2x128(tempLast, tempLast, 0b00_01);

// Store the reversed vectors
tempLast.StoreUnsafe(ref Unsafe.As<char, byte>(ref first));
tempFirst.StoreUnsafe(ref Unsafe.As<char, byte>(ref last));
tempLast.StoreUnsafe(ref first);
tempFirst.StoreUnsafe(ref last);

first = ref Unsafe.Add(ref first, Vector256<ushort>.Count);
last = ref Unsafe.Subtract(ref last, Vector256<ushort>.Count);
} while (!Unsafe.IsAddressGreaterThan(ref first, ref last));
offset += Vector256<ushort>.Count;
lastOffset -= Vector256<ushort>.Count;
} while (lastOffset >= offset);

// shift for division is fine here since we don't care about what any negative number end up being
remainder = Unsafe.ByteOffset(ref first, ref Unsafe.Add(ref last, Vector256<ushort>.Count)) >> 1;
remainder = (lastOffset + Vector256<ushort>.Count - offset);
}
else if (Vector128.IsHardwareAccelerated && length >= (nuint)Vector128<ushort>.Count)
else if (Vector128.IsHardwareAccelerated && remainder >= Vector128<ushort>.Count)
{
last = ref Unsafe.Subtract(ref last, Vector128<ushort>.Count);
nint lastOffset = remainder - Vector128<ushort>.Count;
do
{
Vector128<ushort> tempFirst = Vector128.LoadUnsafe(ref Unsafe.As<char, ushort>(ref first));
Vector128<ushort> tempLast = Vector128.LoadUnsafe(ref Unsafe.As<char, ushort>(ref last));
ref ushort first = ref Unsafe.As<char, ushort>(ref Unsafe.Add(ref buf, offset));
ref ushort last = ref Unsafe.As<char, ushort>(ref Unsafe.Add(ref buf, lastOffset));

Vector128<ushort> tempFirst = Vector128.LoadUnsafe(ref first);
Vector128<ushort> tempLast = Vector128.LoadUnsafe(ref last);

// Shuffle to reverse each vector:
// +-------------------------------+
Expand All @@ -800,21 +804,20 @@ public static void Reverse(ref char buf, nuint length)
tempLast = Vector128.Shuffle(tempLast, Vector128.Create((ushort)7, 6, 5, 4, 3, 2, 1, 0));

// Store the reversed vectors
tempLast.StoreUnsafe(ref Unsafe.As<char, ushort>(ref first));
tempFirst.StoreUnsafe(ref Unsafe.As<char, ushort>(ref last));
tempLast.StoreUnsafe(ref first);
tempFirst.StoreUnsafe(ref last);

first = ref Unsafe.Add(ref first, Vector128<ushort>.Count);
last = ref Unsafe.Subtract(ref last, Vector128<ushort>.Count);
} while (!Unsafe.IsAddressGreaterThan(ref first, ref last));
offset += Vector128<ushort>.Count;
lastOffset -= Vector128<ushort>.Count;
} while (lastOffset >= offset);

// shift for division is fine here since we don't care about what any negative number end up being
remainder = Unsafe.ByteOffset(ref first, ref Unsafe.Add(ref last, Vector128<ushort>.Count)) >> 1;
remainder = (lastOffset + Vector128<ushort>.Count - offset);
}

// Store any remaining values one-by-one
if (remainder > 1)
{
ReverseInner(ref first, (nuint)remainder);
ReverseInner(ref Unsafe.Add(ref buf, offset), (nuint)remainder);
}
}
}
Expand Down
Loading