From c4115a9bf5620ad5d7791800e3759be5b4f56b81 Mon Sep 17 00:00:00 2001 From: wjr <1966336874@qq.com> Date: Thu, 18 Jan 2024 13:47:37 +0800 Subject: [PATCH] optimize toom33 and mul_s --- include/wjr/math/mul.hpp | 121 ++++++++++++++++++--------------------- 1 file changed, 57 insertions(+), 64 deletions(-) diff --git a/include/wjr/math/mul.hpp b/include/wjr/math/mul.hpp index de70ea19..81c1095f 100644 --- a/include/wjr/math/mul.hpp +++ b/include/wjr/math/mul.hpp @@ -226,12 +226,17 @@ WJR_INTRINSIC_CONSTEXPR_E T submul_1(T *dst, const T *src0, size_t n, T src1) { #define WJR_TOOM22_MUL_THRESHOLD 26 #endif +#ifndef WJR_TOOM33_MUL_THRESHOLD +#define WJR_TOOM33_MUL_THRESHOLD 64 +#endif + inline constexpr size_t toom22_mul_threshold = WJR_TOOM22_MUL_THRESHOLD; +inline constexpr size_t toom33_mul_threshold = WJR_TOOM33_MUL_THRESHOLD; template void mul_s(T *dst, const T *src0, size_t n, const T *src1, size_t m); -enum class __mul_s_mode { +enum class __mul_s_mode : unsigned int { toom22, toom33, toom24, @@ -239,7 +244,8 @@ enum class __mul_s_mode { }; template <__mul_s_mode mode, typename T> -void __mul_s(T *dst, const T *src0, size_t n, const T *src1, size_t m, T *stk); +WJR_INTRINSIC_INLINE void __mul_s_impl(T *dst, const T *src0, size_t n, const T *src1, + size_t m, T *stk); template void basecase_mul_s(T *dst, const T *src0, size_t n, const T *src1, size_t m); @@ -263,52 +269,55 @@ void mul_s(T *dst, const T *src0, size_t n, const T *src1, size_t m) { WJR_ASSERT(m != 0); do { - if (m <= toom22_mul_threshold) { + if (WJR_UNLIKELY(m > toom22_mul_threshold && 2 * m > n)) { break; } - if (m <= toom22_mul_threshold * 4) { - if (5 * m <= 4 * n) { - break; - } - } + return basecase_mul_s(dst, src0, n, src1, m); + } while (0); - if (5 * m <= 3 * n) { + do { + if (WJR_UNLIKELY(m > toom33_mul_threshold && 3 * m > 2 * n)) { break; } - unique_stack_ptr stk(math_details::stack_alloc, sizeof(T) * n * 2); - return toom22_mul_s(dst, src0, n, src1, m, static_cast(stk.get())); + unique_stack_ptr ptr(math_details::stack_alloc, sizeof(T) * (n * 2 + 1)); + T *stk = static_cast(ptr.get()); + return toom22_mul_s(dst, src0, n, src1, m, stk); } while (0); - return basecase_mul_s(dst, src0, n, src1, m); + unique_stack_ptr ptr(math_details::stack_alloc, sizeof(T) * (4 * n + 24)); + T *stk = static_cast(ptr.get()); + return toom33_mul_s(dst, src0, n, src1, m, stk); } template <__mul_s_mode mode, typename T> -WJR_INTRINSIC_INLINE void __mul_s(T *dst, const T *src0, size_t n, const T *src1, - size_t m, T *stk) { +WJR_INTRINSIC_INLINE void __mul_s_impl(T *dst, const T *src0, size_t n, const T *src1, + size_t m, T *stk) { WJR_ASSERT(WJR_IS_SAME_OR_INCR_P(dst, n + m, src0, n)); WJR_ASSERT(WJR_IS_SAME_OR_INCR_P(dst, n + m, src1, m)); WJR_ASSERT(n >= m); WJR_ASSERT(m != 0); - // TODO do { - if (m <= toom22_mul_threshold || 2 * m <= n) { + if (WJR_UNLIKELY(m > toom22_mul_threshold && 2 * m > n)) { break; } - do { - if (m <= 100 || 3 * m <= 2 * n) { + return basecase_mul_s(dst, src0, n, src1, m); + } while (0); + + do { + if constexpr (mode >= __mul_s_mode::toom33) { + if (WJR_UNLIKELY(m > toom33_mul_threshold && 3 * m > 2 * n)) { break; } - - return toom33_mul_s(dst, src0, n, src1, m, stk); - } while (0); + } + return toom22_mul_s(dst, src0, n, src1, m, stk); } while (0); - return basecase_mul_s(dst, src0, n, src1, m); + return toom33_mul_s(dst, src0, n, src1, m, stk); } template @@ -372,11 +381,11 @@ void toom22_mul_s(T *dst, const T *src0, size_t n, const T *src1, size_t m, T *s break; } - __mul_s<__mul_s_mode::toom22>(wp, p0, l, p1, l, stk); + __mul_s_impl<__mul_s_mode::toom22>(wp, p0, l, p1, l, stk); } while (0); - __mul_s<__mul_s_mode::toom22>(p0, u0, l, v0, l, stk); - __mul_s<__mul_s_mode::toom22>(p2, u1, rn, v1, rm, stk); + __mul_s_impl<__mul_s_mode::toom22>(p0, u0, l, v0, l, stk); + __mul_s_impl<__mul_s_mode::toom22>(p2, u1, rn, v1, rm, stk); T cf = 0, cf2 = 0; @@ -390,9 +399,8 @@ void toom22_mul_s(T *dst, const T *src0, size_t n, const T *src1, size_t m, T *s cf += addc_n(p1, p1, wp, l * 2, 0u); } - cf2 = addc_1(p2, p2, rn + rm, cf2, 0u); - WJR_ASSERT(cf2 == 0); - cf = addc_1(p3, p3, p3n, cf, 0u); + cf2 = addc_1(p2, p2, l, cf2, 0u); + cf = addc_1(p3, p3, p3n, cf, cf2); WJR_ASSERT(cf == 0); } @@ -416,17 +424,17 @@ void toom33_mul_s(T *dst, const T *src0, size_t n, const T *src1, size_t m, T *s const auto v1p = src1 + l; const auto v2p = src1 + l * 2; - auto w0p = stk; - auto w1p = stk + (l + 1) * 2; - auto w2p = stk + (l + 1) * 4; - auto w3p = stk + (l + 1) * 6; - auto w4p = stk + (l + 1) * 8; - auto w5p = stk + (l + 1) * 10; + auto w0p = dst; + auto w1p = stk; + auto w2p = stk + (l + 1) * 2; + auto w3p = stk + (l + 1) * 4; + auto w4p = dst + l * 4; + auto w5p = stk + (l + 1) * 6; - stk += (l + 1) * 12; + stk += (l + 1) * 8; T cf0 = 0, cf1 = 0, cf2 = 0, cf3 = 0, cf4 = 0, cf5 = 0; - bool neg0 = 0, neg1 = 0, neg2 = 0, neg3 = 0, neg4 = 0, neg5 = 0; + bool neg2 = 0, neg3 = 0; // W0 = U0 + U2 : (non-negative) cf0 = addc_s(w0p, u0p, l, u2p, rn, 0u); @@ -466,55 +474,42 @@ void toom33_mul_s(T *dst, const T *src0, size_t n, const T *src1, size_t m, T *s // W1 = W3 * W2 : r(-1) w3p[l] = cf3; w2p[l] = cf2; - cf1 = 0; - neg1 = neg3 ^ neg2; - __mul_s<__mul_s_mode::toom33>(w1p, w3p, l + 1, w2p, l + 1, stk); + neg2 ^= neg3; + __mul_s_impl<__mul_s_mode::toom33>(w1p, w3p, l + 1, w2p, l + 1, stk); // W2 = W0 * W4 : (non-negative) r(1) w0p[l] = cf0; w4p[l] = cf4; - cf2 = 0; - neg2 = 0; - __mul_s<__mul_s_mode::toom33>(w2p, w0p, l + 1, w4p, l + 1, stk); + __mul_s_impl<__mul_s_mode::toom33>(w2p, w0p, l + 1, w4p, l + 1, stk); // W0 = (W0 + U2) << 1 - U0 : (non-negative) u(2) cf0 += addc_s(w0p, w0p, l, u2p, rn, 0u); WJR_ASSERT(cf0 <= 3); cf0 += cf0 + lshift_n(w0p, w0p, l, 1u); - neg0 = 0; cf0 -= subc_n(w0p, w0p, u0p, l, 0u); WJR_ASSERT(cf0 <= 6); // W4 = (W4 + V2) << 1 - V0 : (non-negative) v(2) cf4 += addc_s(w4p, w4p, l, v2p, rm, 0u); + WJR_ASSERT(cf4 <= 3); cf4 += cf4 + lshift_n(w4p, w4p, l, 1u); - neg4 = 0; cf4 -= subc_n(w4p, w4p, v0p, l, 0u); WJR_ASSERT(cf4 <= 6); // W3 = W0 * W4 : (non-negative) r(2) w0p[l] = cf0; w4p[l] = cf4; - cf3 = 0; - neg3 = 0; - __mul_s<__mul_s_mode::toom33>(w3p, w0p, l + 1, w4p, l + 1, stk); - - w0p = dst; - w4p = dst + l * 4; + __mul_s_impl<__mul_s_mode::toom33>(w3p, w0p, l + 1, w4p, l + 1, stk); // W0 = U0 * V0 : (non-negative) r(0) = r0 - cf0 = 0; - neg0 = 0; - __mul_s<__mul_s_mode::toom33>(w0p, u0p, l, v0p, l, stk); + __mul_s_impl<__mul_s_mode::toom33>(w0p, u0p, l, v0p, l, stk); // W4 = U2 * V2 : (non-negative) r(inf) = r4 - cf4 = 0; - neg4 = 0; - __mul_s<__mul_s_mode::toom33>(w4p, u2p, rn, v2p, rm, stk); + __mul_s_impl<__mul_s_mode::toom33>(w4p, u2p, rn, v2p, rm, stk); // W3 = (W3 - W1) / 3 : (non-negative) (r(2) - r(-1)) / 3 { - if (neg3 == neg1) { + if (!neg2) { cf3 = subc_n(w3p, w3p, w1p, l * 2 + 1, 0u); } else { cf3 = addc_n(w3p, w3p, w1p, l * 2 + 1, 0u); @@ -526,10 +521,9 @@ void toom33_mul_s(T *dst, const T *src0, size_t n, const T *src1, size_t m, T *s // W1 = (W2 - W1) >> 1 : (non-negative) (r(1) - r(-1)) / 2 { - if (neg2 == neg1) { + if (!neg2) { cf1 = subc_n(w1p, w2p, w1p, l * 2 + 1, 0u); } else { - neg1 = 0; cf1 = addc_n(w1p, w2p, w1p, l * 2 + 1, 0u); } @@ -538,7 +532,7 @@ void toom33_mul_s(T *dst, const T *src0, size_t n, const T *src1, size_t m, T *s } // W2 = W2 - W0 : (non-negative) r(1) - r(0) - cf2 -= subc_s(w2p, w2p, l * 2 + 1, w0p, l * 2, 0u); + cf2 = subc_s(w2p, w2p, l * 2 + 1, w0p, l * 2, 0u); WJR_ASSERT(cf2 == 0); // W3 = ((W3 - W2) >> 1) - (W4 << 1) : (non-negative) r3 @@ -548,7 +542,6 @@ void toom33_mul_s(T *dst, const T *src0, size_t n, const T *src1, size_t m, T *s (void)rshift_n(w3p, w3p, l + rn + 1, 1u); - neg5 = 0; cf5 = lshift_n(w5p, w4p, rn + rm, 1u); cf3 = subc_n(w3p, w3p, w5p, rn + rm, 0u); @@ -569,16 +562,16 @@ void toom33_mul_s(T *dst, const T *src0, size_t n, const T *src1, size_t m, T *s // W1 = W1 - W3 : // r2 * x + r1 cf1 = subc_n(w1p, w1p, w3p, l, 0u); - cf1 = cf3 + subc_n(w2p, w2p, w4p, rn + rm, cf1); + cf1 = cf3 + subc_n(dst + l * 2, w2p, w4p, rn + rm, cf1); cf2 += w2p[l * 2]; if (l * 2 != rn + rm) { - cf1 = cf2 - subc_1(w2p + rn + rm, w2p + rn + rm, (l * 2) - (rn + rm), cf1, 0u); + cf1 = cf2 - + subc_1(dst + l * 2 + rn + rm, w2p + rn + rm, (l * 2) - (rn + rm), cf1, 0u); } else { cf1 = cf2 - cf1; } // W = W3*x^3+ W1*x + W0 - std::copy(w2p, w2p + l * 2, dst + l * 2); cf0 = addc_n(w0p + l, w0p + l, w1p, l, 0u); cf0 = addc_1(dst + l * 2, dst + l * 2, l, 0u, cf0); cf0 = addc_n(dst + l * 3, dst + l * 3, w3p, l, cf0);