Skip to content

Commit

Permalink
Preliminary Implementation of Toom Cook-3
Browse files Browse the repository at this point in the history
  • Loading branch information
wjr-z committed Jan 18, 2024
1 parent 2d9078c commit 61a9789
Showing 1 changed file with 100 additions and 96 deletions.
196 changes: 100 additions & 96 deletions include/wjr/math/mul.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define WJR_MATH_MUL_HPP__

#include <wjr/math/add.hpp>
#include <wjr/math/compare.hpp>
#include <wjr/math/shift.hpp>
#include <wjr/math/sub.hpp>

Expand Down Expand Up @@ -231,7 +232,7 @@ template <typename T>
void mul_s(T *dst, const T *src0, size_t n, const T *src1, size_t m);

enum class __mul_s_mode {
toom22 = 0,
toom22,
toom33,
toom24,
// ...
Expand All @@ -241,12 +242,14 @@ template <__mul_s_mode mode, typename T>
void __mul_s(T *dst, const T *src0, size_t n, const T *src1, size_t m, T *stk);

template <typename T>
WJR_INLINE_CONSTEXPR void basecase_mul_s(T *dst, const T *src0, size_t n, const T *src1,
size_t m);
void basecase_mul_s(T *dst, const T *src0, size_t n, const T *src1, size_t m);

template <typename T>
void toom22_mul_s(T *dst, const T *src0, size_t n, const T *src1, size_t m, T *stk);

template <typename T>
void toom33_mul_s(T *dst, const T *src0, size_t n, const T *src1, size_t m, T *stk);

template <typename T>
void mul_s(T *dst, const T *src0, size_t n, const T *src1, size_t m) {
WJR_ASSERT(WJR_IS_SAME_OR_INCR_P(dst, n + m, src0, n));
Expand Down Expand Up @@ -282,37 +285,34 @@ void mul_s(T *dst, const T *src0, size_t n, const T *src1, size_t m) {
}

template <__mul_s_mode mode, typename T>
void __mul_s(T *dst, const T *src0, size_t n, const T *src1, size_t m, T *stk) {
WJR_INTRINSIC_INLINE void __mul_s(T *dst, const T *src0, size_t n, const T *src1,
size_t m, T *stk) {
WJR_ASSERT(WJR_IS_SAME_OR_INCR_P(dst, n + m, src0, n));
WJR_ASSERT(WJR_IS_SAME_OR_INCR_P(dst, n + m, src1, m));
WJR_ASSERT(n >= m);
WJR_ASSERT(m != 0);

// TODO
do {
if (m <= toom22_mul_threshold) {
if (m <= toom22_mul_threshold || 2 * m <= n) {
break;
}

if (m <= toom22_mul_threshold * 4) {
if (5 * m <= 4 * n) {
do {
if (m <= 100 || 3 * m <= 2 * n) {
break;
}
}

if (5 * m <= 3 * n) {
break;
}


return toom33_mul_s(dst, src0, n, src1, m, stk);
} while (0);
return toom22_mul_s(dst, src0, n, src1, m, stk);
} while (0);

return basecase_mul_s(dst, src0, n, src1, m);
}

template <typename T>
WJR_INLINE_CONSTEXPR void basecase_mul_s(T *dst, const T *src0, size_t n, const T *src1,
size_t m) {
void basecase_mul_s(T *dst, const T *src0, size_t n, const T *src1, size_t m) {
dst[n] = mul_1(dst, src0, n, src1[0]);
for (size_t i = 1; i < m; ++i) {
dst[i + n] = addmul_1(dst + i, src0, n, src1[i]);
Expand Down Expand Up @@ -401,7 +401,7 @@ WJR_CONSTEXPR_E void divexact_by3(T *dst, const T *src, size_t n);

// TODO : toom33 and toom24
template <typename T>
void toom33_mul_s(T *dst, const T *src0, size_t n, const T *src1, size_t m) {
void toom33_mul_s(T *dst, const T *src0, size_t n, const T *src1, size_t m, T *stk) {
WJR_ASSUME(n >= m);

const size_t l = (n + 2) / 3;
Expand All @@ -416,169 +416,173 @@ void toom33_mul_s(T *dst, const T *src0, size_t n, const T *src1, size_t m) {
const auto v1p = src1 + l;
const auto v2p = src1 + l * 2;

unique_stack_ptr ptr(math_details::stack_alloc, sizeof(T) * ((l + 1) * 12));
T *stk = (T *)ptr.get();

auto w0p = stk;
auto w1p = stk + (l + 1) * 2;
auto w2p = stk + (l + 1) * 4;
auto w3p = stk + (l + 1) * 6;
auto w4p = stk + (l + 1) * 8;
auto w5p = stk + (l + 1) * 10;

T cf0, cf1, cf2, cf3, cf4, cf5;
bool neg0, neg1, neg2, neg3, neg4, neg5;
stk += (l + 1) * 12;

T cf0 = 0, cf1 = 0, cf2 = 0, cf3 = 0, cf4 = 0, cf5 = 0;
bool neg0 = 0, neg1 = 0, neg2 = 0, neg3 = 0, neg4 = 0, neg5 = 0;

// W0 = U0 + U2 (non-negative)
// W0 = U0 + U2 : (non-negative)
cf0 = addc_s(w0p, u0p, l, u2p, rn, 0u);
// W4 = V0 + V2 (non-negative)
// W4 = V0 + V2 : (non-negative)
cf4 = addc_s(w4p, v0p, l, v2p, rm, 0u);

// W3 = W0 - U1
// W3 = W0 - U1 : u(-1)
if (cf0) {
neg3 = 0;
cf3 = cf0 - subc_n(w3p, stk, u1p, l, 0u);
cf3 = cf0 - subc_n(w3p, w0p, u1p, l, 0u);
WJR_ASSERT(cf3 <= 1);
} else {
cf3 = 0;
ptrdiff_t p = abs_subc_n(w3p, stk, u1p, l);
ptrdiff_t p = abs_subc_n(w3p, w0p, u1p, l);
neg3 = p < 0;
}

// W2 = W4 - V1
// W2 = W4 - V1 : v(-1)
if (cf4) {
neg2 = 0;
cf2 = cf4 - subc_n(w2p, w4p, v1p, l, 0u);
WJR_ASSERT(cf2 <= 1);
} else {
cf2 = 0;
ptrdiff_t p = abs_subc_n(w2p, w4p, v1p, l);
neg2 = p < 0;
}

// W0 = W0 + U1 (non-negative)
// W0 = W0 + U1 : (non-negative) u(1)
cf0 += addc_n(w0p, w0p, u1p, l, 0u);
WJR_ASSERT(cf0 <= 2);

// W4 = W4 + V1 (non-negative)
// W4 = W4 + V1 : (non-negative) v(1)
cf4 += addc_n(w4p, w4p, v1p, l, 0u);
WJR_ASSERT(cf4 <= 2);

// W1 = W3 * W2
// W1 = W3 * W2 : r(-1)
w3p[l] = cf3;
w2p[l] = cf2;
cf1 = 0;
neg1 = neg3 ^ neg2;
mul_s(w1p, w3p, l + 1, w2p, l + 1);
__mul_s<__mul_s_mode::toom33>(w1p, w3p, l + 1, w2p, l + 1, stk);

// W2 = W0 * W4 (non-negative)
// W2 = W0 * W4 : (non-negative) r(1)
w0p[l] = cf0;
w4p[l] = cf4;
cf2 = 0;
neg2 = 0;
mul_s(w2p, w0p, l + 1, w4p, l + 1);
__mul_s<__mul_s_mode::toom33>(w2p, w0p, l + 1, w4p, l + 1, stk);

// W0 = (W0 + U2) << 1 - U0 (non-negative)
// W0 = (W0 + U2) << 1 - U0 : (non-negative) u(2)
cf0 += addc_s(w0p, w0p, l, u2p, rn, 0u);
cf0 += lshift_n(w0p, w0p, l, 1u);
WJR_ASSERT(cf0 <= 3);
cf0 += cf0 + lshift_n(w0p, w0p, l, 1u);
neg0 = 0;
cf0 -= subc_n(w0p, w0p, u0p, l, 0u);
WJR_ASSERT(cf0 <= 6);

// W4 = (W4 + V2) << 1 - V0 (non-negative)
// W4 = (W4 + V2) << 1 - V0 : (non-negative) v(2)
cf4 += addc_s(w4p, w4p, l, v2p, rm, 0u);
cf4 += lshift_n(w4p, w4p, l, 1u);
cf4 += cf4 + lshift_n(w4p, w4p, l, 1u);
neg4 = 0;
cf4 -= subc_n(w4p, w4p, v0p, l, 0u);
WJR_ASSERT(cf4 <= 6);

// W3 = W0 * W4 (non-negative)
// W3 = W0 * W4 : (non-negative) r(2)
w0p[l] = cf0;
w4p[l] = cf4;
cf3 = 0;
neg3 = 0;
mul_s(w3p, w0p, l + 1, w4p, l + 1);
__mul_s<__mul_s_mode::toom33>(w3p, w0p, l + 1, w4p, l + 1, stk);

// W0 = U0 * V0 (non-negative)
w0p = dst;
w4p = dst + l * 4;

// W0 = U0 * V0 : (non-negative) r(0) = r0
cf0 = 0;
neg0 = 0;
mul_s(w0p, u0p, l, v0p, l);
__mul_s<__mul_s_mode::toom33>(w0p, u0p, l, v0p, l, stk);

// W4 = U2 * V2 (non-negative)
// W4 = U2 * V2 : (non-negative) r(inf) = r4
cf4 = 0;
neg4 = 0;
mul_s(w4p, u2p, rn, v2p, rm);
__mul_s<__mul_s_mode::toom33>(w4p, u2p, rn, v2p, rm, stk);

// W3 = (W3 - W1) / 3
// W3 = (W3 - W1) / 3 : (non-negative) (r(2) - r(-1)) / 3
{
if (neg3 == neg1) {
ptrdiff_t p = abs_subc_n(w3p, w3p, w1p, (l + 1) * 2);
neg3 ^= p < 0;
cf3 = subc_n(w3p, w3p, w1p, l * 2 + 1, 0u);
} else {
cf3 = addc_n(w3p, w3p, w1p, (l + 1) * 2, 0u);
cf3 = addc_n(w3p, w3p, w1p, l * 2 + 1, 0u);
}

cf3 = 0;
divexact_by3(w3p, w3p, (l + 1) * 2);
WJR_ASSERT(cf3 == 0);
divexact_by3(w3p, w3p, l * 2 + 1);
}

// W1 = (W2 - W1) >> 1
// W1 = (W2 - W1) >> 1 : (non-negative) (r(1) - r(-1)) / 2
{
if (neg2 == neg1) {
ptrdiff_t p = abs_subc_n(w1p, w2p, w1p, (l + 1) * 2);
neg1 ^= p < 0;
cf1 = subc_n(w1p, w2p, w1p, l * 2 + 1, 0u);
} else {
neg1 ^= 1;
cf1 = addc_n(w1p, w2p, w1p, (l + 1) * 2, 0u);
neg1 = 0;
cf1 = addc_n(w1p, w2p, w1p, l * 2 + 1, 0u);
}

rshift_n(w1p, w1p, (l + 1) * 2, 1u);
if (cf1) {
w1p[(l + 1) * 2 - 1] |= (cf1) << 63;
cf1 >>= 1;
}
WJR_ASSERT(cf1 == 0);
rshift_n(w1p, w1p, l * 2 + 1, 1u);
}

// W2 = W2 - W0 (non-negative)
{ cf2 -= subc_n(w2p, w2p, w0p, (l + 1) * 2); }
// W2 = W2 - W0 : (non-negative) r(1) - r(0)
cf2 -= subc_s(w2p, w2p, l * 2 + 1, w0p, l * 2, 0u);
WJR_ASSERT(cf2 == 0);

// W3 = (W3 - W2) >> 1 - W4 << 1
// W3 = ((W3 - W2) >> 1) - (W4 << 1) : (non-negative) r3
{
if (neg3 == neg2) {
if (cf2) {
neg3 ^= 1;
cf3 = cf2 - subc_n(w3p, w2p, w3p, (l + 1) * 2, 0u);
} else {
cf3 = 0;
ptrdiff_t p = abs_subc_n(w3p, w3p, w2p, (l + 1) * 2);
neg3 ^= p < 0;
}
} else {
cf3 = cf2 + addc_n(w3p, w3p, w2p, (l + 1) * 2, 0u);
}

(void)rshift_n(w3p, w3p, (l + 1) * 2, 1u);
if (cf3) {
w3p[(l + 1) * 2 - 1] |= cf3 << 63;
cf3 >>= 1;
}
cf3 = subc_n(w3p, w3p, w2p, l * 2 + 1, 0u);
WJR_ASSERT(cf3 == 0);

// TODO :
(void)rshift_n(w3p, w3p, l + rn + 1, 1u);

// neg5 = 0;
// cf5 = lshift_n(w5p, w4p, (l + 1) * 2, 1u);
neg5 = 0;
cf5 = lshift_n(w5p, w4p, rn + rm, 1u);

if (neg3 == neg5) {
} else {
cf3 += cf5;
cf3 += addc_n(w3p, w3p, w5p, (l + 1) * 2, 0u);
}
cf3 = subc_n(w3p, w3p, w5p, rn + rm, 0u);
cf3 =
subc_1(w3p + rn + rm, w3p + rn + rm, (l + rn + 1) - (rn + rm), cf3 + cf5, 0u);
WJR_ASSERT(cf3 == 0);
}

// W2 = W2 - W1 (non-negative)
{
if (neg2 == neg1) {
cf2 -= subc_n(w2p, w2p, w1p, (l + 1) * 2, 0u);
} else {
cf2 += addc_n(w2p, w2p, w1p, (l + 1) * 2, 0u);
}
// W2 = W2 - W1 : (non-negative) r(1) / 2 - r(0) + r(-1) / 2
cf2 -= subc_n(w2p, w2p, w1p, l * 2 + 1, 0u);
WJR_ASSERT(cf2 == 0);

// W3 = W4 * x + W3 : r4 * x + r3
cf3 = addc_s(w4p, w4p, rn + rm, w3p + l, rn + 1, 0u);

// W1 = W2 * x + W1 :
cf2 = addc_s(w2p, w2p, l * 2, w1p + l, l + 1, 0u);

// W1 = W1 - W3 : // r2 * x + r1
cf1 = subc_n(w1p, w1p, w3p, l, 0u);
cf1 = cf3 + subc_n(w2p, w2p, w4p, rn + rm, cf1);
cf2 += w2p[l * 2];
if (l * 2 != rn + rm) {
cf1 = cf2 - subc_1(w2p + rn + rm, w2p + rn + rm, (l * 2) - (rn + rm), cf1, 0u);
} else {
cf1 = cf2 - cf1;
}

// W3 = W4 * x + W3
// W = W3*x^3+ W1*x + W0
std::copy(w2p, w2p + l * 2, dst + l * 2);
cf0 = addc_n(w0p + l, w0p + l, w1p, l, 0u);
cf0 = addc_1(dst + l * 2, dst + l * 2, l, 0u, cf0);
cf0 = addc_n(dst + l * 3, dst + l * 3, w3p, l, cf0);
addc_1(dst + l * 4, dst + l * 4, rn + rm, cf0 + cf1, 0u);
}

} // namespace wjr
Expand Down

0 comments on commit 61a9789

Please sign in to comment.