Skip to content

Commit

Permalink
optimize toom33 and mul_s
Browse files Browse the repository at this point in the history
  • Loading branch information
wjr-z committed Jan 18, 2024
1 parent 61a9789 commit c4115a9
Showing 1 changed file with 57 additions and 64 deletions.
121 changes: 57 additions & 64 deletions include/wjr/math/mul.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -226,20 +226,26 @@ WJR_INTRINSIC_CONSTEXPR_E T submul_1(T *dst, const T *src0, size_t n, T src1) {
#define WJR_TOOM22_MUL_THRESHOLD 26
#endif

#ifndef WJR_TOOM33_MUL_THRESHOLD
#define WJR_TOOM33_MUL_THRESHOLD 64
#endif

inline constexpr size_t toom22_mul_threshold = WJR_TOOM22_MUL_THRESHOLD;
inline constexpr size_t toom33_mul_threshold = WJR_TOOM33_MUL_THRESHOLD;

template <typename T>
void mul_s(T *dst, const T *src0, size_t n, const T *src1, size_t m);

enum class __mul_s_mode {
enum class __mul_s_mode : unsigned int {
toom22,
toom33,
toom24,
// ...
};

template <__mul_s_mode mode, typename T>
void __mul_s(T *dst, const T *src0, size_t n, const T *src1, size_t m, T *stk);
WJR_INTRINSIC_INLINE void __mul_s_impl(T *dst, const T *src0, size_t n, const T *src1,
size_t m, T *stk);

template <typename T>
void basecase_mul_s(T *dst, const T *src0, size_t n, const T *src1, size_t m);
Expand All @@ -263,52 +269,55 @@ void mul_s(T *dst, const T *src0, size_t n, const T *src1, size_t m) {
WJR_ASSERT(m != 0);

do {
if (m <= toom22_mul_threshold) {
if (WJR_UNLIKELY(m > toom22_mul_threshold && 2 * m > n)) {
break;
}

if (m <= toom22_mul_threshold * 4) {
if (5 * m <= 4 * n) {
break;
}
}
return basecase_mul_s(dst, src0, n, src1, m);
} while (0);

if (5 * m <= 3 * n) {
do {
if (WJR_UNLIKELY(m > toom33_mul_threshold && 3 * m > 2 * n)) {
break;
}

unique_stack_ptr stk(math_details::stack_alloc, sizeof(T) * n * 2);
return toom22_mul_s(dst, src0, n, src1, m, static_cast<T *>(stk.get()));
unique_stack_ptr ptr(math_details::stack_alloc, sizeof(T) * (n * 2 + 1));
T *stk = static_cast<T *>(ptr.get());
return toom22_mul_s(dst, src0, n, src1, m, stk);
} while (0);

return basecase_mul_s(dst, src0, n, src1, m);
unique_stack_ptr ptr(math_details::stack_alloc, sizeof(T) * (4 * n + 24));
T *stk = static_cast<T *>(ptr.get());
return toom33_mul_s(dst, src0, n, src1, m, stk);
}

template <__mul_s_mode mode, typename T>
WJR_INTRINSIC_INLINE void __mul_s(T *dst, const T *src0, size_t n, const T *src1,
size_t m, T *stk) {
WJR_INTRINSIC_INLINE void __mul_s_impl(T *dst, const T *src0, size_t n, const T *src1,
size_t m, T *stk) {
WJR_ASSERT(WJR_IS_SAME_OR_INCR_P(dst, n + m, src0, n));
WJR_ASSERT(WJR_IS_SAME_OR_INCR_P(dst, n + m, src1, m));
WJR_ASSERT(n >= m);
WJR_ASSERT(m != 0);

// TODO
do {
if (m <= toom22_mul_threshold || 2 * m <= n) {
if (WJR_UNLIKELY(m > toom22_mul_threshold && 2 * m > n)) {
break;
}

do {
if (m <= 100 || 3 * m <= 2 * n) {
return basecase_mul_s(dst, src0, n, src1, m);
} while (0);

do {
if constexpr (mode >= __mul_s_mode::toom33) {
if (WJR_UNLIKELY(m > toom33_mul_threshold && 3 * m > 2 * n)) {
break;
}

return toom33_mul_s(dst, src0, n, src1, m, stk);
} while (0);
}

return toom22_mul_s(dst, src0, n, src1, m, stk);
} while (0);

return basecase_mul_s(dst, src0, n, src1, m);
return toom33_mul_s(dst, src0, n, src1, m, stk);
}

template <typename T>
Expand Down Expand Up @@ -372,11 +381,11 @@ void toom22_mul_s(T *dst, const T *src0, size_t n, const T *src1, size_t m, T *s
break;
}

__mul_s<__mul_s_mode::toom22>(wp, p0, l, p1, l, stk);
__mul_s_impl<__mul_s_mode::toom22>(wp, p0, l, p1, l, stk);
} while (0);

__mul_s<__mul_s_mode::toom22>(p0, u0, l, v0, l, stk);
__mul_s<__mul_s_mode::toom22>(p2, u1, rn, v1, rm, stk);
__mul_s_impl<__mul_s_mode::toom22>(p0, u0, l, v0, l, stk);
__mul_s_impl<__mul_s_mode::toom22>(p2, u1, rn, v1, rm, stk);

T cf = 0, cf2 = 0;

Expand All @@ -390,9 +399,8 @@ void toom22_mul_s(T *dst, const T *src0, size_t n, const T *src1, size_t m, T *s
cf += addc_n(p1, p1, wp, l * 2, 0u);
}

cf2 = addc_1(p2, p2, rn + rm, cf2, 0u);
WJR_ASSERT(cf2 == 0);
cf = addc_1(p3, p3, p3n, cf, 0u);
cf2 = addc_1(p2, p2, l, cf2, 0u);
cf = addc_1(p3, p3, p3n, cf, cf2);
WJR_ASSERT(cf == 0);
}

Expand All @@ -416,17 +424,17 @@ void toom33_mul_s(T *dst, const T *src0, size_t n, const T *src1, size_t m, T *s
const auto v1p = src1 + l;
const auto v2p = src1 + l * 2;

auto w0p = stk;
auto w1p = stk + (l + 1) * 2;
auto w2p = stk + (l + 1) * 4;
auto w3p = stk + (l + 1) * 6;
auto w4p = stk + (l + 1) * 8;
auto w5p = stk + (l + 1) * 10;
auto w0p = dst;
auto w1p = stk;
auto w2p = stk + (l + 1) * 2;
auto w3p = stk + (l + 1) * 4;
auto w4p = dst + l * 4;
auto w5p = stk + (l + 1) * 6;

stk += (l + 1) * 12;
stk += (l + 1) * 8;

T cf0 = 0, cf1 = 0, cf2 = 0, cf3 = 0, cf4 = 0, cf5 = 0;
bool neg0 = 0, neg1 = 0, neg2 = 0, neg3 = 0, neg4 = 0, neg5 = 0;
bool neg2 = 0, neg3 = 0;

// W0 = U0 + U2 : (non-negative)
cf0 = addc_s(w0p, u0p, l, u2p, rn, 0u);
Expand Down Expand Up @@ -466,55 +474,42 @@ void toom33_mul_s(T *dst, const T *src0, size_t n, const T *src1, size_t m, T *s
// W1 = W3 * W2 : r(-1)
w3p[l] = cf3;
w2p[l] = cf2;
cf1 = 0;
neg1 = neg3 ^ neg2;
__mul_s<__mul_s_mode::toom33>(w1p, w3p, l + 1, w2p, l + 1, stk);
neg2 ^= neg3;
__mul_s_impl<__mul_s_mode::toom33>(w1p, w3p, l + 1, w2p, l + 1, stk);

// W2 = W0 * W4 : (non-negative) r(1)
w0p[l] = cf0;
w4p[l] = cf4;
cf2 = 0;
neg2 = 0;
__mul_s<__mul_s_mode::toom33>(w2p, w0p, l + 1, w4p, l + 1, stk);
__mul_s_impl<__mul_s_mode::toom33>(w2p, w0p, l + 1, w4p, l + 1, stk);

// W0 = (W0 + U2) << 1 - U0 : (non-negative) u(2)
cf0 += addc_s(w0p, w0p, l, u2p, rn, 0u);
WJR_ASSERT(cf0 <= 3);
cf0 += cf0 + lshift_n(w0p, w0p, l, 1u);
neg0 = 0;
cf0 -= subc_n(w0p, w0p, u0p, l, 0u);
WJR_ASSERT(cf0 <= 6);

// W4 = (W4 + V2) << 1 - V0 : (non-negative) v(2)
cf4 += addc_s(w4p, w4p, l, v2p, rm, 0u);
WJR_ASSERT(cf4 <= 3);
cf4 += cf4 + lshift_n(w4p, w4p, l, 1u);
neg4 = 0;
cf4 -= subc_n(w4p, w4p, v0p, l, 0u);
WJR_ASSERT(cf4 <= 6);

// W3 = W0 * W4 : (non-negative) r(2)
w0p[l] = cf0;
w4p[l] = cf4;
cf3 = 0;
neg3 = 0;
__mul_s<__mul_s_mode::toom33>(w3p, w0p, l + 1, w4p, l + 1, stk);

w0p = dst;
w4p = dst + l * 4;
__mul_s_impl<__mul_s_mode::toom33>(w3p, w0p, l + 1, w4p, l + 1, stk);

// W0 = U0 * V0 : (non-negative) r(0) = r0
cf0 = 0;
neg0 = 0;
__mul_s<__mul_s_mode::toom33>(w0p, u0p, l, v0p, l, stk);
__mul_s_impl<__mul_s_mode::toom33>(w0p, u0p, l, v0p, l, stk);

// W4 = U2 * V2 : (non-negative) r(inf) = r4
cf4 = 0;
neg4 = 0;
__mul_s<__mul_s_mode::toom33>(w4p, u2p, rn, v2p, rm, stk);
__mul_s_impl<__mul_s_mode::toom33>(w4p, u2p, rn, v2p, rm, stk);

// W3 = (W3 - W1) / 3 : (non-negative) (r(2) - r(-1)) / 3
{
if (neg3 == neg1) {
if (!neg2) {
cf3 = subc_n(w3p, w3p, w1p, l * 2 + 1, 0u);
} else {
cf3 = addc_n(w3p, w3p, w1p, l * 2 + 1, 0u);
Expand All @@ -526,10 +521,9 @@ void toom33_mul_s(T *dst, const T *src0, size_t n, const T *src1, size_t m, T *s

// W1 = (W2 - W1) >> 1 : (non-negative) (r(1) - r(-1)) / 2
{
if (neg2 == neg1) {
if (!neg2) {
cf1 = subc_n(w1p, w2p, w1p, l * 2 + 1, 0u);
} else {
neg1 = 0;
cf1 = addc_n(w1p, w2p, w1p, l * 2 + 1, 0u);
}

Expand All @@ -538,7 +532,7 @@ void toom33_mul_s(T *dst, const T *src0, size_t n, const T *src1, size_t m, T *s
}

// W2 = W2 - W0 : (non-negative) r(1) - r(0)
cf2 -= subc_s(w2p, w2p, l * 2 + 1, w0p, l * 2, 0u);
cf2 = subc_s(w2p, w2p, l * 2 + 1, w0p, l * 2, 0u);
WJR_ASSERT(cf2 == 0);

// W3 = ((W3 - W2) >> 1) - (W4 << 1) : (non-negative) r3
Expand All @@ -548,7 +542,6 @@ void toom33_mul_s(T *dst, const T *src0, size_t n, const T *src1, size_t m, T *s

(void)rshift_n(w3p, w3p, l + rn + 1, 1u);

neg5 = 0;
cf5 = lshift_n(w5p, w4p, rn + rm, 1u);

cf3 = subc_n(w3p, w3p, w5p, rn + rm, 0u);
Expand All @@ -569,16 +562,16 @@ void toom33_mul_s(T *dst, const T *src0, size_t n, const T *src1, size_t m, T *s

// W1 = W1 - W3 : // r2 * x + r1
cf1 = subc_n(w1p, w1p, w3p, l, 0u);
cf1 = cf3 + subc_n(w2p, w2p, w4p, rn + rm, cf1);
cf1 = cf3 + subc_n(dst + l * 2, w2p, w4p, rn + rm, cf1);
cf2 += w2p[l * 2];
if (l * 2 != rn + rm) {
cf1 = cf2 - subc_1(w2p + rn + rm, w2p + rn + rm, (l * 2) - (rn + rm), cf1, 0u);
cf1 = cf2 -
subc_1(dst + l * 2 + rn + rm, w2p + rn + rm, (l * 2) - (rn + rm), cf1, 0u);
} else {
cf1 = cf2 - cf1;
}

// W = W3*x^3+ W1*x + W0
std::copy(w2p, w2p + l * 2, dst + l * 2);
cf0 = addc_n(w0p + l, w0p + l, w1p, l, 0u);
cf0 = addc_1(dst + l * 2, dst + l * 2, l, 0u, cf0);
cf0 = addc_n(dst + l * 3, dst + l * 3, w3p, l, cf0);
Expand Down

0 comments on commit c4115a9

Please sign in to comment.