From b4e1355bf6f955ded84751d0e740b8a3c27b1987 Mon Sep 17 00:00:00 2001 From: Ben Adams Date: Wed, 25 Dec 2024 20:13:29 +0000 Subject: [PATCH] Improve AddImpl and SubtractImpl with Avx512 --- src/Nethermind.Int256/UInt256.cs | 104 ++++++++++++++++++------------- 1 file changed, 59 insertions(+), 45 deletions(-) diff --git a/src/Nethermind.Int256/UInt256.cs b/src/Nethermind.Int256/UInt256.cs index 5661359..56feb04 100644 --- a/src/Nethermind.Int256/UInt256.cs +++ b/src/Nethermind.Int256/UInt256.cs @@ -406,24 +406,32 @@ public static bool AddImpl(in UInt256 a, in UInt256 b, out UInt256 res) { if (Avx2.IsSupported) { - var av = Unsafe.As>(ref Unsafe.AsRef(in a)); - var bv = Unsafe.As>(ref Unsafe.AsRef(in b)); + Vector256 av = Unsafe.As>(ref Unsafe.AsRef(in a)); + Vector256 bv = Unsafe.As>(ref Unsafe.AsRef(in b)); - var result = Avx2.Add(av, bv); - - var carryFromBothHighBits = Avx2.And(av, bv); - var eitherHighBit = Avx2.Or(av, bv); - var highBitNotInResult = Avx2.AndNot(result, eitherHighBit); + Vector256 result = Avx2.Add(av, bv); + Vector256 vCarry; + if (Avx512F.VL.IsSupported) + { + vCarry = Avx512F.VL.CompareLessThan(result, av); + } + else + { + // Work around for missing Vector256.CompareLessThan + Vector256 carryFromBothHighBits = Avx2.And(av, bv); + Vector256 eitherHighBit = Avx2.Or(av, bv); + Vector256 highBitNotInResult = Avx2.AndNot(result, eitherHighBit); - // Set high bits where carry occurs - var vCarry = Avx2.Or(carryFromBothHighBits, highBitNotInResult); + // Set high bits where carry occurs + vCarry = Avx2.Or(carryFromBothHighBits, highBitNotInResult); + } // Move carry from Vector space to int - var carry = Avx.MoveMask(Unsafe.As, Vector256>(ref vCarry)); + int carry = Avx.MoveMask(vCarry.AsDouble()); // All bits set will cascade another carry when carry is added to it - var vCascade = Avx2.CompareEqual(result, Vector256.AllBitsSet); + Vector256 vCascade = Avx2.CompareEqual(result, Vector256.AllBitsSet); // Move cascade from Vector space to int - var cascade = Avx.MoveMask(Unsafe.As, Vector256>(ref vCascade)); + int cascade = Avx.MoveMask(Unsafe.As, Vector256>(ref vCascade)); // Use ints to work out the Vector cross lane cascades // Move carry to next bit and add cascade @@ -434,12 +442,12 @@ public static bool AddImpl(in UInt256 a, in UInt256 b, out UInt256 res) cascade &= 0x0f; // Lookup the carries to broadcast to the Vectors - var cascadedCarries = Unsafe.Add(ref Unsafe.As>(ref MemoryMarshal.GetReference(s_broadcastLookup)), cascade); + Vector256 cascadedCarries = Unsafe.Add(ref Unsafe.As>(ref MemoryMarshal.GetReference(s_broadcastLookup)), cascade); - // Mark res as initalized so we can use it as left said of ref assignment + // Mark res as initialized so we can use it as left said of ref assignment Unsafe.SkipInit(out res); // Add the cascadedCarries to the result - Unsafe.As>(ref res) = Avx2.Add(result, cascadedCarries); + Unsafe.As>(ref res) = Avx2.Add(result, cascadedCarries); return (carry & 0b1_0000) != 0; } @@ -458,7 +466,6 @@ public static bool AddImpl(in UInt256 a, in UInt256 b, out UInt256 res) // Debug.Assert((BigInteger)res == ((BigInteger)a + (BigInteger)b) % ((BigInteger)1 << 256)); // #endif } - public void Add(in UInt256 a, out UInt256 res) => Add(this, a, out res); /// @@ -665,7 +672,7 @@ private static void Udivrem(ref ulong quot, ref ulong u, int length, in UInt256 int uLen = 0; for (int i = length - 1; i >= 0; i--) { - if (Unsafe.Add(ref u,i) != 0) + if (Unsafe.Add(ref u, i) != 0) { uLen = i + 1; break; @@ -730,13 +737,13 @@ private static void Udivrem(ref ulong quot, ref ulong u, int length, in UInt256 goto r3; } - r3: + r3: rem2 = Rsh(un[2], shift) | Lsh(un[3], 64 - shift); - r2: + r2: rem1 = Rsh(un[1], shift) | Lsh(un[2], 64 - shift); - r1: + r1: rem0 = Rsh(un[0], shift) | Lsh(un[1], 64 - shift); - r0: + r0: rem = new UInt256(rem0, rem1, rem2, rem3); } @@ -879,25 +886,32 @@ private static bool SubtractImpl(in UInt256 a, in UInt256 b, out UInt256 res) { if (Avx2.IsSupported) { - var av = Unsafe.As>(ref Unsafe.AsRef(in a)); - var bv = Unsafe.As>(ref Unsafe.AsRef(in b)); - - var result = Avx2.Subtract(av, bv); - // Invert top bits as Avx2.CompareGreaterThan is only available for longs, not unsigned - var resultSigned = Avx2.Xor(result, Vector256.Create(0x8000_0000_0000_0000)); - var avSigned = Avx2.Xor(av, Vector256.Create(0x8000_0000_0000_0000)); + Vector256 av = Unsafe.As>(ref Unsafe.AsRef(in a)); + Vector256 bv = Unsafe.As>(ref Unsafe.AsRef(in b)); - // Which vectors need to borrow from the next - var vBorrow = Avx2.CompareGreaterThan(Unsafe.As, Vector256>(ref resultSigned), - Unsafe.As, Vector256>(ref avSigned)); + Vector256 result = Avx2.Subtract(av, bv); + Vector256 vBorrow; + if (Avx512F.VL.IsSupported) + { + vBorrow = Avx512F.VL.CompareGreaterThan(result, av); + } + else + { + // Invert top bits as Avx2.CompareGreaterThan is only available for longs, not unsigned + Vector256 resultSigned = Avx2.Xor(result, Vector256.Create(0x8000_0000_0000_0000)); + Vector256 avSigned = Avx2.Xor(av, Vector256.Create(0x8000_0000_0000_0000)); + // Which vectors need to borrow from the next + vBorrow = Avx2.CompareGreaterThan(Unsafe.As, Vector256>(ref resultSigned), + Unsafe.As, Vector256>(ref avSigned)).AsUInt64(); + } // Move borrow from Vector space to int - var borrow = Avx.MoveMask(Unsafe.As, Vector256>(ref vBorrow)); + int borrow = Avx.MoveMask(vBorrow.AsDouble()); // All zeros will cascade another borrow when borrow is subtracted from it - var vCascade = Avx2.CompareEqual(result, Vector256.Zero); + Vector256 vCascade = Avx2.CompareEqual(result, Vector256.Zero); // Move cascade from Vector space to int - var cascade = Avx.MoveMask(Unsafe.As, Vector256>(ref vCascade)); + int cascade = Avx.MoveMask(vCascade.AsDouble()); // Use ints to work out the Vector cross lane cascades // Move borrow to next bit and add cascade @@ -908,9 +922,9 @@ private static bool SubtractImpl(in UInt256 a, in UInt256 b, out UInt256 res) cascade &= 0x0f; // Lookup the borrows to broadcast to the Vectors - var cascadedBorrows = Unsafe.Add(ref Unsafe.As>(ref MemoryMarshal.GetReference(s_broadcastLookup)), cascade); + Vector256 cascadedBorrows = Unsafe.Add(ref Unsafe.As>(ref MemoryMarshal.GetReference(s_broadcastLookup)), cascade); - // Mark res as initalized so we can use it as left said of ref assignment + // Mark res as initialized so we can use it as left said of ref assignment Unsafe.SkipInit(out res); // Subtract the cascadedBorrows from the result Unsafe.As>(ref res) = Avx2.Subtract(result, cascadedBorrows); @@ -1315,15 +1329,15 @@ public static void Lsh(in UInt256 x, int n, out UInt256 res) a = Rsh(res.u0, 64 - n); z0 = Lsh(res.u0, n); - sh64: + sh64: b = Rsh(res.u1, 64 - n); z1 = Lsh(res.u1, n) | a; - sh128: + sh128: a = Rsh(res.u2, 64 - n); z2 = Lsh(res.u2, n) | b; - sh192: + sh192: z3 = Lsh(res.u3, n) | a; res = new UInt256(z0, z1, z2, z3); @@ -1425,15 +1439,15 @@ public static void Rsh(in UInt256 x, int n, out UInt256 res) a = Lsh(res.u3, 64 - n); z3 = Rsh(res.u3, n); - sh64: + sh64: b = Lsh(res.u2, 64 - n); z2 = Rsh(res.u2, n) | a; - sh128: + sh128: a = Lsh(res.u1, 64 - n); z1 = Rsh(res.u1, n) | b; - sh192: + sh192: z0 = Rsh(res.u0, n) | a; res = new UInt256(z0, z1, z2, z3); @@ -1923,13 +1937,13 @@ public static bool TryParse(in ReadOnlySpan value, NumberStyles style, IFo public TypeCode GetTypeCode() => TypeCode.Object; public bool ToBoolean(IFormatProvider? provider) => !IsZero; public byte ToByte(IFormatProvider? provider) => System.Convert.ToByte(ToDecimal(provider), provider); - public char ToChar(IFormatProvider? provider) => System.Convert.ToChar(ToDecimal(provider), provider); - public DateTime ToDateTime(IFormatProvider? provider) => System.Convert.ToDateTime(ToDecimal(provider), provider); + public char ToChar(IFormatProvider? provider) => System.Convert.ToChar(ToDecimal(provider), provider); + public DateTime ToDateTime(IFormatProvider? provider) => System.Convert.ToDateTime(ToDecimal(provider), provider); public decimal ToDecimal(IFormatProvider? provider) => (decimal)this; public double ToDouble(IFormatProvider? provider) => (double)this; public short ToInt16(IFormatProvider? provider) => System.Convert.ToInt16(ToDecimal(provider), provider); public int ToInt32(IFormatProvider? provider) => System.Convert.ToInt32(ToDecimal(provider), provider); - public long ToInt64(IFormatProvider? provider) => System.Convert.ToInt64(ToDecimal(provider), provider); + public long ToInt64(IFormatProvider? provider) => System.Convert.ToInt64(ToDecimal(provider), provider); public sbyte ToSByte(IFormatProvider? provider) => System.Convert.ToSByte(ToDecimal(provider), provider); public float ToSingle(IFormatProvider? provider) => System.Convert.ToSingle(ToDouble(provider), provider); public string ToString(IFormatProvider? provider) => ((BigInteger)this).ToString(provider);