Skip to content

Commit

Permalink
upd
Browse files Browse the repository at this point in the history
  • Loading branch information
wjr-z committed Jan 25, 2024
1 parent 68fc3b1 commit 050547c
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 64 deletions.
4 changes: 3 additions & 1 deletion include/wjr/math/add.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,9 @@ WJR_INTRINSIC_CONSTEXPR_E T addc(T a, T b, type_identity_t<U> c_in, U &c_out) {
return fallback_addc(a, b, c_in, c_out);
#else
constexpr auto is_constant_or_zero = [](const auto &x) -> int {
return WJR_BUILTIN_CONSTANT_P(x) ? 1 : 0;
return WJR_BUILTIN_CONSTANT_P(x == 0) && x == 0 ? 2
: WJR_BUILTIN_CONSTANT_P(x) ? 1
: 0;
};

// The compiler should be able to optimize the judgment condition of if when enabling
Expand Down
135 changes: 74 additions & 61 deletions include/wjr/math/mul.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -299,14 +299,14 @@ template <typename T>
void toom32_mul_s(T *dst, const T *src0, size_t n, const T *src1, size_t m, T *stk);

// l = max(ceil(n / 4), ceil(m / 2))
// stk usage : 6 * l + 3
// recursive stk max usage : 8 * l + 67
// stk usage : 6 * l
// recursive stk max usage : 8 * l + 67 ?
template <typename T>
void toom42_mul_s(T *dst, const T *src0, size_t n, const T *src1, size_t m, T *stk);

// l = ceil(n / 3)
// stk usage : 6 * l + 3
// recursive stk max usage : 9 * l + 288
// stk usage : 6 * l
// recursive stk max usage : 9 * l + 288 ?
template <typename T>
void toom33_mul_s(T *dst, const T *src0, size_t n, const T *src1, size_t m, T *stk);

Expand Down Expand Up @@ -788,86 +788,103 @@ template <typename T, std::enable_if_t<std::is_same_v<T, uint64_t>, int> = 0>
WJR_CONSTEXPR_E void divexact_by3(T *dst, const T *src, size_t n);

template <typename T>
void toom_interpolation_5p_s(T *dst, T *w1p, size_t l, size_t rn, size_t rm, bool neg2) {
struct toom_interpolation_5p_struct {
bool neg1;
T cf1;
T cf2;
T cf3;
};

template <typename T>
void toom_interpolation_5p_s(T *dst, T *w1p, size_t l, size_t rn, size_t rm,
toom_interpolation_5p_struct<T> &&flag) {
// r(0) r(-1) r(1) r(2) r(inf)

const size_t maxr = std::max(rn, rm);

auto w0p = dst;
auto w2p = w1p + (2 * l + 1);
auto w3p = w1p + (2 * l + 1) * 2;
auto w2p = w1p + l * 2;
auto w3p = w1p + l * 4;
auto w4p = dst + l * 4;

T cf0 = 0, cf1 = 0, cf2 = 0, cf3 = 0;
auto [neg1, cf1, cf2, cf3] = flag;

T cf;

// W3 = (W3 - W1) / 3 : (non-negative) (r(2) - r(-1)) / 3
{
if (!neg2) {
cf3 = subc_n(w3p, w3p, w1p, l * 2 + 1, 0u);
if (!flag.neg1) {
cf3 -= cf1 + subc_n(w3p, w3p, w1p, l * 2, 0u);
} else {
cf3 = addc_n(w3p, w3p, w1p, l * 2 + 1, 0u);
cf3 += cf1 + addc_n(w3p, w3p, w1p, l * 2, 0u);
}

WJR_ASSERT(cf3 == 0);
divexact_by3(w3p, w3p, l * 2 + 1);
cf3 /= 3;
divexact_by3(w3p, w3p, l * 2);
}

// W1 = (W2 - W1) >> 1 : (non-negative) (r(1) - r(-1)) / 2
{
if (!neg2) {
cf1 = subc_n(w1p, w2p, w1p, l * 2 + 1, 0u);
if (!flag.neg1) {
cf1 = cf2 - cf1 - subc_n(w1p, w2p, w1p, l * 2, 0u);
} else {
cf1 = addc_n(w1p, w2p, w1p, l * 2 + 1, 0u);
cf1 = cf2 + cf1 + addc_n(w1p, w2p, w1p, l * 2, 0u);
}

WJR_ASSERT(cf1 == 0);
rshift_n(w1p, w1p, l * 2 + 1, 1u);
(void)rshift_n(w1p, w1p, l * 2, 1u);
if (cf1) {
w1p[l * 2 - 1] |= (cf1 & 1) << 63;
}
cf1 >>= 1;
}

// W2 = W2 - W0 : (non-negative) r(1) - r(0)
cf2 = subc_s(w2p, w2p, l * 2 + 1, w0p, l * 2, 0u);
WJR_ASSERT(cf2 == 0);
cf2 -= subc_n(w2p, w2p, w0p, l * 2, 0u);

// W3 = ((W3 - W2) >> 1) - (W4 << 1) : (non-negative) r3
{
cf3 = subc_n(w3p, w3p, w2p, l * 2 + 1, 0u);
WJR_ASSERT(cf3 == 0);
cf3 -= cf2 + subc_n(w3p, w3p, w2p, l * 2, 0u);

(void)rshift_n(w3p, w3p, l + maxr + 1, 1u);

T cf5 = lshift_n(dst + l * 2, w4p, rn + rm, 1u);
(void)rshift_n(w3p, w3p, l + maxr, 1u);
if (maxr != l) {
cf3 = w3p[l + maxr];
}
if (cf3) {
w3p[l + maxr - 1] |= (cf3 & 1) << 63;
}
cf3 >>= 1;

cf3 = subc_n(w3p, w3p, dst + l * 2, rn + rm, 0u);
cf3 = subc_1(w3p + rn + rm, w3p + rn + rm, (l + maxr + 1) - (rn + rm), cf5, cf3);
WJR_ASSERT(cf3 == 0);
cf = lshift_n(dst + l * 2, w4p, rn + rm, 1u);
cf += subc_n(w3p, w3p, dst + l * 2, rn + rm, 0u);
if ((l + maxr) != (rn + rm)) {
cf3 -= subc_1(w3p + rn + rm, w3p + rn + rm, (l + maxr) - (rn + rm), cf, 0u);
} else {
cf3 -= cf;
}
}

// W2 = W2 - W1 : (non-negative) r(1) / 2 - r(0) + r(-1) / 2
cf2 -= subc_n(w2p, w2p, w1p, l * 2 + 1, 0u);
WJR_ASSERT(cf2 == 0);
cf2 -= cf1 + subc_n(w2p, w2p, w1p, l * 2, 0u);

// W3 = W4 * x + W3 : r4 * x + r3
cf3 = addc_s(w4p, w4p, rn + rm, w3p + l, maxr + 1, 0u);
cf = addc_n(w4p, w4p, w3p + l, maxr, 0u);
cf = addc_1(w4p + maxr, w4p + maxr, rn + rm - maxr, cf3, cf);
WJR_ASSERT(cf == 0);

// W1 = W2 * x + W1 :
cf2 = addc_s(w2p, w2p, l * 2, w1p + l, l + 1, 0u);
cf = addc_s(w2p, w2p, l * 2, w1p + l, l, 0u);
cf2 += addc_1(w2p + l, w2p + l, l, cf1, cf);

// W1 = W1 - W3 : // r2 * x + r1
cf1 = subc_n(w1p, w1p, w3p, l, 0u);
cf1 = cf3 + subc_n(dst + l * 2, w2p, w4p, rn + rm, cf1);
cf2 += w2p[l * 2];
if (l * 2 != rn + rm) {
cf1 = cf2 - subc_1(dst + (l * 2) + (rn + rm), w2p + (rn + rm),
(l * 2) - (rn + rm), cf1, 0u);
} else {
cf1 = cf2 - cf1;
}
cf = subc_n(w1p, w1p, w3p, l, 0u);
cf2 -= subc_s(dst + l * 2, w2p, l * 2, w4p, rn + rm, cf);

// W = W3*x^3+ W1*x + W0
cf0 = addc_n(w0p + l, w0p + l, w1p, l, 0u);
cf0 = addc_1(dst + l * 2, dst + l * 2, l, 0u, cf0);
cf0 = addc_n(dst + l * 3, dst + l * 3, w3p, l, cf0);
cf0 = addc_1(dst + l * 4, dst + l * 4, rn + rm, cf1, cf0);
WJR_ASSERT(cf0 == 0);
(void)(cf0);
cf = addc_n(dst + l, dst + l, w1p, l, 0u);
cf = addc_1(dst + l * 2, dst + l * 2, l, 0u, cf);
cf = addc_n(dst + l * 3, dst + l * 3, w3p, l, cf);
cf = addc_1(dst + l * 4, dst + l * 4, rn + rm, cf2, cf);
WJR_ASSERT(cf == 0);
}

template <typename T>
Expand All @@ -893,11 +910,11 @@ void toom42_mul_s(T *dst, const T *src0, size_t n, const T *src1, size_t m, T *s

auto w0p = dst;
auto w1p = stk;
auto w2p = stk + (2 * l + 1);
auto w3p = stk + (2 * l + 1) * 2;
auto w2p = stk + l * 2;
auto w3p = stk + l * 4;
auto w4p = dst + l * 4;

stk += 3 * (2 * l + 1);
stk += l * 6;

T cf0 = 0, cf1 = 0, cf2 = 0, cf3 = 0, cf4 = 0;
bool neg2 = 0, neg3 = 0;
Expand Down Expand Up @@ -943,7 +960,6 @@ void toom42_mul_s(T *dst, const T *src0, size_t n, const T *src1, size_t m, T *s
if (WJR_UNLIKELY(cf3 != 0)) {
cf1 += addc_n(w1p + l, w1p + l, w2p, l, 0u);
}
w1p[l * 2] = cf1;

// W2 = W0 * W4 : (non-negative) r(1)
__rec_mul_n<__rec_mul_mode::toom22>(w2p, w0p, w4p, l, stk);
Expand All @@ -958,7 +974,6 @@ void toom42_mul_s(T *dst, const T *src0, size_t n, const T *src1, size_t m, T *s
if (WJR_UNLIKELY(cf4 != 0)) {
cf2 += addc_n(w2p + l, w2p + l, w0p, l, 0u);
}
w2p[l * 2] = cf2;

// W0 = U0 +(U1 +(U2 +U3<<1)<<1)<<1 : (non-negative) u(2)
{
Expand Down Expand Up @@ -998,7 +1013,6 @@ void toom42_mul_s(T *dst, const T *src0, size_t n, const T *src1, size_t m, T *s
cf3 += addmul_1(w3p + l, w0p, l, 2u);
}
}
w3p[l * 2] = cf3;

// W0 = U0 * V0 : (non-negative) r(0) = r0
__rec_mul_n<__rec_mul_mode::toom22>(w0p, u0p, v0p, l, stk);
Expand All @@ -1010,7 +1024,8 @@ void toom42_mul_s(T *dst, const T *src0, size_t n, const T *src1, size_t m, T *s
__rec_mul_s<__rec_mul_mode::toom22>(w4p, v1p, rm, u3p, rn, stk);
}

return toom_interpolation_5p_s(dst, w1p, l, rn, rm, neg2);
return toom_interpolation_5p_s(dst, w1p, l, rn, rm,
toom_interpolation_5p_struct<T>{neg2, cf1, cf2, cf3});
}

template <typename T>
Expand All @@ -1032,11 +1047,11 @@ void toom33_mul_s(T *dst, const T *src0, size_t n, const T *src1, size_t m, T *s

auto w0p = dst;
auto w1p = stk;
auto w2p = stk + (2 * l + 1);
auto w3p = stk + (2 * l + 1) * 2;
auto w2p = stk + l * 2;
auto w3p = stk + l * 4;
auto w4p = dst + l * 4;

stk += 3 * (2 * l + 1);
stk += l * 6;

T cf0 = 0, cf1 = 0, cf2 = 0, cf3 = 0, cf4 = 0;
bool neg2 = 0, neg3 = 0;
Expand Down Expand Up @@ -1082,7 +1097,6 @@ void toom33_mul_s(T *dst, const T *src0, size_t n, const T *src1, size_t m, T *s
if (WJR_UNLIKELY(cf3 != 0)) {
cf1 += addc_n(w1p + l, w1p + l, w2p, l, 0u);
}
w1p[l * 2] = cf1;

// W2 = W0 * W4 : (non-negative) r(1)
__rec_mul_n<__rec_mul_mode::toom33>(w2p, w0p, w4p, l, stk);
Expand All @@ -1101,7 +1115,6 @@ void toom33_mul_s(T *dst, const T *src0, size_t n, const T *src1, size_t m, T *s
cf2 += addmul_1(w2p + l, w0p, l, 2u);
}
}
w2p[l * 2] = cf2;

// W0 = (W0 + U2) << 1 - U0 : (non-negative) u(2)
cf0 += addc_s(w0p, w0p, l, u2p, rn, 0u);
Expand Down Expand Up @@ -1134,15 +1147,15 @@ void toom33_mul_s(T *dst, const T *src0, size_t n, const T *src1, size_t m, T *s
cf3 += addmul_1(w3p + l, w0p, l, cf4);
}
}
w3p[l * 2] = cf3;

// W0 = U0 * V0 : (non-negative) r(0) = r0
__rec_mul_n<__rec_mul_mode::toom33>(w0p, u0p, v0p, l, stk);

// W4 = U2 * V2 : (non-negative) r(inf) = r4
__rec_mul_s<__rec_mul_mode::toom33>(w4p, u2p, rn, v2p, rm, stk);

return toom_interpolation_5p_s(dst, w1p, l, rn, rm, neg2);
return toom_interpolation_5p_s(dst, w1p, l, rn, rm,
toom_interpolation_5p_struct<T>{neg2, cf1, cf2, cf3});
}

} // namespace wjr
Expand Down
6 changes: 4 additions & 2 deletions include/wjr/math/sub.hpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#ifndef WJR_MATH_SUB_HPP__
#define WJR_MATH_SUB_HPP__

#include <wjr/math/clz.hpp>
#include <wjr/math/bit.hpp>
#include <wjr/math/clz.hpp>
#include <wjr/math/neg.hpp>
#include <wjr/math/replace.hpp>

Expand Down Expand Up @@ -65,7 +65,9 @@ WJR_INTRINSIC_CONSTEXPR_E T subc(T a, T b, type_identity_t<U> c_in, U &c_out) {
return fallback_subc(a, b, c_in, c_out);
#else
constexpr auto is_constant_or_zero = [](const auto &x) -> int {
return WJR_BUILTIN_CONSTANT_P(x) ? 1 : 0;
return WJR_BUILTIN_CONSTANT_P(x == 0) && x == 0 ? 2
: WJR_BUILTIN_CONSTANT_P(x) ? 1
: 0;
};

// The compiler should be able to optimize the judgment condition of if when enabling
Expand Down

0 comments on commit 050547c

Please sign in to comment.