diff --git a/include/wjr/math/mul.hpp b/include/wjr/math/mul.hpp index add785a8..c7de1fa1 100644 --- a/include/wjr/math/mul.hpp +++ b/include/wjr/math/mul.hpp @@ -261,7 +261,7 @@ WJR_INTRINSIC_CONSTEXPR_E T submul_1(T *dst, const T *src, size_t n, #endif #ifndef WJR_TOOM33_MUL_THRESHOLD -#define WJR_TOOM33_MUL_THRESHOLD 64 +#define WJR_TOOM33_MUL_THRESHOLD 80 #endif inline constexpr size_t toom22_mul_threshold = WJR_TOOM22_MUL_THRESHOLD; @@ -270,16 +270,18 @@ inline constexpr size_t toom33_mul_threshold = WJR_TOOM33_MUL_THRESHOLD; template void mul_s(T *dst, const T *src0, size_t n, const T *src1, size_t m); -enum class __mul_s_mode : unsigned int { +enum class __rec_mul_mode : unsigned int { toom22, toom33, - toom24, - // ... }; -template <__mul_s_mode mode, typename T> -WJR_INTRINSIC_INLINE void __mul_s_impl(T *dst, const T *src0, size_t n, const T *src1, - size_t m, T *stk); +template <__rec_mul_mode mode, typename T> +WJR_INTRINSIC_INLINE void __rec_mul_s(T *dst, const T *src0, size_t n, const T *src1, + size_t m, T *stk); + +template <__rec_mul_mode mode, typename T> +WJR_INTRINSIC_INLINE void __rec_mul_n(T *dst, const T *src0, size_t n, const T *src1, + size_t m, T *stk); template void basecase_mul_s(T *dst, const T *src0, size_t n, const T *src1, size_t m); @@ -287,6 +289,12 @@ void basecase_mul_s(T *dst, const T *src0, size_t n, const T *src1, size_t m); template void toom22_mul_s(T *dst, const T *src0, size_t n, const T *src1, size_t m, T *stk); +template +void toom32_mul_s(T *dst, const T *src0, size_t n, const T *src1, size_t m, T *stk); + +template +void toom42_mul_s(T *dst, const T *src0, size_t n, const T *src1, size_t m, T *stk); + template void toom33_mul_s(T *dst, const T *src0, size_t n, const T *src1, size_t m, T *stk); @@ -295,63 +303,262 @@ void mul_s(T *dst, const T *src0, size_t n, const T *src1, size_t m) { WJR_ASSERT(WJR_IS_SAME_OR_INCR_P(dst, n + m, src0, n)); WJR_ASSERT(WJR_IS_SAME_OR_INCR_P(dst, n + m, src1, m)); - if (n < m) { - std::swap(n, m); - std::swap(src0, src1); - } + WJR_ASSERT(n >= m); + WJR_ASSERT(m >= 1); - WJR_ASSERT(m != 0); + if (n < toom22_mul_threshold) { + return basecase_mul_s(dst, src0, n, src1, m); + } - do { - if (WJR_UNLIKELY(m > toom22_mul_threshold && 2 * m > n)) { - break; + if (m < toom22_mul_threshold) { + if (n >= m + toom22_mul_threshold / 4) { + return basecase_mul_s(dst, src0, n, src1, m); } + // n < m + toom22_mul_threshold / 4 <= toom22_mul_threshold - 1 + + // toom22_mul_threshold / 4. Only need to allocate temporary memory once. + constexpr auto maxn = (toom22_mul_threshold - 1) + (toom22_mul_threshold / 4); + constexpr auto maxl = (maxn + 1) / 2; + T stk[2 * maxl]; + return toom22_mul_s(dst, src0, n, src1, m, stk); + } - return basecase_mul_s(dst, src0, n, src1, m); - } while (0); + if (m < toom33_mul_threshold) { + unique_stack_ptr ptr(math_details::stack_alloc, sizeof(T) * (2 * n + 128)); + T *stk = static_cast(ptr.get()); + if (n >= 3 * m) { + unique_stack_ptr tmpp(math_details::stack_alloc, sizeof(T) * (4 * m)); + T *tmp = static_cast(tmpp.get()); - do { - if (WJR_UNLIKELY(m > toom33_mul_threshold && 3 * m > 2 * n)) { - break; + toom42_mul_s(dst, src0, 2 * m, src1, m, stk); + n -= 2 * m; + src1 += 2 * m; + dst += 2 * m; + + T cf = 0; + + while (n >= 3 * m) { + toom42_mul_s(tmp, src0, 2 * m, src1, m, stk); + n -= 2 * m; + src1 += 2 * m; + + cf = addc_n(dst, dst, tmp, m, cf); + std::copy(tmp + m, tmp + 3 * m, dst + m); + cf = addc_1(dst + m, dst + m, 2 * m, 0u, cf); + + dst += 2 * m; + } + + if (4 * n < 5 * m) { + toom22_mul_s(tmp, src0, n, src1, m, stk); + } else if (4 * n < 7 * m) { + toom32_mul_s(tmp, src0, n, src1, m, stk); + } else { + toom42_mul_s(tmp, src0, n, src1, m, stk); + } + + cf = addc_n(dst, dst, tmp, m, cf); + std::copy(tmp + m, tmp + m + n, dst + m); + cf = addc_1(dst + m, dst + m, n, 0u, cf); + WJR_ASSERT(cf == 0); + } else { + if (4 * n < 5 * m) { + toom22_mul_s(dst, src0, n, src1, m, stk); + } else if (4 * n < 7 * m) { + toom32_mul_s(dst, src0, n, src1, m, stk); + } else { + toom42_mul_s(dst, src0, n, src1, m, stk); + } } - unique_stack_ptr ptr(math_details::stack_alloc, sizeof(T) * (n * 2 + 1)); - T *stk = static_cast(ptr.get()); - return toom22_mul_s(dst, src0, n, src1, m, stk); - } while (0); + return; + } - unique_stack_ptr ptr(math_details::stack_alloc, sizeof(T) * (4 * n + 240)); + unique_stack_ptr ptr(math_details::stack_alloc, sizeof(T) * (3 * n + 256)); T *stk = static_cast(ptr.get()); - return toom33_mul_s(dst, src0, n, src1, m, stk); + if (n >= 3 * m) { + unique_stack_ptr tmpp(math_details::stack_alloc, sizeof(T) * (4 * m)); + T *tmp = static_cast(tmpp.get()); + + toom42_mul_s(dst, src0, 2 * m, src1, m, stk); + n -= 2 * m; + src1 += 2 * m; + dst += 2 * m; + + T cf = 0; + + while (n >= 3 * m) { + toom42_mul_s(tmp, src0, 2 * m, src1, m, stk); + n -= 2 * m; + src1 += 2 * m; + + cf = addc_n(dst, dst, tmp, m, cf); + std::copy(tmp + m, tmp + 3 * m, dst + m); + cf = addc_1(dst + m, dst + m, 2 * m, 0u, cf); + + dst += 2 * m; + } + + if (4 * n < 5 * m) { + toom33_mul_s(tmp, src0, n, src1, m, stk); + } else if (4 * n < 7 * m) { + toom32_mul_s(tmp, src0, n, src1, m, stk); + } else { + toom42_mul_s(tmp, src0, n, src1, m, stk); + } + + cf = addc_n(dst, dst, tmp, m, cf); + std::copy(tmp + m, tmp + m + n, dst + m); + cf = addc_1(dst + m, dst + m, n, 0u, cf); + WJR_ASSERT(cf == 0); + } else { + if (4 * n < 5 * m) { + toom33_mul_s(dst, src0, n, src1, m, stk); + } else if (4 * n < 7 * m) { + toom32_mul_s(dst, src0, n, src1, m, stk); + } else { + toom42_mul_s(dst, src0, n, src1, m, stk); + } + } + + return; } -template <__mul_s_mode mode, typename T> -WJR_INTRINSIC_INLINE void __mul_s_impl(T *dst, const T *src0, size_t n, const T *src1, - size_t m, T *stk) { +template <__rec_mul_mode mode, typename T> +WJR_INTRINSIC_INLINE void __rec_mul_s(T *dst, const T *src0, size_t n, const T *src1, + size_t m, T *stk) { WJR_ASSERT(WJR_IS_SAME_OR_INCR_P(dst, n + m, src0, n)); WJR_ASSERT(WJR_IS_SAME_OR_INCR_P(dst, n + m, src1, m)); + WJR_ASSERT(n >= m); - WJR_ASSERT(m != 0); + WJR_ASSERT(m >= 1); - do { - if (WJR_UNLIKELY(m > toom22_mul_threshold && 2 * m > n)) { - break; + if (n < toom22_mul_threshold) { + return basecase_mul_s(dst, src0, n, src1, m); + } + + if (m < toom22_mul_threshold) { + if (n >= m + toom22_mul_threshold / 4) { + return basecase_mul_s(dst, src0, n, src1, m); } + return toom22_mul_s(dst, src0, n, src1, m, stk); + } - return basecase_mul_s(dst, src0, n, src1, m); - } while (0); + if (mode <= __rec_mul_mode::toom22 || m < toom33_mul_threshold) { + if (n >= 3 * m) { + T *tmp = stk; + stk += 4 * m; - do { - if constexpr (mode >= __mul_s_mode::toom33) { - if (WJR_UNLIKELY(m > toom33_mul_threshold && 3 * m > 2 * n)) { - break; + toom42_mul_s(dst, src0, 2 * m, src1, m, stk); + n -= 2 * m; + src1 += 2 * m; + dst += 2 * m; + + T cf = 0; + + while (n >= 3 * m) { + toom42_mul_s(tmp, src0, 2 * m, src1, m, stk); + n -= 2 * m; + src1 += 2 * m; + + cf = addc_n(dst, dst, tmp, m, cf); + std::copy(tmp + m, tmp + 3 * m, dst + m); + cf = addc_1(dst + m, dst + m, 2 * m, 0u, cf); + + dst += 2 * m; + } + + if (4 * n < 5 * m) { + toom22_mul_s(tmp, src0, n, src1, m, stk); + } else if (4 * n < 7 * m) { + toom32_mul_s(tmp, src0, n, src1, m, stk); + } else { + toom42_mul_s(tmp, src0, n, src1, m, stk); + } + + cf = addc_n(dst, dst, tmp, m, cf); + std::copy(tmp + m, tmp + m + n, dst + m); + cf = addc_1(dst + m, dst + m, n, 0u, cf); + WJR_ASSERT(cf == 0); + } else { + if (4 * n < 5 * m) { + toom22_mul_s(dst, src0, n, src1, m, stk); + } else if (4 * n < 7 * m) { + toom32_mul_s(dst, src0, n, src1, m, stk); + } else { + toom42_mul_s(dst, src0, n, src1, m, stk); } } - return toom22_mul_s(dst, src0, n, src1, m, stk); - } while (0); + return; + } + + if (n >= 3 * m) { + T *tmp = stk; + stk += 4 * m; + + toom42_mul_s(dst, src0, 2 * m, src1, m, stk); + n -= 2 * m; + src1 += 2 * m; + dst += 2 * m; - return toom33_mul_s(dst, src0, n, src1, m, stk); + T cf = 0; + + while (n >= 3 * m) { + toom42_mul_s(tmp, src0, 2 * m, src1, m, stk); + n -= 2 * m; + src1 += 2 * m; + + cf = addc_n(dst, dst, tmp, m, cf); + std::copy(tmp + m, tmp + 3 * m, dst + m); + cf = addc_1(dst + m, dst + m, 2 * m, 0u, cf); + + dst += 2 * m; + } + + if (4 * n < 5 * m) { + toom33_mul_s(tmp, src0, n, src1, m, stk); + } else if (4 * n < 7 * m) { + toom32_mul_s(tmp, src0, n, src1, m, stk); + } else { + toom42_mul_s(tmp, src0, n, src1, m, stk); + } + + cf = addc_n(dst, dst, tmp, m, cf); + std::copy(tmp + m, tmp + m + n, dst + m); + cf = addc_1(dst + m, dst + m, n, 0u, cf); + WJR_ASSERT(cf == 0); + } else { + if (4 * n < 5 * m) { + toom33_mul_s(dst, src0, n, src1, m, stk); + } else if (4 * n < 7 * m) { + toom32_mul_s(dst, src0, n, src1, m, stk); + } else { + toom42_mul_s(dst, src0, n, src1, m, stk); + } + } + + return; +} + +template <__rec_mul_mode mode, typename T> +WJR_INTRINSIC_INLINE void __rec_mul_n(T *dst, const T *src0, const T *src1, size_t n, + T *stk) { + WJR_ASSERT(WJR_IS_SAME_OR_INCR_P(dst, n * 2, src0, n)); + WJR_ASSERT(WJR_IS_SAME_OR_INCR_P(dst, n * 2, src1, n)); + + WJR_ASSERT(n >= 1); + + if (n < toom22_mul_threshold) { + return basecase_mul_s(dst, src0, n, src1, n); + } + + if (n < toom33_mul_threshold) { + toom22_mul_s(dst, src0, n, src1, n, stk); + return; + } + + toom33_mul_s(dst, src0, n, src1, n, stk); + return; } template @@ -369,16 +576,14 @@ void toom22_mul_s(T *dst, const T *src0, size_t n, const T *src1, size_t m, T *s // (u0 - u1) * (v0 - v1) - WJR_ASSUME(n >= m); - - const size_t rn = n >> 1; - const size_t l = n - rn; + WJR_ASSERT(n >= m); + WJR_ASSERT(2 * m > n); - WJR_ASSUME(l >= rn); - WJR_ASSUME(l - rn <= 1); + WJR_ASSUME(n >= m); + const size_t l = (n + 1) / 2; + const size_t rn = n - l; const size_t rm = m - l; - WJR_ASSUME(l >= rm); auto u0 = src0; auto u1 = src0 + l; @@ -415,11 +620,11 @@ void toom22_mul_s(T *dst, const T *src0, size_t n, const T *src1, size_t m, T *s break; } - __mul_s_impl<__mul_s_mode::toom22>(wp, p0, l, p1, l, stk); + __rec_mul_n<__rec_mul_mode::toom22>(wp, p0, p1, l, stk); } while (0); - __mul_s_impl<__mul_s_mode::toom22>(p0, u0, l, v0, l, stk); - __mul_s_impl<__mul_s_mode::toom22>(p2, u1, rn, v1, rm, stk); + __rec_mul_n<__rec_mul_mode::toom22>(p0, u0, v0, l, stk); + __rec_mul_s<__rec_mul_mode::toom22>(p2, u1, rn, v1, rm, stk); T cf = 0, cf2 = 0; @@ -438,12 +643,355 @@ void toom22_mul_s(T *dst, const T *src0, size_t n, const T *src1, size_t m, T *s WJR_ASSERT(cf == 0); } +template +void toom32_mul_s(T *dst, const T *src0, size_t n, const T *src1, size_t m, T *stk) { + WJR_ASSERT(3 * m <= 2 * n); + WJR_ASSERT(3 * m > n); + + const size_t l = (n + 2) / 3; + const size_t rn = n - l * 2; + const size_t rm = m - l; + + const auto u0p = src0; + const auto u1p = src0 + l; + const auto u2p = src0 + l * 2; + + const auto v0p = src1; + const auto v1p = src1 + l; + + auto w0p = dst; + auto w1p = stk; + auto w2p = stk + (2 * l + 1); + auto w3p = dst + l * 3; + + stk += 2 * (2 * l + 1); + + T cf0 = 0, cf1 = 0, cf2 = 0, cf3 = 0; + bool neg0 = 0, neg3 = 0; + + // W0 = U0 + U2 : (non-negative) + cf0 = addc_s(w0p, u0p, l, u2p, rn, 0u); + // W3 = V0 + V1 : (non-negative) v(1) + cf3 = addc_s(w3p, v0p, l, v1p, rm, 0u); + // W2 = W0 + U1 : (non-negative) u(1) + cf2 = cf0 + addc_n(w2p, w0p, u1p, l, 0u); + WJR_ASSERT(cf2 <= 2); + + // W1 = W2 * W3 : (non-negative) r(1) = r1 + __rec_mul_n<__rec_mul_mode::toom22>(w1p, w2p, w3p, l, stk); + cf1 = cf2 * cf3; + if (WJR_UNLIKELY(cf2 != 0)) { + if (cf2 == 1) { + cf1 += addc_n(w1p + l, w1p + l, w3p, l, 0u); + } else { + cf1 += addmul_1(w1p + l, w3p, l, cf2); + } + } + if (WJR_UNLIKELY(cf3 != 0)) { + cf1 += addc_n(w1p + l, w1p + l, w2p, l, 0u); + } + w1p[l * 2] = cf1; + + // W0 = W0 - U1 : u(-1) + if (cf0) { + cf0 -= subc_n(w0p, w0p, u1p, l, 0u); + } else { + ptrdiff_t p = abs_subc_n(w0p, w0p, u1p, l); + neg0 = p < 0; + } + WJR_ASSERT(cf0 <= 1); + + // W3 = V0 - V1 : v(-1) + { + ptrdiff_t p = abs_subc_s(w3p, v0p, l, v1p, rm); + neg3 = p < 0; + } + + // W2 = W0 * W3 : r(-1) = r2 + neg0 ^= neg3; + __rec_mul_n<__rec_mul_mode::toom22>(w2p, w0p, w3p, l, stk); + cf2 = 0; + if (WJR_UNLIKELY(cf0 != 0)) { + cf2 += addc_n(w2p + l, w2p + l, w3p, l, 0u); + } + w2p[l * 2] = cf2; + + // W0 = U0 * V0 : (non-negative) r(0) = r0 + __rec_mul_n<__rec_mul_mode::toom22>(w0p, u0p, v0p, l, stk); + + // W3 = U2 * V1 : (non-negative) r(inf) = r3 + __rec_mul_s<__rec_mul_mode::toom22>(w3p, u2p, rn, v1p, rm, stk); + + // W1 = (W1 - W2) >> 1 : (non-negative) (r(1) - r(-1)) / 2 + { + if (!neg0) { + cf1 = subc_n(w1p, w1p, w2p, l * 2 + 1, 0u); + } else { + cf1 = addc_n(w1p, w1p, w2p, l * 2 + 1, 0u); + } + WJR_ASSERT(cf1 == 0); + + rshift_n(w1p, w1p, l * 2 + 1, 1u); + } + + // W2 = (W1 + W2) - W0 : (non-negative) r2 + { + if (!neg0) { + cf2 = addc_n(w2p, w1p, w2p, l * 2 + 1, 0u); + } else { + cf2 = subc_n(w2p, w1p, w2p, l * 2 + 1, 0u); + } + + WJR_ASSERT(cf2 == 0); + cf2 -= subc_s(w2p, w2p, l * 2 + 1, w0p, l * 2, 0u); + WJR_ASSERT(cf2 == 0); + } + + // W1 = W1 - W3 : (non-negative) r1 + cf1 = subc_s(w1p, w1p, l * 2 + 1, w3p, rn + rm, 0u); + WJR_ASSERT(cf1 == 0); + cf1 = w1p[l * 2]; + + // W = W3*x^3+W2*x^2+W1*x+W0 + cf0 = addc_n(w0p + l, w0p + l, w1p, l, 0u); + cf0 = addc_n(dst + l * 2, w1p + l, w2p, l, cf0); + cf0 = addc_s(w3p, w3p, rn + rm, w2p + l, rn + 1, cf0); + WJR_ASSERT(cf0 == 0); + cf0 = addc_1(w3p, w3p, rn + rm, cf1, 0u); + WJR_ASSERT(cf0 == 0); + (void)(cf0); +} + template , int> = 0> WJR_CONSTEXPR_E void divexact_by3(T *dst, const T *src, size_t n); -// TODO : toom33 and toom24 +template +void toom_interpolation_5p_s(T *dst, T *w1p, size_t l, size_t rn, size_t rm, + bool neg2) { + auto w0p = dst; + auto w2p = w1p + (2 * l + 1); + auto w3p = w1p + (2 * l + 1) * 2; + auto w4p = dst + l * 4; + + T cf0 = 0, cf1 = 0, cf2 = 0, cf3 = 0; + + // W3 = (W3 - W1) / 3 : (non-negative) (r(2) - r(-1)) / 3 + { + if (!neg2) { + cf3 = subc_n(w3p, w3p, w1p, l * 2 + 1, 0u); + } else { + cf3 = addc_n(w3p, w3p, w1p, l * 2 + 1, 0u); + } + + WJR_ASSERT(cf3 == 0); + divexact_by3(w3p, w3p, l * 2 + 1); + } + + // W1 = (W2 - W1) >> 1 : (non-negative) (r(1) - r(-1)) / 2 + { + if (!neg2) { + cf1 = subc_n(w1p, w2p, w1p, l * 2 + 1, 0u); + } else { + cf1 = addc_n(w1p, w2p, w1p, l * 2 + 1, 0u); + } + + WJR_ASSERT(cf1 == 0); + rshift_n(w1p, w1p, l * 2 + 1, 1u); + } + + // W2 = W2 - W0 : (non-negative) r(1) - r(0) + cf2 = subc_s(w2p, w2p, l * 2 + 1, w0p, l * 2, 0u); + WJR_ASSERT(cf2 == 0); + + // W3 = ((W3 - W2) >> 1) - (W4 << 1) : (non-negative) r3 + { + cf3 = subc_n(w3p, w3p, w2p, l * 2 + 1, 0u); + WJR_ASSERT(cf3 == 0); + + if (rn != l) { + WJR_ASSERT(w3p[l * 2] == 0); + } + + (void)rshift_n(w3p, w3p, l + rn + 1, 1u); + + T cf5 = lshift_n(dst + l * 2, w4p, rn + rm, 1u); + + cf3 = subc_n(w3p, w3p, dst + l * 2, rn + rm, 0u); + cf3 = subc_1(w3p + rn + rm, w3p + rn + rm, (l + rn + 1) - (rn + rm), cf5, cf3); + WJR_ASSERT(cf3 == 0); + } + + // W2 = W2 - W1 : (non-negative) r(1) / 2 - r(0) + r(-1) / 2 + cf2 -= subc_n(w2p, w2p, w1p, l * 2 + 1, 0u); + WJR_ASSERT(cf2 == 0); + + // W3 = W4 * x + W3 : r4 * x + r3 + cf3 = addc_s(w4p, w4p, rn + rm, w3p + l, rn + 1, 0u); + + // W1 = W2 * x + W1 : + cf2 = addc_s(w2p, w2p, l * 2, w1p + l, l + 1, 0u); + + // W1 = W1 - W3 : // r2 * x + r1 + cf1 = subc_n(w1p, w1p, w3p, l, 0u); + cf1 = cf3 + subc_n(dst + l * 2, w2p, w4p, rn + rm, cf1); + cf2 += w2p[l * 2]; + if (l * 2 != rn + rm) { + cf1 = cf2 - + subc_1(dst + l * 2 + rn + rm, w2p + rn + rm, (l * 2) - (rn + rm), cf1, 0u); + } else { + cf1 = cf2 - cf1; + } + + // W = W3*x^3+ W1*x + W0 + cf0 = addc_n(w0p + l, w0p + l, w1p, l, 0u); + cf0 = addc_1(dst + l * 2, dst + l * 2, l, 0u, cf0); + cf0 = addc_n(dst + l * 3, dst + l * 3, w3p, l, cf0); + cf0 = addc_1(dst + l * 4, dst + l * 4, rn + rm, cf1, cf0); + WJR_ASSERT(cf0 == 0); + (void)(cf0); +} + +template +void toom42_mul_s(T *dst, const T *src0, size_t n, const T *src1, size_t m, T *stk) { + WJR_ASSERT(2 * m <= n); + WJR_ASSERT(4 * m > n); + + WJR_ASSUME(n >= m); + + const size_t l = (n + 3) / 4; + const size_t rn = n - l * 3; + const size_t rm = m - l; + WJR_ASSERT(rm <= l); + + const auto u0p = src0; + const auto u1p = src0 + l; + const auto u2p = src0 + l * 2; + const auto u3p = src0 + l * 3; + + const auto v0p = src1; + const auto v1p = src1 + l; + + auto w0p = dst; + auto w1p = stk; + auto w2p = stk + (2 * l + 1); + auto w3p = stk + (2 * l + 1) * 2; + auto w4p = dst + l * 4; + + stk += 3 * (2 * l + 1); + + T cf0 = 0, cf1 = 0, cf2 = 0, cf3 = 0, cf4 = 0; + bool neg2 = 0, neg3 = 0; + + // W0 = U0 + U2 : (non-negative) + cf0 = addc_n(w0p, u0p, u2p, l, 0u); + // W4 = V0 + V1 : (non-negative) v(1) + cf4 = addc_s(w4p, v0p, l, v1p, rm, 0u); + // W1 = U1 + U3 : (non-negative) + cf1 = addc_s(w1p, u1p, l, u3p, rn, 0u); + + // W3 = W0 - W1 : u(-1) + if (cf0 != cf1) { + cf3 = 1; + if (cf0) { + WJR_ASSERT(cf0 == 1); + cf3 -= subc_n(w3p, w0p, w1p, l, 0u); + } else { + WJR_ASSERT(cf1 == 1); + neg3 = 1; + cf3 -= -subc_n(w3p, w1p, w0p, l, 0u); + } + } else { + ptrdiff_t p = abs_subc_n(w3p, w0p, w1p, l); + neg3 = p < 0; + } + WJR_ASSERT(cf3 <= 1); + + // W2 = V0 - V1 : v(-1) + { + ptrdiff_t p = abs_subc_s(w2p, v0p, l, v1p, rm); + neg2 = p < 0; + } + + // W0 = W0 + W1 : (non-negative) u(1) + cf0 += cf1 + addc_n(w0p, w0p, w1p, l, 0u); + WJR_ASSERT(cf0 <= 3); + + // W1 = W3 * W2 : r(-1) + neg2 ^= neg3; + __rec_mul_n<__rec_mul_mode::toom22>(w1p, w3p, w2p, l, stk); + cf1 = 0; + if (WJR_UNLIKELY(cf3 != 0)) { + cf1 += addc_n(w1p + l, w1p + l, w2p, l, 0u); + } + w1p[l * 2] = cf1; + + // W2 = W0 * W4 : (non-negative) r(1) + __rec_mul_n<__rec_mul_mode::toom22>(w2p, w0p, w4p, l, stk); + cf2 = cf0 * cf4; + if (WJR_UNLIKELY(cf0 != 0)) { + if (cf0 == 1) { + cf2 += addc_n(w2p + l, w2p + l, w4p, l, 0u); + } else { + cf2 += addmul_1(w2p + l, w4p, l, cf0); + } + } + if (WJR_UNLIKELY(cf4 != 0)) { + cf2 += addc_n(w2p + l, w2p + l, w0p, l, 0u); + } + w2p[l * 2] = cf2; + + // W0 = U0 +(U1 +(U2 +U3<<1)<<1)<<1 : (non-negative) u(2) + { + cf0 = lshift_n(w0p, u3p, rn, 1u); + cf0 += addc_n(w0p, w0p, u2p, rn, 0u); + if (l != rn) { + cf0 = addc_1(w0p + rn, w0p + rn, l - rn, cf0, 0u); + } + + cf0 += cf0 + lshift_n(w0p, w0p, l, 1u); + cf0 += addc_n(w0p, w0p, u1p, l, 0u); + cf0 += cf0 + lshift_n(w0p, w0p, n, 1u); + cf0 += addc_n(w0p, w0p, u0p, l, 0u); + WJR_ASSERT(cf0 <= 14); + } + + // W4 = W4 + V1 : (non-negative) v(2) + cf4 += addc_s(w4p, w4p, l, v1p, rm, 0u); + WJR_ASSERT(cf4 <= 2); + + // W3 = W0 * W4 : (non-negative) r(2) + __rec_mul_n<__rec_mul_mode::toom22>(w3p, w0p, w4p, l, stk); + cf3 = cf0 * cf4; + if (WJR_UNLIKELY(cf0 != 0)) { + if (cf0 == 1) { + cf3 += addc_n(w3p + l, w3p + l, w4p, l, 0u); + } else { + cf3 += addmul_1(w3p + l, w4p, l, cf0); + } + } + if (WJR_UNLIKELY(cf4 != 0)) { + if (cf4 == 1) { + cf3 += addc_n(w3p + l, w3p + l, w0p, l, 0u); + } else { + cf3 += addmul_1(w3p + l, w0p, l, 2u); + } + } + w3p[l * 2] = cf3; + + // W0 = U0 * V0 : (non-negative) r(0) = r0 + __rec_mul_n<__rec_mul_mode::toom22>(w0p, u0p, v0p, l, stk); + + // W4 = U3 * V1 : (non-negative) r(inf) = r4 + __rec_mul_s<__rec_mul_mode::toom22>(w4p, u3p, rn, v1p, rm, stk); + + return toom_interpolation_5p_s(dst, w1p, l, rn, rm, neg2); +} + template void toom33_mul_s(T *dst, const T *src0, size_t n, const T *src1, size_t m, T *stk) { + WJR_ASSERT(n >= m); + WJR_ASSERT(3 * m > 2 * n); + WJR_ASSUME(n >= m); const size_t l = (n + 2) / 3; @@ -463,11 +1011,10 @@ void toom33_mul_s(T *dst, const T *src0, size_t n, const T *src1, size_t m, T *s auto w2p = stk + (2 * l + 1); auto w3p = stk + (2 * l + 1) * 2; auto w4p = dst + l * 4; - auto w5p = stk + (2 * l + 1) * 3; - stk += 4 * (2 * l + 1); + stk += 3 * (2 * l + 1); - T cf0 = 0, cf1 = 0, cf2 = 0, cf3 = 0, cf4 = 0, cf5 = 0; + T cf0 = 0, cf1 = 0, cf2 = 0, cf3 = 0, cf4 = 0; bool neg2 = 0, neg3 = 0; // W0 = U0 + U2 : (non-negative) @@ -503,7 +1050,7 @@ void toom33_mul_s(T *dst, const T *src0, size_t n, const T *src1, size_t m, T *s // W1 = W3 * W2 : r(-1) neg2 ^= neg3; - __mul_s_impl<__mul_s_mode::toom33>(w1p, w3p, l, w2p, l, stk); + __rec_mul_n<__rec_mul_mode::toom33>(w1p, w3p, w2p, l, stk); cf1 = cf2 && cf3; if (WJR_UNLIKELY(cf2 != 0)) { cf1 += addc_n(w1p + l, w1p + l, w3p, l, 0u); @@ -514,7 +1061,7 @@ void toom33_mul_s(T *dst, const T *src0, size_t n, const T *src1, size_t m, T *s w1p[l * 2] = cf1; // W2 = W0 * W4 : (non-negative) r(1) - __mul_s_impl<__mul_s_mode::toom33>(w2p, w0p, l, w4p, l, stk); + __rec_mul_n<__rec_mul_mode::toom33>(w2p, w0p, w4p, l, stk); cf2 = cf0 * cf4; if (WJR_UNLIKELY(cf0 != 0)) { if (cf0 == 1) { @@ -547,7 +1094,7 @@ void toom33_mul_s(T *dst, const T *src0, size_t n, const T *src1, size_t m, T *s WJR_ASSERT(cf4 <= 6); // W3 = W0 * W4 : (non-negative) r(2) - __mul_s_impl<__mul_s_mode::toom33>(w3p, w0p, l, w4p, l, stk); + __rec_mul_n<__rec_mul_mode::toom33>(w3p, w0p, w4p, l, stk); cf3 = cf0 * cf4; if (WJR_UNLIKELY(cf0 != 0)) { if (cf0 == 1) { @@ -566,80 +1113,12 @@ void toom33_mul_s(T *dst, const T *src0, size_t n, const T *src1, size_t m, T *s w3p[l * 2] = cf3; // W0 = U0 * V0 : (non-negative) r(0) = r0 - __mul_s_impl<__mul_s_mode::toom33>(w0p, u0p, l, v0p, l, stk); + __rec_mul_n<__rec_mul_mode::toom33>(w0p, u0p, v0p, l, stk); // W4 = U2 * V2 : (non-negative) r(inf) = r4 - __mul_s_impl<__mul_s_mode::toom33>(w4p, u2p, rn, v2p, rm, stk); - - // W3 = (W3 - W1) / 3 : (non-negative) (r(2) - r(-1)) / 3 - { - if (!neg2) { - cf3 = subc_n(w3p, w3p, w1p, l * 2 + 1, 0u); - } else { - cf3 = addc_n(w3p, w3p, w1p, l * 2 + 1, 0u); - } - - WJR_ASSERT(cf3 == 0); - divexact_by3(w3p, w3p, l * 2 + 1); - } + __rec_mul_s<__rec_mul_mode::toom33>(w4p, u2p, rn, v2p, rm, stk); - // W1 = (W2 - W1) >> 1 : (non-negative) (r(1) - r(-1)) / 2 - { - if (!neg2) { - cf1 = subc_n(w1p, w2p, w1p, l * 2 + 1, 0u); - } else { - cf1 = addc_n(w1p, w2p, w1p, l * 2 + 1, 0u); - } - - WJR_ASSERT(cf1 == 0); - rshift_n(w1p, w1p, l * 2 + 1, 1u); - } - - // W2 = W2 - W0 : (non-negative) r(1) - r(0) - cf2 = subc_s(w2p, w2p, l * 2 + 1, w0p, l * 2, 0u); - WJR_ASSERT(cf2 == 0); - - // W3 = ((W3 - W2) >> 1) - (W4 << 1) : (non-negative) r3 - { - cf3 = subc_n(w3p, w3p, w2p, l * 2 + 1, 0u); - WJR_ASSERT(cf3 == 0); - - (void)rshift_n(w3p, w3p, l + rn + 1, 1u); - - cf5 = lshift_n(w5p, w4p, rn + rm, 1u); - - cf3 = subc_n(w3p, w3p, w5p, rn + rm, 0u); - cf3 = - subc_1(w3p + rn + rm, w3p + rn + rm, (l + rn + 1) - (rn + rm), cf3 + cf5, 0u); - WJR_ASSERT(cf3 == 0); - } - - // W2 = W2 - W1 : (non-negative) r(1) / 2 - r(0) + r(-1) / 2 - cf2 -= subc_n(w2p, w2p, w1p, l * 2 + 1, 0u); - WJR_ASSERT(cf2 == 0); - - // W3 = W4 * x + W3 : r4 * x + r3 - cf3 = addc_s(w4p, w4p, rn + rm, w3p + l, rn + 1, 0u); - - // W1 = W2 * x + W1 : - cf2 = addc_s(w2p, w2p, l * 2, w1p + l, l + 1, 0u); - - // W1 = W1 - W3 : // r2 * x + r1 - cf1 = subc_n(w1p, w1p, w3p, l, 0u); - cf1 = cf3 + subc_n(dst + l * 2, w2p, w4p, rn + rm, cf1); - cf2 += w2p[l * 2]; - if (l * 2 != rn + rm) { - cf1 = cf2 - - subc_1(dst + l * 2 + rn + rm, w2p + rn + rm, (l * 2) - (rn + rm), cf1, 0u); - } else { - cf1 = cf2 - cf1; - } - - // W = W3*x^3+ W1*x + W0 - cf0 = addc_n(w0p + l, w0p + l, w1p, l, 0u); - cf0 = addc_1(dst + l * 2, dst + l * 2, l, 0u, cf0); - cf0 = addc_n(dst + l * 3, dst + l * 3, w3p, l, cf0); - addc_1(dst + l * 4, dst + l * 4, rn + rm, cf0 + cf1, 0u); + return toom_interpolation_5p_s(dst, w1p, l, rn, rm, neg2); } } // namespace wjr diff --git a/include/wjr/preprocessor/preview.hpp b/include/wjr/preprocessor/preview.hpp index c4a6223c..262315f8 100644 --- a/include/wjr/preprocessor/preview.hpp +++ b/include/wjr/preprocessor/preview.hpp @@ -32,12 +32,10 @@ // use abort instead of assert when NDEBUG is defined #if defined(NDEBUG) #define WJR_ASSERT_NOMESSAGE_I(expr) \ - if (!WJR_UNLIKELY(expr)) { \ - std::abort(); \ + if (WJR_UNLIKELY(!(expr))) { \ WJR_UNREACHABLE(); \ } #define WJR_ASSERT_MESSAGE_I(expr) \ - std::abort(); \ WJR_UNREACHABLE(); #else #define WJR_ASSERT_NOMESSAGE_I(expr) assert(expr) @@ -114,8 +112,10 @@ #define WJR_IS_SEPARATE_P(p, pn, q, qn) (!WJR_IS_OVERLAP_P(p, pn, q, qn)) #define WJR_IS_SAME_OR_SEPARATE_P(p, pn, q, qn) \ (p == q || WJR_IS_SEPARATE_P(p, pn, q, qn)) -#define WJR_IS_SAME_OR_INCR_P(p, pn, q, qn) (((p) <= (q)) || WJR_IS_SEPARATE_P(p, pn, q, qn)) -#define WJR_IS_SAME_OR_DECR_P(p, pn, q, qn) (((p) >= (q)) || WJR_IS_SEPARATE_P(p, pn, q, qn)) +#define WJR_IS_SAME_OR_INCR_P(p, pn, q, qn) \ + (((p) <= (q)) || WJR_IS_SEPARATE_P(p, pn, q, qn)) +#define WJR_IS_SAME_OR_DECR_P(p, pn, q, qn) \ + (((p) >= (q)) || WJR_IS_SEPARATE_P(p, pn, q, qn)) #define WJR_ASM_PIC_JMPL(LABEL, TABLE) ".long " #LABEL "-" #TABLE #define WJR_ASM_NOPIC_JMPL(LABEL) ".quad " #LABEL