Skip to content

Commit

Permalink
Address remaining PR feedback
Browse files Browse the repository at this point in the history
Also move some tests to outerloop to decrease inner loop testing time
  • Loading branch information
stephentoub committed Feb 10, 2024
1 parent fa2daf2 commit 81564a6
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13274,34 +13274,27 @@ ref Unsafe.As<ulong, T>(ref MemoryMarshal.GetReference(RemainderUInt64Mask_8x9))
throw new NotSupportedException();
}

// TODO: The uses of these ApplyScalar methods are all as part of operators when handling edge cases (NaN, Infinity, really large inputs, etc.)
// Currently, these edge cases are not handled in a vectorized way and instead fall back to scalar processing. We can look into
// handling those in a vectorized manner as well.

private static Vector128<float> ApplyScalar<TOperator>(Vector128<float> floats) where TOperator : IUnaryOperator<float, float> =>
Vector128.Create(
TOperator.Invoke(floats[0]), TOperator.Invoke(floats[1]), TOperator.Invoke(floats[2]), TOperator.Invoke(floats[3]));
Vector128.Create(TOperator.Invoke(floats[0]), TOperator.Invoke(floats[1]), TOperator.Invoke(floats[2]), TOperator.Invoke(floats[3]));

private static Vector256<float> ApplyScalar<TOperator>(Vector256<float> floats) where TOperator : IUnaryOperator<float, float> =>
Vector256.Create(
TOperator.Invoke(floats[0]), TOperator.Invoke(floats[1]), TOperator.Invoke(floats[2]), TOperator.Invoke(floats[3]),
TOperator.Invoke(floats[4]), TOperator.Invoke(floats[5]), TOperator.Invoke(floats[6]), TOperator.Invoke(floats[7]));
Vector256.Create(ApplyScalar<TOperator>(floats.GetLower()), ApplyScalar<TOperator>(floats.GetUpper()));

private static Vector512<float> ApplyScalar<TOperator>(Vector512<float> floats) where TOperator : IUnaryOperator<float, float> =>
Vector512.Create(
TOperator.Invoke(floats[0]), TOperator.Invoke(floats[1]), TOperator.Invoke(floats[2]), TOperator.Invoke(floats[3]),
TOperator.Invoke(floats[4]), TOperator.Invoke(floats[5]), TOperator.Invoke(floats[6]), TOperator.Invoke(floats[7]),
TOperator.Invoke(floats[8]), TOperator.Invoke(floats[9]), TOperator.Invoke(floats[10]), TOperator.Invoke(floats[11]),
TOperator.Invoke(floats[12]), TOperator.Invoke(floats[13]), TOperator.Invoke(floats[14]), TOperator.Invoke(floats[15]));
Vector512.Create(ApplyScalar<TOperator>(floats.GetLower()), ApplyScalar<TOperator>(floats.GetUpper()));

private static Vector128<double> ApplyScalar<TOperator>(Vector128<double> doubles) where TOperator : IUnaryOperator<double, double> =>
Vector128.Create(
TOperator.Invoke(doubles[0]), TOperator.Invoke(doubles[1]));
Vector128.Create(TOperator.Invoke(doubles[0]), TOperator.Invoke(doubles[1]));

private static Vector256<double> ApplyScalar<TOperator>(Vector256<double> doubles) where TOperator : IUnaryOperator<double, double> =>
Vector256.Create(
TOperator.Invoke(doubles[0]), TOperator.Invoke(doubles[1]), TOperator.Invoke(doubles[2]), TOperator.Invoke(doubles[3]));
Vector256.Create(ApplyScalar<TOperator>(doubles.GetLower()), ApplyScalar<TOperator>(doubles.GetUpper()));

private static Vector512<double> ApplyScalar<TOperator>(Vector512<double> doubles) where TOperator : IUnaryOperator<double, double> =>
Vector512.Create(
TOperator.Invoke(doubles[0]), TOperator.Invoke(doubles[1]), TOperator.Invoke(doubles[2]), TOperator.Invoke(doubles[3]),
TOperator.Invoke(doubles[4]), TOperator.Invoke(doubles[5]), TOperator.Invoke(doubles[6]), TOperator.Invoke(doubles[7]));
Vector512.Create(ApplyScalar<TOperator>(doubles.GetLower()), ApplyScalar<TOperator>(doubles.GetUpper()));

/// <summary>Creates a span of <typeparamref name="TTo"/> from a <typeparamref name="TTo"/> when they're the same type.</summary>
private static unsafe ReadOnlySpan<TTo> Rename<TFrom, TTo>(ReadOnlySpan<TFrom> span)
Expand Down Expand Up @@ -16075,7 +16068,7 @@ public static Vector512<T> Invoke(Vector512<T> x)

public static Vector128<float> Invoke(Vector128<float> x)
{
Vector128<uint> uxMasked = x.AsUInt32() & Vector128.Create(SignMask);
Vector128<uint> uxMasked = Vector128.Abs(x).AsUInt32();
if (Vector128.GreaterThanAny(uxMasked, Vector128.Create(MaxVectorizedValue)))
{
return ApplyScalar<CosOperatorSingle>(x);
Expand All @@ -16102,7 +16095,7 @@ public static Vector128<float> Invoke(Vector128<float> x)

public static Vector256<float> Invoke(Vector256<float> x)
{
Vector256<uint> uxMasked = x.AsUInt32() & Vector256.Create(0x7FFFFFFFu);
Vector256<uint> uxMasked = Vector256.Abs(x).AsUInt32();
if (Vector256.GreaterThanAny(uxMasked, Vector256.Create(MaxVectorizedValue)))
{
return ApplyScalar<CosOperatorSingle>(x);
Expand All @@ -16129,7 +16122,7 @@ public static Vector256<float> Invoke(Vector256<float> x)

public static Vector512<float> Invoke(Vector512<float> x)
{
Vector512<uint> uxMasked = x.AsUInt32() & Vector512.Create(SignMask);
Vector512<uint> uxMasked = Vector512.Abs(x).AsUInt32();
if (Vector512.GreaterThanAny(uxMasked, Vector512.Create(MaxVectorizedValue)))
{
return ApplyScalar<CosOperatorSingle>(x);
Expand Down Expand Up @@ -16178,7 +16171,7 @@ public static Vector512<float> Invoke(Vector512<float> x)

public static Vector128<double> Invoke(Vector128<double> x)
{
Vector128<ulong> uxMasked = x.AsUInt64() & Vector128.Create(SignMask);
Vector128<ulong> uxMasked = Vector128.Abs(x).AsUInt64();
if (Vector128.GreaterThanAny(uxMasked, Vector128.Create(MaxVectorizedValue)))
{
return ApplyScalar<CosOperatorDouble>(x);
Expand Down Expand Up @@ -16210,7 +16203,7 @@ public static Vector128<double> Invoke(Vector128<double> x)

public static Vector256<double> Invoke(Vector256<double> x)
{
Vector256<ulong> uxMasked = x.AsUInt64() & Vector256.Create(SignMask);
Vector256<ulong> uxMasked = Vector256.Abs(x).AsUInt64();
if (Vector256.GreaterThanAny(uxMasked, Vector256.Create(MaxVectorizedValue)))
{
return ApplyScalar<CosOperatorDouble>(x);
Expand Down Expand Up @@ -16242,7 +16235,7 @@ public static Vector256<double> Invoke(Vector256<double> x)

public static Vector512<double> Invoke(Vector512<double> x)
{
Vector512<ulong> uxMasked = x.AsUInt64() & Vector512.Create(SignMask);
Vector512<ulong> uxMasked = Vector512.Abs(x).AsUInt64();
if (Vector512.GreaterThanAny(uxMasked, Vector512.Create(MaxVectorizedValue)))
{
return ApplyScalar<CosOperatorDouble>(x);
Expand Down Expand Up @@ -16533,9 +16526,8 @@ public static Vector512<T> Invoke(Vector512<T> x)

public static Vector128<float> Invoke(Vector128<float> x)
{
Vector128<uint> ux = x.AsUInt32();
Vector128<uint> sign = ux & Vector128.Create(~SignMask);
Vector128<uint> uxMasked = ux & Vector128.Create(SignMask);
Vector128<uint> sign = x.AsUInt32() & Vector128.Create(~SignMask);
Vector128<uint> uxMasked = Vector128.Abs(x).AsUInt32();

if (Vector128.GreaterThanAny(uxMasked, Vector128.Create(MaxVectorizedValue)))
{
Expand Down Expand Up @@ -16563,9 +16555,8 @@ public static Vector128<float> Invoke(Vector128<float> x)

public static Vector256<float> Invoke(Vector256<float> x)
{
Vector256<uint> ux = x.AsUInt32();
Vector256<uint> sign = ux & Vector256.Create(~SignMask);
Vector256<uint> uxMasked = ux & Vector256.Create(SignMask);
Vector256<uint> sign = x.AsUInt32() & Vector256.Create(~SignMask);
Vector256<uint> uxMasked = Vector256.Abs(x).AsUInt32();

if (Vector256.GreaterThanAny(uxMasked, Vector256.Create(MaxVectorizedValue)))
{
Expand Down Expand Up @@ -16593,16 +16584,15 @@ public static Vector256<float> Invoke(Vector256<float> x)

public static Vector512<float> Invoke(Vector512<float> x)
{
Vector512<uint> ux = x.AsUInt32();
Vector512<uint> sign = ux & Vector512.Create(~SignMask);
Vector512<uint> uxMasked = ux & Vector512.Create(SignMask);
Vector512<uint> sign = x.AsUInt32() & Vector512.Create(~SignMask);
Vector512<uint> uxMasked = Vector512.Abs(x).AsUInt32();

if (Vector512.GreaterThanAny(uxMasked, Vector512.Create(MaxVectorizedValue)))
{
return ApplyScalar<SinOperatorSingle>(x);
}

Vector512<float> r = (ux & Vector512.Create(SignMask)).AsSingle();
Vector512<float> r = uxMasked.AsSingle();
Vector512<float> almHuge = Vector512.Create(AlmHuge);
Vector512<float> dn = (r * Vector512.Create(1 / float.Pi)) + almHuge;
Vector512<uint> odd = dn.AsUInt32() << 31;
Expand Down Expand Up @@ -16645,9 +16635,8 @@ public static Vector512<float> Invoke(Vector512<float> x)

public static Vector128<double> Invoke(Vector128<double> x)
{
Vector128<ulong> ux = x.AsUInt64();
Vector128<ulong> sign = ux & Vector128.Create(~SignMask);
Vector128<ulong> uxMasked = ux & Vector128.Create(SignMask);
Vector128<ulong> sign = x.AsUInt64() & Vector128.Create(~SignMask);
Vector128<ulong> uxMasked = Vector128.Abs(x).AsUInt64();

if (Vector128.GreaterThanAny(uxMasked, Vector128.Create(MaxVectorizedValue)))
{
Expand Down Expand Up @@ -16680,16 +16669,15 @@ public static Vector128<double> Invoke(Vector128<double> x)

public static Vector256<double> Invoke(Vector256<double> x)
{
Vector256<ulong> ux = x.AsUInt64();
Vector256<ulong> sign = ux & Vector256.Create(~SignMask);
Vector256<ulong> uxMasked = ux & Vector256.Create(SignMask);
Vector256<ulong> sign = x.AsUInt64() & Vector256.Create(~SignMask);
Vector256<ulong> uxMasked = Vector256.Abs(x).AsUInt64();

if (Vector256.GreaterThanAny(uxMasked, Vector256.Create(MaxVectorizedValue)))
{
return ApplyScalar<SinOperatorDouble>(x);
}

Vector256<double> r = (ux & Vector256.Create(SignMask)).AsDouble();
Vector256<double> r = uxMasked.AsDouble();
Vector256<double> almHuge = Vector256.Create(AlmHuge);
Vector256<double> dn = (r * Vector256.Create(1 / double.Pi)) + almHuge;
Vector256<ulong> odd = dn.AsUInt64() << 63;
Expand All @@ -16715,16 +16703,15 @@ public static Vector256<double> Invoke(Vector256<double> x)

public static Vector512<double> Invoke(Vector512<double> x)
{
Vector512<ulong> ux = x.AsUInt64();
Vector512<ulong> sign = ux & Vector512.Create(~SignMask);
Vector512<ulong> uxMasked = ux & Vector512.Create(SignMask);
Vector512<ulong> sign = x.AsUInt64() & Vector512.Create(~SignMask);
Vector512<ulong> uxMasked = Vector512.Abs(x).AsUInt64();

if (Vector512.GreaterThanAny(uxMasked, Vector512.Create(MaxVectorizedValue)))
{
return ApplyScalar<SinOperatorDouble>(x);
}

Vector512<double> r = (ux & Vector512.Create(SignMask)).AsDouble();
Vector512<double> r = uxMasked.AsDouble();
Vector512<double> almHuge = Vector512.Create(AlmHuge);
Vector512<double> dn = (r * Vector512.Create(1 / double.Pi)) + almHuge;
Vector512<ulong> odd = dn.AsUInt64() << 63;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,29 @@ namespace System.Numerics.Tensors.Tests
{
public class ConvertTests
{
[ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsNotBuiltWithAggressiveTrimming))]
[Fact]
[SkipOnCoreClr("Depends heavily on folded type comparisons", RuntimeTestModes.JitMinOpts)]
public void ConvertTruncatingAndSaturating()
{
// A few cases. More exhaustive testing is done in the OuterLoop test.

ConvertTruncatingImpl<float, double>();
ConvertTruncatingImpl<double, float>();
ConvertTruncatingImpl<long, byte>();
ConvertTruncatingImpl<short, uint>();
ConvertTruncatingImpl<Half, int>();

ConvertSaturatingImpl<float, double>();
ConvertSaturatingImpl<double, float>();
ConvertSaturatingImpl<long, byte>();
ConvertSaturatingImpl<short, uint>();
ConvertSaturatingImpl<Half, int>();
}

[OuterLoop]
[ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsNotBuiltWithAggressiveTrimming))]
[SkipOnCoreClr("Depends heavily on folded type comparisons", RuntimeTestModes.JitMinOpts)]
public void ConvertTruncatingAndSaturating_Outerloop()
{
MethodInfo convertTruncatingImpl = typeof(ConvertTests).GetMethod(nameof(ConvertTruncatingImpl), BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static | BindingFlags.Instance);
Assert.NotNull(convertTruncatingImpl);
Expand Down Expand Up @@ -54,11 +74,8 @@ public void ConvertChecked()
{
// Conversions that never overflow. This isn't an exhaustive list; just a sampling.
ConvertCheckedImpl<byte, byte>();
ConvertCheckedImpl<byte, ushort>();
ConvertCheckedImpl<byte, short>();
ConvertCheckedImpl<byte, uint>();
ConvertCheckedImpl<byte, int>();
ConvertCheckedImpl<byte, ulong>();
ConvertCheckedImpl<byte, long>();
ConvertCheckedImpl<byte, float>();
ConvertCheckedImpl<Half, Half>();
Expand All @@ -78,12 +95,12 @@ private static void ConvertTruncatingImpl<TFrom, TTo>()
{
AssertExtensions.Throws<ArgumentException>("destination", () => TensorPrimitives.ConvertTruncating<TFrom, TTo>(new TFrom[3], new TTo[2]));

Random rand = new(42);
foreach (int tensorLength in Helpers.TensorLengthsIncluding0)
{
using BoundedMemory<TFrom> source = BoundedMemory.Allocate<TFrom>(tensorLength);
using BoundedMemory<TTo> destination = BoundedMemory.Allocate<TTo>(tensorLength);

Random rand = new(42);
Span<TFrom> sourceSpan = source.Span;
for (int i = 0; i < tensorLength; i++)
{
Expand All @@ -110,12 +127,12 @@ private static void ConvertSaturatingImpl<TFrom, TTo>()
{
AssertExtensions.Throws<ArgumentException>("destination", () => TensorPrimitives.ConvertSaturating<TFrom, TTo>(new TFrom[3], new TTo[2]));

Random rand = new(42);
foreach (int tensorLength in Helpers.TensorLengthsIncluding0)
{
using BoundedMemory<TFrom> source = BoundedMemory.Allocate<TFrom>(tensorLength);
using BoundedMemory<TTo> destination = BoundedMemory.Allocate<TTo>(tensorLength);

Random rand = new(42);
Span<TFrom> sourceSpan = source.Span;
for (int i = 0; i < tensorLength; i++)
{
Expand Down Expand Up @@ -209,29 +226,40 @@ private static bool IsEqualWithTolerance<T>(T expected, T actual, T? tolerance =
}
}

// The tests for some types have been marked as OuterLoop simply to decrease inner loop testing time.

public class DoubleGenericTensorPrimitives : GenericFloatingPointNumberTensorPrimitivesTests<double> { }
public class SingleGenericTensorPrimitives : GenericFloatingPointNumberTensorPrimitivesTests<float> { }
public class HalfGenericTensorPrimitives : GenericFloatingPointNumberTensorPrimitivesTests<Half>
{
protected override void AssertEqualTolerance(Half expected, Half actual, Half? tolerance = null) =>
base.AssertEqualTolerance(expected, actual, tolerance ?? Half.CreateTruncating(0.001));
}

[OuterLoop]
public class NFloatGenericTensorPrimitives : GenericFloatingPointNumberTensorPrimitivesTests<NFloat> { }

[OuterLoop]
public class SByteGenericTensorPrimitives : GenericSignedIntegerTensorPrimitivesTests<sbyte> { }
public class Int16GenericTensorPrimitives : GenericSignedIntegerTensorPrimitivesTests<short> { }
[OuterLoop]
public class Int32GenericTensorPrimitives : GenericSignedIntegerTensorPrimitivesTests<int> { }
public class Int64GenericTensorPrimitives : GenericSignedIntegerTensorPrimitivesTests<long> { }
[OuterLoop]
public class IntPtrGenericTensorPrimitives : GenericSignedIntegerTensorPrimitivesTests<nint> { }
public class Int128GenericTensorPrimitives : GenericSignedIntegerTensorPrimitivesTests<Int128> { }

public class ByteGenericTensorPrimitives : GenericIntegerTensorPrimitivesTests<byte> { }
[OuterLoop]
public class UInt16GenericTensorPrimitives : GenericIntegerTensorPrimitivesTests<ushort> { }
[OuterLoop]
public class CharGenericTensorPrimitives : GenericIntegerTensorPrimitivesTests<char> { }
public class UInt32GenericTensorPrimitives : GenericIntegerTensorPrimitivesTests<uint> { }
[OuterLoop]
public class UInt64GenericTensorPrimitives : GenericIntegerTensorPrimitivesTests<ulong> { }

public class UIntPtrGenericTensorPrimitives : GenericIntegerTensorPrimitivesTests<nuint> { }
[OuterLoop]
public class UInt128GenericTensorPrimitives : GenericIntegerTensorPrimitivesTests<UInt128> { }

public unsafe abstract class GenericFloatingPointNumberTensorPrimitivesTests<T> : GenericNumberTensorPrimitivesTests<T>
Expand Down

0 comments on commit 81564a6

Please sign in to comment.