Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize Span<T>.Fill implementation #51365

Merged
merged 5 commits into from
Apr 17, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 110 additions & 0 deletions src/libraries/System.Memory/tests/Span/Fill.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Linq;
using System.Runtime.InteropServices;
using Xunit;
using static System.TestHelpers;
Expand Down Expand Up @@ -147,5 +148,114 @@ public static unsafe void FillNativeBytes()
Marshal.FreeHGlobal(new IntPtr(ptr));
}
}

[Fact]
public static void FillWithRecognizedType()
{
RunTest<sbyte>(0x20);
RunTest<byte>(0x20);
RunTest<bool>(true);
RunTest<short>(0x1234);
RunTest<ushort>(0x1234);
RunTest<char>('x');
RunTest<int>(0x12345678);
RunTest<uint>(0x12345678);
RunTest<long>(0x0123456789abcdef);
RunTest<ulong>(0x0123456789abcdef);
RunTest<nint>(unchecked((nint)0x0123456789abcdef));
RunTest<nuint>(unchecked((nuint)0x0123456789abcdef));
RunTest<Half>((Half)1.0);
RunTest<float>(1.0f);
RunTest<double>(1.0);
RunTest<StringComparison>(StringComparison.CurrentCultureIgnoreCase); // should be treated as underlying primitive
RunTest<string>("Hello world!"); // ref type, no SIMD
RunTest<decimal>(1.0m); // 128-bit struct
RunTest<Guid>(new Guid("29e07627-2481-4f43-8fbf-09cf21180239")); // 128-bit struct
RunTest<My96BitStruct>(new(0x11111111, 0x22222222, 0x33333333)); // 96-bit struct, no SIMD
RunTest<My256BitStruct>(new(0x1111111111111111, 0x2222222222222222, 0x3333333333333333, 0x4444444444444444));
RunTest<My512BitStruct>(new(
0x1111111111111111, 0x2222222222222222, 0x3333333333333333, 0x4444444444444444,
0x5555555555555555, 0x6666666666666666, 0x7777777777777777, 0x8888888888888888)); // 512-bit struct, no SIMD
RunTest<MyRefContainingStruct>(new("Hello world!")); // struct contains refs, no SIMD

static void RunTest<T>(T value)
{
T[] arr = new T[128];

// Run tests for lengths := 0 to 64, ensuring we don't overrun our buffer

for (int i = 0; i <= 64; i++)
{
arr.AsSpan(0, i).Fill(value);
Assert.Equal(Enumerable.Repeat(value, i), arr.Take(i)); // first i entries should've been populated with 'value'
Assert.Equal(Enumerable.Repeat(default(T), arr.Length - i), arr.Skip(i)); // remaining entries should contain default(T)
Array.Clear(arr, 0, arr.Length);
}
}
}

private readonly struct My96BitStruct
{
public My96BitStruct(int data0, int data1, int data2)
{
Data0 = data0;
Data1 = data1;
Data2 = data2;
}

public readonly int Data0;
public readonly int Data1;
public readonly int Data2;
}

private readonly struct My256BitStruct
{
public My256BitStruct(ulong data0, ulong data1, ulong data2, ulong data3)
{
Data0 = data0;
Data1 = data1;
Data2 = data2;
Data3 = data3;
}

public readonly ulong Data0;
public readonly ulong Data1;
public readonly ulong Data2;
public readonly ulong Data3;
}

private readonly struct My512BitStruct
{
public My512BitStruct(ulong data0, ulong data1, ulong data2, ulong data3, ulong data4, ulong data5, ulong data6, ulong data7)
{
Data0 = data0;
Data1 = data1;
Data2 = data2;
Data3 = data3;
Data4 = data4;
Data5 = data5;
Data6 = data6;
Data7 = data7;
}

public readonly ulong Data0;
public readonly ulong Data1;
public readonly ulong Data2;
public readonly ulong Data3;
public readonly ulong Data4;
public readonly ulong Data5;
public readonly ulong Data6;
public readonly ulong Data7;
}

private readonly struct MyRefContainingStruct
{
public MyRefContainingStruct(object data)
{
Data = data;
}

public readonly object Data;
}
}
}
49 changes: 8 additions & 41 deletions src/libraries/System.Private.CoreLib/src/System/Span.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Runtime.Versioning;
using System.Text;
using EditorBrowsableAttribute = System.ComponentModel.EditorBrowsableAttribute;
using EditorBrowsableState = System.ComponentModel.EditorBrowsableState;
using Internal.Runtime.CompilerServices;
Expand Down Expand Up @@ -280,53 +279,21 @@ public unsafe void Clear()
/// <summary>
/// Fills the contents of this span with the given value.
/// </summary>
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public void Fill(T value)
{
if (Unsafe.SizeOf<T>() == 1)
{
uint length = (uint)_length;
if (length == 0)
return;

T tmp = value; // Avoid taking address of the "value" argument. It would regress performance of the loop below.
Unsafe.InitBlockUnaligned(ref Unsafe.As<T, byte>(ref _pointer.Value), Unsafe.As<T, byte>(ref tmp), length);
// Special-case single-byte types like byte / sbyte / bool.
// The runtime eventually calls memset, which can efficiently support large buffers.
// We don't need to check IsReferenceOrContainsReferences because no references
// can ever be stored in types this small.
Unsafe.InitBlockUnaligned(ref Unsafe.As<T, byte>(ref _pointer.Value), Unsafe.As<T, byte>(ref value), (uint)_length);
}
else
{
// Do all math as nuint to avoid unnecessary 64->32->64 bit integer truncations
nuint length = (uint)_length;
if (length == 0)
return;

ref T r = ref _pointer.Value;

// TODO: Create block fill for value types of power of two sizes e.g. 2,4,8,16

nuint elementSize = (uint)Unsafe.SizeOf<T>();
nuint i = 0;
for (; i < (length & ~(nuint)7); i += 8)
{
Unsafe.AddByteOffset<T>(ref r, (i + 0) * elementSize) = value;
Unsafe.AddByteOffset<T>(ref r, (i + 1) * elementSize) = value;
Unsafe.AddByteOffset<T>(ref r, (i + 2) * elementSize) = value;
Unsafe.AddByteOffset<T>(ref r, (i + 3) * elementSize) = value;
Unsafe.AddByteOffset<T>(ref r, (i + 4) * elementSize) = value;
Unsafe.AddByteOffset<T>(ref r, (i + 5) * elementSize) = value;
Unsafe.AddByteOffset<T>(ref r, (i + 6) * elementSize) = value;
Unsafe.AddByteOffset<T>(ref r, (i + 7) * elementSize) = value;
}
if (i < (length & ~(nuint)3))
{
Unsafe.AddByteOffset<T>(ref r, (i + 0) * elementSize) = value;
Unsafe.AddByteOffset<T>(ref r, (i + 1) * elementSize) = value;
Unsafe.AddByteOffset<T>(ref r, (i + 2) * elementSize) = value;
Unsafe.AddByteOffset<T>(ref r, (i + 3) * elementSize) = value;
i += 4;
}
for (; i < length; i++)
{
Unsafe.AddByteOffset<T>(ref r, i * elementSize) = value;
}
// Call our optimized workhorse method for all other types.
SpanHelpers.Fill(ref _pointer.Value, (uint)_length, value);
}
}

Expand Down
169 changes: 168 additions & 1 deletion src/libraries/System.Private.CoreLib/src/System/SpanHelpers.T.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,180 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System.Diagnostics;

using System.Numerics;
using System.Runtime.CompilerServices;
using System.Runtime.Intrinsics;
using Internal.Runtime.CompilerServices;

namespace System
{
internal static partial class SpanHelpers // .T
{
public static void Fill<T>(ref T refData, nuint numElements, T value)
{
// Early checks to see if it's even possible to vectorize - JIT will turn these checks into consts.
// - T cannot contain references (GC can't track references in vectors)
// - Vectorization must be hardware-accelerated
// - T's size must not exceed the vector's size and must be a whole power of 2

if (RuntimeHelpers.IsReferenceOrContainsReferences<T>()) { goto CannotVectorize; }
if (!Vector.IsHardwareAccelerated) { goto CannotVectorize; }
if (Unsafe.SizeOf<T>() > Vector<byte>.Count) { goto CannotVectorize; }
if ((Unsafe.SizeOf<T>() & (Unsafe.SizeOf<T>() - 1)) != 0) { goto CannotVectorize; } // power of 2 check
GrabYourPitchforks marked this conversation as resolved.
Show resolved Hide resolved

if (numElements > (uint)(Vector<byte>.Count / Unsafe.SizeOf<T>()))
{
// We have enough data for at least one vectorized write.

T tmp = value; // Avoid taking address of the "value" argument. It would regress performance of the loops below.
Vector<byte> vector;

if (Unsafe.SizeOf<T>() == 1)
{
vector = new Vector<byte>(Unsafe.As<T, byte>(ref tmp));
}
else if (Unsafe.SizeOf<T>() == 2)
{
vector = (Vector<byte>)(new Vector<ushort>(Unsafe.As<T, ushort>(ref tmp)));
}
else if (Unsafe.SizeOf<T>() == 4)
{
// special-case float since it's already passed in a SIMD reg
vector = (typeof(T) == typeof(float))
? (Vector<byte>)(new Vector<float>((float)(object)tmp!))
: (Vector<byte>)(new Vector<uint>(Unsafe.As<T, uint>(ref tmp)));
}
else if (Unsafe.SizeOf<T>() == 8)
{
// special-case double since it's already passed in a SIMD reg
vector = (typeof(T) == typeof(double))
? (Vector<byte>)(new Vector<double>((double)(object)tmp!))
: (Vector<byte>)(new Vector<ulong>(Unsafe.As<T, ulong>(ref tmp)));
}
else if (Unsafe.SizeOf<T>() == 16)
{
Vector128<byte> vec128 = Unsafe.As<T, Vector128<byte>>(ref tmp);
if (Vector<byte>.Count == 16)
{
vector = vec128.AsVector();
}
else if (Vector<byte>.Count == 32)
{
vector = Vector256.Create(vec128, vec128).AsVector();
}
else
GrabYourPitchforks marked this conversation as resolved.
Show resolved Hide resolved
{
Debug.Fail("Vector<T> isn't 128 or 256 bits in size?");
goto CannotVectorize;
}
}
else if (Unsafe.SizeOf<T>() == 32)
GrabYourPitchforks marked this conversation as resolved.
Show resolved Hide resolved
{
vector = Unsafe.As<T, Vector256<byte>>(ref tmp).AsVector();
}
else
{
Debug.Fail("Vector<T> is greater than 256 bits in size?");
goto CannotVectorize;
}

ref byte refDataAsBytes = ref Unsafe.As<T, byte>(ref refData);
nuint totalByteLength = numElements * (nuint)Unsafe.SizeOf<T>(); // get this calculation ready ahead of time
nuint stopLoopAtOffset = totalByteLength & (nuint)(nint)(2 * (int)-Vector<byte>.Count); // intentional sign extension carries the negative bit
nuint offset = 0;

// Loop, writing 2 vectors at a time.
// Compare 'numElements' rather than 'stopLoopAtOffset' because we don't want a dependency
// on the very recently calculated 'stopLoopAtOffset' value.

if (numElements >= (uint)(2 * Vector<byte>.Count / Unsafe.SizeOf<T>()))
{
do
{
Unsafe.WriteUnaligned(ref Unsafe.AddByteOffset(ref refDataAsBytes, offset), vector);
Unsafe.WriteUnaligned(ref Unsafe.AddByteOffset(ref refDataAsBytes, offset + (nuint)Vector<byte>.Count), vector);
offset += (uint)(2 * Vector<byte>.Count);
} while (offset < stopLoopAtOffset);
}

// At this point, if any data remains to be written, it's strictly less than
// 2 * sizeof(Vector) bytes. The loop above had us write an even number of vectors.
// If the total byte length instead involves us writing an odd number of vectors, write
// one additional vector now. The bit check below tells us if we're in an "odd vector
// count" situation.

if ((totalByteLength & (nuint)Vector<byte>.Count) != 0)
GrabYourPitchforks marked this conversation as resolved.
Show resolved Hide resolved
{
Unsafe.WriteUnaligned(ref Unsafe.AddByteOffset(ref refDataAsBytes, offset), vector);
}

// It's possible that some small buffer remains to be populated - something that won't
// fit an entire vector's worth of data. Instead of falling back to a loop, we'll write
// a vector at the very end of the buffer. This may involve overwriting previously
// populated data, which is fine since we're splatting the same value for all entries.
// There's no need to perform a length check here because we already performed this
// check before entering the vectorized code path.

Unsafe.WriteUnaligned(ref Unsafe.AddByteOffset(ref refDataAsBytes, totalByteLength - (nuint)Vector<byte>.Count), vector);

// And we're done!

return;
}

CannotVectorize:

// If we reached this point, we cannot vectorize this T, or there are too few
// elements for us to vectorize. Fall back to an unrolled loop.

nuint i = 0;

// Write 8 elements at a time

if (numElements >= 8)
{
nuint stopLoopAtOffset = numElements & ~(nuint)7;
do
{
Unsafe.Add(ref refData, (nint)i + 0) = value;
Unsafe.Add(ref refData, (nint)i + 1) = value;
Unsafe.Add(ref refData, (nint)i + 2) = value;
Unsafe.Add(ref refData, (nint)i + 3) = value;
Unsafe.Add(ref refData, (nint)i + 4) = value;
Unsafe.Add(ref refData, (nint)i + 5) = value;
Unsafe.Add(ref refData, (nint)i + 6) = value;
Unsafe.Add(ref refData, (nint)i + 7) = value;
} while ((i += 8) < stopLoopAtOffset);
}

// Write next 4 elements if needed

if ((numElements & 4) != 0)
{
Unsafe.Add(ref refData, (nint)i + 0) = value;
Unsafe.Add(ref refData, (nint)i + 1) = value;
Unsafe.Add(ref refData, (nint)i + 2) = value;
Unsafe.Add(ref refData, (nint)i + 3) = value;
i += 4;
}

// Write next 2 elements if needed

if ((numElements & 2) != 0)
{
Unsafe.Add(ref refData, (nint)i + 0) = value;
Unsafe.Add(ref refData, (nint)i + 1) = value;
i += 2;
}

// Write final element if needed

if ((numElements & 1) != 0)
{
Unsafe.Add(ref refData, (nint)i) = value;
}
}

public static int IndexOf<T>(ref T searchSpace, int searchSpaceLength, ref T value, int valueLength) where T : IEquatable<T>
{
Debug.Assert(searchSpaceLength >= 0);
Expand Down