Skip to content

Commit

Permalink
Adding vectorized implementations of Exp to Vector64/128/256/512 (#97114
Browse files Browse the repository at this point in the history
)

* Adding vectorized implementations of Exp to Vector64/128/256/512

* Accelerate TensorPrimitives.Exp for double

* Ensure the right allowedVariance is used for the vectorized exp tests

* Ensure V128/256/512 defers to the next smaller vector size by operating on the lower/upper halves

* Ensure the right allowedVariance amounts are used for the vectorized Exp(float) tests

* Ensure we call Exp and that the methods are properly inlined

* Skip the Exp test for Vector128/256/512 on Mono due to #97176
  • Loading branch information
tannergooding committed Jan 19, 2024
1 parent 59a38f1 commit c53d221
Show file tree
Hide file tree
Showing 12 changed files with 971 additions and 48 deletions.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -1426,6 +1426,40 @@ public static bool EqualsAny<T>(Vector128<T> left, Vector128<T> right)
|| Vector64.EqualsAny(left._upper, right._upper);
}

/// <inheritdoc cref="Vector64.Exp(Vector64{double})" />
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static Vector128<double> Exp(Vector128<double> vector)
{
if (IsHardwareAccelerated)
{
return VectorMath.ExpDouble<Vector128<double>, Vector128<long>, Vector128<ulong>>(vector);
}
else
{
return Create(
Vector64.Exp(vector._lower),
Vector64.Exp(vector._upper)
);
}
}

/// <inheritdoc cref="Vector64.Exp(Vector64{float})" />
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static Vector128<float> Exp(Vector128<float> vector)
{
if (IsHardwareAccelerated)
{
return VectorMath.ExpSingle<Vector128<float>, Vector128<uint>, Vector128<double>, Vector128<ulong>>(vector);
}
else
{
return Create(
Vector64.Exp(vector._lower),
Vector64.Exp(vector._upper)
);
}
}

/// <summary>Extracts the most significant bit from each element in a vector.</summary>
/// <typeparam name="T">The type of the elements in the vector.</typeparam>
/// <param name="vector">The vector whose elements should have their most significant bit extracted.</param>
Expand Down Expand Up @@ -1782,6 +1816,7 @@ internal static Vector128<ushort> LoadUnsafe(ref char source, nuint elementOffse
LoadUnsafe(ref Unsafe.As<char, ushort>(ref source), elementOffset);

/// <inheritdoc cref="Vector64.Log(Vector64{double})" />
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static Vector128<double> Log(Vector128<double> vector)
{
if (IsHardwareAccelerated)
Expand All @@ -1798,6 +1833,7 @@ public static Vector128<double> Log(Vector128<double> vector)
}

/// <inheritdoc cref="Vector64.Log(Vector64{float})" />
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static Vector128<float> Log(Vector128<float> vector)
{
if (IsHardwareAccelerated)
Expand All @@ -1814,6 +1850,7 @@ public static Vector128<float> Log(Vector128<float> vector)
}

/// <inheritdoc cref="Vector64.Log2(Vector64{double})" />
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static Vector128<double> Log2(Vector128<double> vector)
{
if (IsHardwareAccelerated)
Expand All @@ -1830,6 +1867,7 @@ public static Vector128<double> Log2(Vector128<double> vector)
}

/// <inheritdoc cref="Vector64.Log2(Vector64{float})" />
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static Vector128<float> Log2(Vector128<float> vector)
{
if (IsHardwareAccelerated)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1402,6 +1402,40 @@ public static bool EqualsAny<T>(Vector256<T> left, Vector256<T> right)
|| Vector128.EqualsAny(left._upper, right._upper);
}

/// <inheritdoc cref="Vector128.Exp(Vector128{double})" />
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static Vector256<double> Exp(Vector256<double> vector)
{
if (IsHardwareAccelerated)
{
return VectorMath.ExpDouble<Vector256<double>, Vector256<long>, Vector256<ulong>>(vector);
}
else
{
return Create(
Vector128.Exp(vector._lower),
Vector128.Exp(vector._upper)
);
}
}

/// <inheritdoc cref="Vector128.Exp(Vector128{float})" />
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static Vector256<float> Exp(Vector256<float> vector)
{
if (IsHardwareAccelerated)
{
return VectorMath.ExpSingle<Vector256<float>, Vector256<uint>, Vector256<double>, Vector256<ulong>>(vector);
}
else
{
return Create(
Vector128.Exp(vector._lower),
Vector128.Exp(vector._upper)
);
}
}

/// <summary>Extracts the most significant bit from each element in a vector.</summary>
/// <param name="vector">The vector whose elements should have their most significant bit extracted.</param>
/// <typeparam name="T">The type of the elements in the vector.</typeparam>
Expand Down Expand Up @@ -1756,6 +1790,7 @@ internal static Vector256<ushort> LoadUnsafe(ref char source, nuint elementOffse
LoadUnsafe(ref Unsafe.As<char, ushort>(ref source), elementOffset);

/// <inheritdoc cref="Vector128.Log(Vector128{double})" />
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static Vector256<double> Log(Vector256<double> vector)
{
if (IsHardwareAccelerated)
Expand All @@ -1772,6 +1807,7 @@ public static Vector256<double> Log(Vector256<double> vector)
}

/// <inheritdoc cref="Vector128.Log(Vector128{float})" />
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static Vector256<float> Log(Vector256<float> vector)
{
if (IsHardwareAccelerated)
Expand All @@ -1788,6 +1824,7 @@ public static Vector256<float> Log(Vector256<float> vector)
}

/// <inheritdoc cref="Vector128.Log2(Vector128{double})" />
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static Vector256<double> Log2(Vector256<double> vector)
{
if (IsHardwareAccelerated)
Expand All @@ -1804,6 +1841,7 @@ public static Vector256<double> Log2(Vector256<double> vector)
}

/// <inheritdoc cref="Vector128.Log2(Vector128{float})" />
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static Vector256<float> Log2(Vector256<float> vector)
{
if (IsHardwareAccelerated)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1453,6 +1453,40 @@ public static bool EqualsAny<T>(Vector512<T> left, Vector512<T> right)
|| Vector256.EqualsAny(left._upper, right._upper);
}

/// <inheritdoc cref="Vector256.Exp(Vector256{double})" />
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static Vector512<double> Exp(Vector512<double> vector)
{
if (IsHardwareAccelerated)
{
return VectorMath.ExpDouble<Vector512<double>, Vector512<long>, Vector512<ulong>>(vector);
}
else
{
return Create(
Vector256.Exp(vector._lower),
Vector256.Exp(vector._upper)
);
}
}

/// <inheritdoc cref="Vector256.Exp(Vector256{float})" />
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static Vector512<float> Exp(Vector512<float> vector)
{
if (IsHardwareAccelerated)
{
return VectorMath.ExpSingle<Vector512<float>, Vector512<uint>, Vector512<double>, Vector512<ulong>>(vector);
}
else
{
return Create(
Vector256.Exp(vector._lower),
Vector256.Exp(vector._upper)
);
}
}

/// <summary>Extracts the most significant bit from each element in a vector.</summary>
/// <param name="vector">The vector whose elements should have their most significant bit extracted.</param>
/// <typeparam name="T">The type of the elements in the vector.</typeparam>
Expand Down Expand Up @@ -1807,6 +1841,7 @@ internal static Vector512<ushort> LoadUnsafe(ref char source, nuint elementOffse
LoadUnsafe(ref Unsafe.As<char, ushort>(ref source), elementOffset);

/// <inheritdoc cref="Vector256.Log(Vector256{double})" />
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static Vector512<double> Log(Vector512<double> vector)
{
if (IsHardwareAccelerated)
Expand All @@ -1823,6 +1858,7 @@ public static Vector512<double> Log(Vector512<double> vector)
}

/// <inheritdoc cref="Vector256.Log(Vector256{float})" />
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static Vector512<float> Log(Vector512<float> vector)
{
if (IsHardwareAccelerated)
Expand All @@ -1839,6 +1875,7 @@ public static Vector512<float> Log(Vector512<float> vector)
}

/// <inheritdoc cref="Vector256.Log2(Vector256{double})" />
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static Vector512<double> Log2(Vector512<double> vector)
{
if (IsHardwareAccelerated)
Expand All @@ -1855,6 +1892,7 @@ public static Vector512<double> Log2(Vector512<double> vector)
}

/// <inheritdoc cref="Vector256.Log2(Vector256{float})" />
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static Vector512<float> Log2(Vector512<float> vector)
{
if (IsHardwareAccelerated)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1156,6 +1156,52 @@ public static bool EqualsAny<T>(Vector64<T> left, Vector64<T> right)
return false;
}

internal static Vector64<T> Exp<T>(Vector64<T> vector)
where T : IExponentialFunctions<T>
{
Unsafe.SkipInit(out Vector64<T> result);

for (int index = 0; index < Vector64<T>.Count; index++)
{
T value = T.Exp(vector.GetElement(index));
result.SetElementUnsafe(index, value);
}

return result;
}

/// <summary>Computes the exp of each element in a vector.</summary>
/// <param name="vector">The vector that will have its Exp computed.</param>
/// <returns>A vector whose elements are the exp of the elements in <paramref name="vector" />.</returns>
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static Vector64<double> Exp(Vector64<double> vector)
{
if (IsHardwareAccelerated)
{
return VectorMath.ExpDouble<Vector64<double>, Vector64<long>, Vector64<ulong>>(vector);
}
else
{
return Exp<double>(vector);
}
}

/// <summary>Computes the exp of each element in a vector.</summary>
/// <param name="vector">The vector that will have its exp computed.</param>
/// <returns>A vector whose elements are the exp of the elements in <paramref name="vector" />.</returns>
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static Vector64<float> Exp(Vector64<float> vector)
{
if (IsHardwareAccelerated)
{
return VectorMath.ExpSingle<Vector64<float>, Vector64<uint>, Vector64<double>, Vector64<ulong>>(vector);
}
else
{
return Exp<float>(vector);
}
}

/// <summary>Extracts the most significant bit from each element in a vector.</summary>
/// <typeparam name="T">The type of the elements in the vector.</typeparam>
/// <param name="vector">The vector whose elements should have their most significant bit extracted.</param>
Expand Down Expand Up @@ -1588,6 +1634,7 @@ internal static Vector64<T> Log<T>(Vector64<T> vector)
/// <summary>Computes the log of each element in a vector.</summary>
/// <param name="vector">The vector that will have its log computed.</param>
/// <returns>A vector whose elements are the log of the elements in <paramref name="vector" />.</returns>
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static Vector64<double> Log(Vector64<double> vector)
{
if (IsHardwareAccelerated)
Expand All @@ -1603,6 +1650,7 @@ public static Vector64<double> Log(Vector64<double> vector)
/// <summary>Computes the log of each element in a vector.</summary>
/// <param name="vector">The vector that will have its log computed.</param>
/// <returns>A vector whose elements are the log of the elements in <paramref name="vector" />.</returns>
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static Vector64<float> Log(Vector64<float> vector)
{
if (IsHardwareAccelerated)
Expand Down Expand Up @@ -1632,6 +1680,7 @@ internal static Vector64<T> Log2<T>(Vector64<T> vector)
/// <summary>Computes the log2 of each element in a vector.</summary>
/// <param name="vector">The vector that will have its log2 computed.</param>
/// <returns>A vector whose elements are the log2 of the elements in <paramref name="vector" />.</returns>
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static Vector64<double> Log2(Vector64<double> vector)
{
if (IsHardwareAccelerated)
Expand All @@ -1647,6 +1696,7 @@ public static Vector64<double> Log2(Vector64<double> vector)
/// <summary>Computes the log2 of each element in a vector.</summary>
/// <param name="vector">The vector that will have its log2 computed.</param>
/// <returns>A vector whose elements are the log2 of the elements in <paramref name="vector" />.</returns>
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static Vector64<float> Log2(Vector64<float> vector)
{
if (IsHardwareAccelerated)
Expand Down
Loading

0 comments on commit c53d221

Please sign in to comment.