Skip to content

Commit

Permalink
Add TensorPrimitives.ConvertTruncating/Saturating/Checked (#97572)
Browse files Browse the repository at this point in the history
* Add TensorPrimitives.ConvertTruncating/Saturating/Checked

* Fix auto-indentation

* Add comment

* Fix failures
  • Loading branch information
stephentoub authored Jan 27, 2024
1 parent 9bffb0f commit 4fc943c
Show file tree
Hide file tree
Showing 5 changed files with 1,516 additions and 461 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ public static void BitwiseOr<T>(System.ReadOnlySpan<T> x, System.ReadOnlySpan<T>
public static void BitwiseOr<T>(System.ReadOnlySpan<T> x, T y, System.Span<T> destination) where T : System.Numerics.IBitwiseOperators<T, T, T> { }
public static void Cbrt<T>(System.ReadOnlySpan<T> x, System.Span<T> destination) where T : System.Numerics.IRootFunctions<T> { }
public static void Ceiling<T>(System.ReadOnlySpan<T> x, System.Span<T> destination) where T : System.Numerics.IFloatingPoint<T> { }
public static void ConvertChecked<TFrom, TTo>(System.ReadOnlySpan<TFrom> source, System.Span<TTo> destination) where TFrom : System.Numerics.INumberBase<TFrom> where TTo : System.Numerics.INumberBase<TTo> { }
public static void ConvertSaturating<TFrom, TTo>(System.ReadOnlySpan<TFrom> source, System.Span<TTo> destination) where TFrom : System.Numerics.INumberBase<TFrom> where TTo : System.Numerics.INumberBase<TTo> { }
public static void ConvertTruncating<TFrom, TTo>(System.ReadOnlySpan<TFrom> source, System.Span<TTo> destination) where TFrom : System.Numerics.INumberBase<TFrom> where TTo : System.Numerics.INumberBase<TTo> { }
public static void ConvertToHalf(System.ReadOnlySpan<float> source, System.Span<System.Half> destination) { }
public static void ConvertToSingle(System.ReadOnlySpan<System.Half> source, System.Span<float> destination) { }
public static void CopySign<T>(System.ReadOnlySpan<T> x, System.ReadOnlySpan<T> sign, System.Span<T> destination) where T : System.Numerics.INumber<T> { }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ public static unsafe partial class TensorPrimitives
{
private static void InvokeSpanIntoSpan<TSingleUnaryOperator>(
ReadOnlySpan<float> x, Span<float> destination)
where TSingleUnaryOperator : struct, IUnaryOperator<float> =>
where TSingleUnaryOperator : struct, IUnaryOperator<float, float> =>
InvokeSpanIntoSpan<float, TSingleUnaryOperator>(x, destination);

private static void InvokeSpanSpanIntoSpan<TSingleBinaryOperator>(
Expand All @@ -58,7 +58,7 @@ private static void InvokeSpanScalarIntoSpan<TSingleBinaryOperator>(

private static unsafe void InvokeSpanScalarIntoSpan<TSingleTransformOperator, TSingleBinaryOperator>(
ReadOnlySpan<float> x, float y, Span<float> destination)
where TSingleTransformOperator : struct, IUnaryOperator<float>
where TSingleTransformOperator : struct, IUnaryOperator<float, float>
where TSingleBinaryOperator : struct, IBinaryOperator<float> =>
InvokeSpanScalarIntoSpan<float, TSingleTransformOperator, TSingleBinaryOperator>(x, y, destination);

Expand All @@ -79,7 +79,7 @@ private static void InvokeSpanScalarSpanIntoSpan<TSingleTernaryOperator>(

private static unsafe float Aggregate<TSingleTransformOperator, TSingleAggregationOperator>(
ReadOnlySpan<float> x)
where TSingleTransformOperator : struct, IUnaryOperator<float>
where TSingleTransformOperator : struct, IUnaryOperator<float, float>
where TSingleAggregationOperator : struct, IAggregationOperator<float> =>
Aggregate<float, TSingleTransformOperator, TSingleAggregationOperator>(x);

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Runtime.CompilerServices;

namespace System.Numerics.Tensors
{
/// <summary>Performs primitive tensor operations over spans of memory.</summary>
Expand Down Expand Up @@ -488,6 +490,249 @@ public static void Ceiling<T>(ReadOnlySpan<T> x, Span<T> destination)
where T : IFloatingPoint<T> =>
InvokeSpanIntoSpan<T, CeilingOperator<T>>(x, destination);

/// <summary>
/// Copies <paramref name="source"/> to <paramref name="destination"/>, converting each <typeparamref name="TFrom"/>
/// value to a <typeparamref name="TTo"/> value.
/// </summary>
/// <param name="source">The source span from which to copy values.</param>
/// <param name="destination">The destination span into which the converted values should be written.</param>
/// <exception cref="ArgumentException">Destination is too short.</exception>
/// <remarks>
/// <para>
/// This method effectively computes <c><paramref name="destination" />[i] = TTo.CreateChecked(<paramref name="source"/>[i])</c>.
/// </para>
/// </remarks>
public static void ConvertChecked<TFrom, TTo>(ReadOnlySpan<TFrom> source, Span<TTo> destination)
where TFrom : INumberBase<TFrom>
where TTo : INumberBase<TTo>
{
if (!TryConvertUniversal(source, destination))
{
InvokeSpanIntoSpan<TFrom, TTo, ConvertCheckedFallbackOperator<TFrom, TTo>>(source, destination);
}
}

/// <summary>
/// Copies <paramref name="source"/> to <paramref name="destination"/>, converting each <typeparamref name="TFrom"/>
/// value to a <typeparamref name="TTo"/> value.
/// </summary>
/// <param name="source">The source span from which to copy values.</param>
/// <param name="destination">The destination span into which the converted values should be written.</param>
/// <exception cref="ArgumentException">Destination is too short.</exception>
/// <remarks>
/// <para>
/// This method effectively computes <c><paramref name="destination" />[i] = TTo.CreateSaturating(<paramref name="source"/>[i])</c>.
/// </para>
/// </remarks>
public static void ConvertSaturating<TFrom, TTo>(ReadOnlySpan<TFrom> source, Span<TTo> destination)
where TFrom : INumberBase<TFrom>
where TTo : INumberBase<TTo>
{
if (!TryConvertUniversal(source, destination))
{
InvokeSpanIntoSpan<TFrom, TTo, ConvertSaturatingFallbackOperator<TFrom, TTo>>(source, destination);
}
}

/// <summary>
/// Copies <paramref name="source"/> to <paramref name="destination"/>, converting each <typeparamref name="TFrom"/>
/// value to a <typeparamref name="TTo"/> value.
/// </summary>
/// <param name="source">The source span from which to copy values.</param>
/// <param name="destination">The destination span into which the converted values should be written.</param>
/// <exception cref="ArgumentException">Destination is too short.</exception>
/// <remarks>
/// <para>
/// This method effectively computes <c><paramref name="destination" />[i] = TTo.CreateTruncating(<paramref name="source"/>[i])</c>.
/// </para>
/// </remarks>
public static void ConvertTruncating<TFrom, TTo>(ReadOnlySpan<TFrom> source, Span<TTo> destination)
where TFrom : INumberBase<TFrom>
where TTo : INumberBase<TTo>
{
if (TryConvertUniversal(source, destination))
{
return;
}

if (((typeof(TFrom) == typeof(byte) || typeof(TFrom) == typeof(sbyte)) && (typeof(TTo) == typeof(byte) || typeof(TTo) == typeof(sbyte))) ||
((typeof(TFrom) == typeof(ushort) || typeof(TFrom) == typeof(short)) && (typeof(TTo) == typeof(ushort) || typeof(TTo) == typeof(short))) ||
((IsUInt32Like<TFrom>() || IsInt32Like<TFrom>()) && (IsUInt32Like<TTo>() || IsInt32Like<TTo>())) ||
((IsUInt64Like<TFrom>() || IsInt64Like<TFrom>()) && (IsUInt64Like<TTo>() || IsInt64Like<TTo>())))
{
source.CopyTo(Rename<TTo, TFrom>(destination));
return;
}

if (typeof(TFrom) == typeof(float) && IsUInt32Like<TTo>())
{
InvokeSpanIntoSpan<float, uint, ConvertSingleToUInt32>(Rename<TFrom, float>(source), Rename<TTo, uint>(destination));
return;
}

if (typeof(TFrom) == typeof(float) && IsInt32Like<TTo>())
{
InvokeSpanIntoSpan<float, int, ConvertSingleToInt32>(Rename<TFrom, float>(source), Rename<TTo, int>(destination));
return;
}

if (typeof(TFrom) == typeof(double) && IsUInt64Like<TTo>())
{
InvokeSpanIntoSpan<double, ulong, ConvertDoubleToUInt64>(Rename<TFrom, double>(source), Rename<TTo, ulong>(destination));
return;
}

if (typeof(TFrom) == typeof(double) && IsInt64Like<TTo>())
{
InvokeSpanIntoSpan<double, long, ConvertDoubleToInt64>(Rename<TFrom, double>(source), Rename<TTo, long>(destination));
return;
}

if (typeof(TFrom) == typeof(ushort) && typeof(TTo) == typeof(byte))
{
InvokeSpanIntoSpan_2to1<ushort, byte, NarrowUInt16ToByteOperator>(Rename<TFrom, ushort>(source), Rename<TTo, byte>(destination));
return;
}

if (typeof(TFrom) == typeof(short) && typeof(TTo) == typeof(sbyte))
{
InvokeSpanIntoSpan_2to1<short, sbyte, NarrowInt16ToSByteOperator>(Rename<TFrom, short>(source), Rename<TTo, sbyte>(destination));
return;
}

if (IsUInt32Like<TFrom>() && typeof(TTo) == typeof(ushort))
{
InvokeSpanIntoSpan_2to1<uint, ushort, NarrowUInt32ToUInt16Operator>(Rename<TFrom, uint>(source), Rename<TTo, ushort>(destination));
return;
}

if (IsInt32Like<TFrom>() && typeof(TTo) == typeof(short))
{
InvokeSpanIntoSpan_2to1<int, short, NarrowInt32ToInt16Operator>(Rename<TFrom, int>(source), Rename<TTo, short>(destination));
return;
}

if (IsUInt64Like<TFrom>() && IsUInt32Like<TTo>())
{
InvokeSpanIntoSpan_2to1<ulong, uint, NarrowUInt64ToUInt32Operator>(Rename<TFrom, ulong>(source), Rename<TTo, uint>(destination));
return;
}

if (IsInt64Like<TFrom>() && IsInt32Like<TTo>())
{
InvokeSpanIntoSpan_2to1<long, int, NarrowInt64ToInt32Operator>(Rename<TFrom, long>(source), Rename<TTo, int>(destination));
return;
}

InvokeSpanIntoSpan<TFrom, TTo, ConvertTruncatingFallbackOperator<TFrom, TTo>>(source, destination);
}

/// <summary>Performs conversions that are the same regardless of checked, truncating, or saturation.</summary>
[MethodImpl(MethodImplOptions.AggressiveInlining)] // at most one of the branches will be kept
private static bool TryConvertUniversal<TFrom, TTo>(ReadOnlySpan<TFrom> source, Span<TTo> destination)
where TFrom : INumberBase<TFrom>
where TTo : INumberBase<TTo>
{
if (typeof(TFrom) == typeof(TTo))
{
if (source.Length > destination.Length)
{
ThrowHelper.ThrowArgument_DestinationTooShort();
}

ValidateInputOutputSpanNonOverlapping(source, Rename<TTo, TFrom>(destination));

source.CopyTo(Rename<TTo, TFrom>(destination));
return true;
}

if (IsInt32Like<TFrom>() && typeof(TTo) == typeof(float))
{
InvokeSpanIntoSpan<int, float, ConvertInt32ToSingle>(Rename<TFrom, int>(source), Rename<TTo, float>(destination));
return true;
}

if (IsUInt32Like<TFrom>() && typeof(TTo) == typeof(float))
{
InvokeSpanIntoSpan<uint, float, ConvertUInt32ToSingle>(Rename<TFrom, uint>(source), Rename<TTo, float>(destination));
return true;
}

if (IsInt64Like<TFrom>() && typeof(TTo) == typeof(double))
{
InvokeSpanIntoSpan<long, double, ConvertInt64ToDouble>(Rename<TFrom, long>(source), Rename<TTo, double>(destination));
return true;
}

if (IsUInt64Like<TFrom>() && typeof(TTo) == typeof(double))
{
InvokeSpanIntoSpan<ulong, double, ConvertUInt64ToDouble>(Rename<TFrom, ulong>(source), Rename<TTo, double>(destination));
return true;
}

if (typeof(TFrom) == typeof(float) && typeof(TTo) == typeof(Half))
{
ConvertToHalf(Rename<TFrom, float>(source), Rename<TTo, Half>(destination));
return true;
}

if (typeof(TFrom) == typeof(Half) && typeof(TTo) == typeof(float))
{
ConvertToSingle(Rename<TFrom, Half>(source), Rename<TTo, float>(destination));
return true;
}

if (typeof(TFrom) == typeof(float) && typeof(TTo) == typeof(double))
{
InvokeSpanIntoSpan_1to2<float, double, WidenSingleToDoubleOperator>(Rename<TFrom, float>(source), Rename<TTo, double>(destination));
return true;
}

if (typeof(TFrom) == typeof(double) && typeof(TTo) == typeof(float))
{
InvokeSpanIntoSpan_2to1<double, float, NarrowDoubleToSingleOperator>(Rename<TFrom, double>(source), Rename<TTo, float>(destination));
return true;
}

if (typeof(TFrom) == typeof(byte) && typeof(TTo) == typeof(ushort))
{
InvokeSpanIntoSpan_1to2<byte, ushort, WidenByteToUInt16Operator>(Rename<TFrom, byte>(source), Rename<TTo, ushort>(destination));
return true;
}

if (typeof(TFrom) == typeof(sbyte) && typeof(TTo) == typeof(short))
{
InvokeSpanIntoSpan_1to2<sbyte, short, WidenSByteToInt16Operator>(Rename<TFrom, sbyte>(source), Rename<TTo, short>(destination));
return true;
}

if (typeof(TFrom) == typeof(ushort) && IsUInt32Like<TTo>())
{
InvokeSpanIntoSpan_1to2<ushort, uint, WidenUInt16ToUInt32Operator>(Rename<TFrom, ushort>(source), Rename<TTo, uint>(destination));
return true;
}

if (typeof(TFrom) == typeof(short) && IsInt32Like<TTo>())
{
InvokeSpanIntoSpan_1to2<short, int, WidenInt16ToInt32Operator>(Rename<TFrom, short>(source), Rename<TTo, int>(destination));
return true;
}

if (IsUInt32Like<TTo>() && IsUInt64Like<TTo>())
{
InvokeSpanIntoSpan_1to2<uint, ulong, WidenUInt32ToUInt64Operator>(Rename<TFrom, uint>(source), Rename<TTo, ulong>(destination));
return true;
}

if (IsInt32Like<TFrom>() && IsInt64Like<TTo>())
{
InvokeSpanIntoSpan_1to2<int, long, WidenInt32ToInt64Operator>(Rename<TFrom, int>(source), Rename<TTo, long>(destination));
return true;
}

return false;
}

/// <summary>Computes the element-wise result of copying the sign from one number to another number in the specified tensors.</summary>
/// <param name="x">The first tensor, represented as a span.</param>
/// <param name="sign">The second tensor, represented as a span.</param>
Expand Down Expand Up @@ -963,15 +1208,14 @@ public static void Ieee754Remainder<T>(T x, ReadOnlySpan<T> y, Span<T> destinati
public static void ILogB<T>(ReadOnlySpan<T> x, Span<int> destination)
where T : IFloatingPointIeee754<T>
{
if (x.Length > destination.Length)
if (typeof(T) == typeof(double))
{
ThrowHelper.ThrowArgument_DestinationTooShort();
// Special-case double as the only vectorizable floating-point type whose size != sizeof(int).
InvokeSpanIntoSpan_2to1<double, int, ILogBDoubleOperator>(Rename<T, double>(x), destination);
}

// TODO: Vectorize
for (int i = 0; i < x.Length; i++)
else
{
destination[i] = T.ILogB(x[i]);
InvokeSpanIntoSpan<T, int, ILogBOperator<T>>(x, destination);
}
}

Expand Down
Loading

0 comments on commit 4fc943c

Please sign in to comment.