Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve AddImpl and SubtractImpl with Avx512 #41

Merged
merged 1 commit into from
Dec 26, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 59 additions & 45 deletions src/Nethermind.Int256/UInt256.cs
Original file line number Diff line number Diff line change
Expand Up @@ -406,24 +406,32 @@ public static bool AddImpl(in UInt256 a, in UInt256 b, out UInt256 res)
{
if (Avx2.IsSupported)
{
var av = Unsafe.As<UInt256,Vector256<ulong>>(ref Unsafe.AsRef(in a));
var bv = Unsafe.As<UInt256,Vector256<ulong>>(ref Unsafe.AsRef(in b));
Vector256<ulong> av = Unsafe.As<UInt256, Vector256<ulong>>(ref Unsafe.AsRef(in a));
Vector256<ulong> bv = Unsafe.As<UInt256, Vector256<ulong>>(ref Unsafe.AsRef(in b));

var result = Avx2.Add(av, bv);

var carryFromBothHighBits = Avx2.And(av, bv);
var eitherHighBit = Avx2.Or(av, bv);
var highBitNotInResult = Avx2.AndNot(result, eitherHighBit);
Vector256<ulong> result = Avx2.Add(av, bv);
Vector256<ulong> vCarry;
if (Avx512F.VL.IsSupported)
{
vCarry = Avx512F.VL.CompareLessThan(result, av);
}
else
{
// Work around for missing Vector256.CompareLessThan
Vector256<ulong> carryFromBothHighBits = Avx2.And(av, bv);
Vector256<ulong> eitherHighBit = Avx2.Or(av, bv);
Vector256<ulong> highBitNotInResult = Avx2.AndNot(result, eitherHighBit);

// Set high bits where carry occurs
var vCarry = Avx2.Or(carryFromBothHighBits, highBitNotInResult);
// Set high bits where carry occurs
vCarry = Avx2.Or(carryFromBothHighBits, highBitNotInResult);
}
// Move carry from Vector space to int
var carry = Avx.MoveMask(Unsafe.As<Vector256<ulong>, Vector256<double>>(ref vCarry));
int carry = Avx.MoveMask(vCarry.AsDouble());

// All bits set will cascade another carry when carry is added to it
var vCascade = Avx2.CompareEqual(result, Vector256<ulong>.AllBitsSet);
Vector256<ulong> vCascade = Avx2.CompareEqual(result, Vector256<ulong>.AllBitsSet);
// Move cascade from Vector space to int
var cascade = Avx.MoveMask(Unsafe.As<Vector256<ulong>, Vector256<double>>(ref vCascade));
int cascade = Avx.MoveMask(Unsafe.As<Vector256<ulong>, Vector256<double>>(ref vCascade));

// Use ints to work out the Vector cross lane cascades
// Move carry to next bit and add cascade
Expand All @@ -434,12 +442,12 @@ public static bool AddImpl(in UInt256 a, in UInt256 b, out UInt256 res)
cascade &= 0x0f;

// Lookup the carries to broadcast to the Vectors
var cascadedCarries = Unsafe.Add(ref Unsafe.As<byte, Vector256<ulong>>(ref MemoryMarshal.GetReference(s_broadcastLookup)), cascade);
Vector256<ulong> cascadedCarries = Unsafe.Add(ref Unsafe.As<byte, Vector256<ulong>>(ref MemoryMarshal.GetReference(s_broadcastLookup)), cascade);

// Mark res as initalized so we can use it as left said of ref assignment
// Mark res as initialized so we can use it as left said of ref assignment
Unsafe.SkipInit(out res);
// Add the cascadedCarries to the result
Unsafe.As<UInt256,Vector256<ulong>>(ref res) = Avx2.Add(result, cascadedCarries);
Unsafe.As<UInt256, Vector256<ulong>>(ref res) = Avx2.Add(result, cascadedCarries);

return (carry & 0b1_0000) != 0;
}
Expand All @@ -458,7 +466,6 @@ public static bool AddImpl(in UInt256 a, in UInt256 b, out UInt256 res)
// Debug.Assert((BigInteger)res == ((BigInteger)a + (BigInteger)b) % ((BigInteger)1 << 256));
// #endif
}

public void Add(in UInt256 a, out UInt256 res) => Add(this, a, out res);

/// <summary>
Expand Down Expand Up @@ -665,7 +672,7 @@ private static void Udivrem(ref ulong quot, ref ulong u, int length, in UInt256
int uLen = 0;
for (int i = length - 1; i >= 0; i--)
{
if (Unsafe.Add(ref u,i) != 0)
if (Unsafe.Add(ref u, i) != 0)
{
uLen = i + 1;
break;
Expand Down Expand Up @@ -730,13 +737,13 @@ private static void Udivrem(ref ulong quot, ref ulong u, int length, in UInt256
goto r3;
}

r3:
r3:
rem2 = Rsh(un[2], shift) | Lsh(un[3], 64 - shift);
r2:
r2:
rem1 = Rsh(un[1], shift) | Lsh(un[2], 64 - shift);
r1:
r1:
rem0 = Rsh(un[0], shift) | Lsh(un[1], 64 - shift);
r0:
r0:

rem = new UInt256(rem0, rem1, rem2, rem3);
}
Expand Down Expand Up @@ -879,25 +886,32 @@ private static bool SubtractImpl(in UInt256 a, in UInt256 b, out UInt256 res)
{
if (Avx2.IsSupported)
{
var av = Unsafe.As<UInt256, Vector256<ulong>>(ref Unsafe.AsRef(in a));
var bv = Unsafe.As<UInt256, Vector256<ulong>>(ref Unsafe.AsRef(in b));

var result = Avx2.Subtract(av, bv);
// Invert top bits as Avx2.CompareGreaterThan is only available for longs, not unsigned
var resultSigned = Avx2.Xor(result, Vector256.Create<ulong>(0x8000_0000_0000_0000));
var avSigned = Avx2.Xor(av, Vector256.Create<ulong>(0x8000_0000_0000_0000));
Vector256<ulong> av = Unsafe.As<UInt256, Vector256<ulong>>(ref Unsafe.AsRef(in a));
Vector256<ulong> bv = Unsafe.As<UInt256, Vector256<ulong>>(ref Unsafe.AsRef(in b));

// Which vectors need to borrow from the next
var vBorrow = Avx2.CompareGreaterThan(Unsafe.As<Vector256<ulong>, Vector256<long>>(ref resultSigned),
Unsafe.As<Vector256<ulong>, Vector256<long>>(ref avSigned));
Vector256<ulong> result = Avx2.Subtract(av, bv);
Vector256<ulong> vBorrow;
if (Avx512F.VL.IsSupported)
{
vBorrow = Avx512F.VL.CompareGreaterThan(result, av);
}
else
{
// Invert top bits as Avx2.CompareGreaterThan is only available for longs, not unsigned
Vector256<ulong> resultSigned = Avx2.Xor(result, Vector256.Create<ulong>(0x8000_0000_0000_0000));
Vector256<ulong> avSigned = Avx2.Xor(av, Vector256.Create<ulong>(0x8000_0000_0000_0000));

// Which vectors need to borrow from the next
vBorrow = Avx2.CompareGreaterThan(Unsafe.As<Vector256<ulong>, Vector256<long>>(ref resultSigned),
Unsafe.As<Vector256<ulong>, Vector256<long>>(ref avSigned)).AsUInt64();
}
// Move borrow from Vector space to int
var borrow = Avx.MoveMask(Unsafe.As<Vector256<long>, Vector256<double>>(ref vBorrow));
int borrow = Avx.MoveMask(vBorrow.AsDouble());

// All zeros will cascade another borrow when borrow is subtracted from it
var vCascade = Avx2.CompareEqual(result, Vector256<ulong>.Zero);
Vector256<ulong> vCascade = Avx2.CompareEqual(result, Vector256<ulong>.Zero);
// Move cascade from Vector space to int
var cascade = Avx.MoveMask(Unsafe.As<Vector256<ulong>, Vector256<double>>(ref vCascade));
int cascade = Avx.MoveMask(vCascade.AsDouble());

// Use ints to work out the Vector cross lane cascades
// Move borrow to next bit and add cascade
Expand All @@ -908,9 +922,9 @@ private static bool SubtractImpl(in UInt256 a, in UInt256 b, out UInt256 res)
cascade &= 0x0f;

// Lookup the borrows to broadcast to the Vectors
var cascadedBorrows = Unsafe.Add(ref Unsafe.As<byte, Vector256<ulong>>(ref MemoryMarshal.GetReference(s_broadcastLookup)), cascade);
Vector256<ulong> cascadedBorrows = Unsafe.Add(ref Unsafe.As<byte, Vector256<ulong>>(ref MemoryMarshal.GetReference(s_broadcastLookup)), cascade);

// Mark res as initalized so we can use it as left said of ref assignment
// Mark res as initialized so we can use it as left said of ref assignment
Unsafe.SkipInit(out res);
// Subtract the cascadedBorrows from the result
Unsafe.As<UInt256, Vector256<ulong>>(ref res) = Avx2.Subtract(result, cascadedBorrows);
Expand Down Expand Up @@ -1315,15 +1329,15 @@ public static void Lsh(in UInt256 x, int n, out UInt256 res)
a = Rsh(res.u0, 64 - n);
z0 = Lsh(res.u0, n);

sh64:
sh64:
b = Rsh(res.u1, 64 - n);
z1 = Lsh(res.u1, n) | a;

sh128:
sh128:
a = Rsh(res.u2, 64 - n);
z2 = Lsh(res.u2, n) | b;

sh192:
sh192:
z3 = Lsh(res.u3, n) | a;

res = new UInt256(z0, z1, z2, z3);
Expand Down Expand Up @@ -1425,15 +1439,15 @@ public static void Rsh(in UInt256 x, int n, out UInt256 res)
a = Lsh(res.u3, 64 - n);
z3 = Rsh(res.u3, n);

sh64:
sh64:
b = Lsh(res.u2, 64 - n);
z2 = Rsh(res.u2, n) | a;

sh128:
sh128:
a = Lsh(res.u1, 64 - n);
z1 = Rsh(res.u1, n) | b;

sh192:
sh192:
z0 = Rsh(res.u0, n) | a;

res = new UInt256(z0, z1, z2, z3);
Expand Down Expand Up @@ -1923,13 +1937,13 @@ public static bool TryParse(in ReadOnlySpan<char> value, NumberStyles style, IFo
public TypeCode GetTypeCode() => TypeCode.Object;
public bool ToBoolean(IFormatProvider? provider) => !IsZero;
public byte ToByte(IFormatProvider? provider) => System.Convert.ToByte(ToDecimal(provider), provider);
public char ToChar(IFormatProvider? provider) => System.Convert.ToChar(ToDecimal(provider), provider);
public DateTime ToDateTime(IFormatProvider? provider) => System.Convert.ToDateTime(ToDecimal(provider), provider);
public char ToChar(IFormatProvider? provider) => System.Convert.ToChar(ToDecimal(provider), provider);
public DateTime ToDateTime(IFormatProvider? provider) => System.Convert.ToDateTime(ToDecimal(provider), provider);
public decimal ToDecimal(IFormatProvider? provider) => (decimal)this;
public double ToDouble(IFormatProvider? provider) => (double)this;
public short ToInt16(IFormatProvider? provider) => System.Convert.ToInt16(ToDecimal(provider), provider);
public int ToInt32(IFormatProvider? provider) => System.Convert.ToInt32(ToDecimal(provider), provider);
public long ToInt64(IFormatProvider? provider) => System.Convert.ToInt64(ToDecimal(provider), provider);
public long ToInt64(IFormatProvider? provider) => System.Convert.ToInt64(ToDecimal(provider), provider);
public sbyte ToSByte(IFormatProvider? provider) => System.Convert.ToSByte(ToDecimal(provider), provider);
public float ToSingle(IFormatProvider? provider) => System.Convert.ToSingle(ToDouble(provider), provider);
public string ToString(IFormatProvider? provider) => ((BigInteger)this).ToString(provider);
Expand Down
Loading