Skip to content

Commit

Permalink
opt
Browse files Browse the repository at this point in the history
  • Loading branch information
wjr-z committed Jan 11, 2024
1 parent c6015a4 commit 42db3e4
Show file tree
Hide file tree
Showing 6 changed files with 119 additions and 147 deletions.
46 changes: 46 additions & 0 deletions include/wjr/math/div.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,34 @@ WJR_INTRINSIC_CONSTEXPR20 T divmod_1(T *dst, const T *src, size_t n,
return fallback_divmod_1(dst, src, n, div2by1_divider<T>(div));
}

template <typename T, std::enable_if_t<std::is_same_v<T, uint64_t>, int> = 0>
WJR_CONSTEXPR_E T fallback_divexact_dbm1c(T *dst, const T *src, size_t n, T bd, T h) {
T a = 0, p0 = 0, p1 = 0, cf = 0;

for (size_t i = 0; i < n; i++) {
a = src[i];
p0 = mul(a, bd, p1);
cf = h < p0;
h = (h - p0);
dst[i] = h;
h = h - p1 - cf;
}

return h;
}

template <typename T, std::enable_if_t<std::is_same_v<T, uint64_t>, int> = 0>
WJR_CONSTEXPR_E void divexact_by3(T *dst, const T *src, size_t n) {
constexpr auto max = std::numeric_limits<T>::max();
(void)fallback_divexact_dbm1c<T>(dst, src, n, max / 3, 0);
}

template <typename T, std::enable_if_t<std::is_same_v<T, uint64_t>, int> = 0>
WJR_CONSTEXPR_E void divexact_by5(T *dst, const T *src, size_t n) {
constexpr auto max = std::numeric_limits<T>::max();
(void)fallback_divexact_dbm1c<T>(dst, src, n, max / 5, 0);
}

// reference : ftp://ftp.risc.uni-linz.ac.at/pub/techreports/1992/92-35.ps.gz
// TODO : asm_divexact_1 (Low priority)
template <typename T, std::enable_if_t<std::is_same_v<T, uint64_t>, int> = 0>
Expand Down Expand Up @@ -255,6 +283,16 @@ WJR_INTRINSIC_CONSTEXPR_E void divexact_1(T *dst, const T *src, size_t n,
return;
}

if (WJR_BUILTIN_CONSTANT_P(div.shift() == 0) && div.shift() == 0) {
if (WJR_BUILTIN_CONSTANT_P(div.divisor() == 3) && div.divisor() == 3) {
return divexact_by3(dst, src, n);
}

if (WJR_BUILTIN_CONSTANT_P(div.divisor() == 5) && div.divisor() == 5) {
return divexact_by5(dst, src, n);
}
}

return fallback_divexact_1(dst, src, n, div);
}

Expand All @@ -270,6 +308,14 @@ WJR_INTRINSIC_CONSTEXPR_E void divexact_1(T *dst, const T *src, size_t n,
return;
}

if (WJR_BUILTIN_CONSTANT_P(div == 3) && div == 3) {
return divexact_by3(dst, src, n);
}

if (WJR_BUILTIN_CONSTANT_P(div == 5) && div == 5) {
return divexact_by5(dst, src, n);
}

return fallback_divexact_1(dst, src, n, divexact1_divider<T>(div));
}

Expand Down
118 changes: 69 additions & 49 deletions include/wjr/math/mul.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -220,14 +220,7 @@ WJR_INTRINSIC_CONSTEXPR_E T submul_1(T *dst, const T *src0, size_t n, T src1,
#endif
}

template <typename T>
WJR_INLINE_CONSTEXPR void basecase_mul_s(T *dst, const T *src0, size_t n, const T *src1,
size_t m) {
dst[n] = mul_1(dst, src0, n, src1[0], 0);
for (size_t i = 1; i < m; ++i) {
dst[i + n] = addmul_1(dst + i, src0, n, src1[i], 0);
}
}
// preview :

// native default threshold of toom-cook-2
// TODO : optimize threshold
Expand All @@ -238,7 +231,53 @@ WJR_INLINE_CONSTEXPR void basecase_mul_s(T *dst, const T *src0, size_t n, const
inline constexpr size_t toom22_mul_threshold = WJR_TOOM22_MUL_THRESHOLD;

template <typename T>
void mul_s(T *dst, const T *src0, size_t n, const T *src1, size_t m);
WJR_INLINE_CONSTEXPR void basecase_mul_s(T *dst, const T *src0, size_t n, const T *src1,
size_t m);

template <typename T>
void toom22_mul_s(T *dst, const T *src0, size_t n, const T *src1, size_t m);

template <typename T>
void mul_s(T *dst, const T *src0, size_t n, const T *src1, size_t m) {
WJR_ASSERT(WJR_IS_SAME_OR_INCR_P(dst, n + m, src0, n));
WJR_ASSERT(WJR_IS_SAME_OR_INCR_P(dst, n + m, src1, m));

if (n < m) {
std::swap(n, m);
std::swap(src0, src1);
}

WJR_ASSUME(m >= 1);

do {
if (m <= toom22_mul_threshold) {
break;
}

if (m <= toom22_mul_threshold * 4) {
if (5 * m <= 4 * n) {
break;
}
}

if (5 * m <= 3 * n) {
break;
}

return toom22_mul_s(dst, src0, n, src1, m);
} while (0);

return basecase_mul_s(dst, src0, n, src1, m);
}

template <typename T>
WJR_INLINE_CONSTEXPR void basecase_mul_s(T *dst, const T *src0, size_t n, const T *src1,
size_t m) {
dst[n] = mul_1(dst, src0, n, src1[0], 0);
for (size_t i = 1; i < m; ++i) {
dst[i + n] = addmul_1(dst + i, src0, n, src1[i], 0);
}
}

template <typename T>
void toom22_mul_s(T *dst, const T *src0, size_t n, const T *src1, size_t m) {
Expand Down Expand Up @@ -299,58 +338,39 @@ void toom22_mul_s(T *dst, const T *src0, size_t n, const T *src1, size_t m) {
mul_s(p0, u0, l, v0, l);
mul_s(p2, u1, rn, v1, rm);

T cy = 0, cy2 = 0;
T cf = 0, cf2 = 0;

cy = addc_n(p2, p1, p2, l, 0u);
cy2 = cy + addc_n(p1, p0, p2, l, 0u);
cy += addc_s(p2, p2, l, p3, p3n, 0u);
cf = addc_n(p2, p1, p2, l, 0u);
cf2 = cf + addc_n(p1, p0, p2, l, 0u);
cf += addc_s(p2, p2, l, p3, p3n, 0u);

if (!f) {
cy -= subc_n(p1, p1, stk, l * 2, 0u);
cf -= subc_n(p1, p1, stk, l * 2, 0u);
} else {
cy += addc_n(p1, p1, stk, l * 2, 0u);
cf += addc_n(p1, p1, stk, l * 2, 0u);
}

cy2 = addc_1(p2, p2, rn + rm, cy2, 0u);
WJR_ASSERT(cy2 == 0);
cy = addc_1(p3, p3, rn + rm - l, cy, 0u);
WJR_ASSERT(cy == 0);
cf2 = addc_1(p2, p2, rn + rm, cf2, 0u);
WJR_ASSERT(cf2 == 0);
cf = addc_1(p3, p3, rn + rm - l, cf, 0u);
WJR_ASSERT(cf == 0);
}

// preview : mul n x m
// TODO : ...

template <typename T>
void mul_s(T *dst, const T *src0, size_t n, const T *src1, size_t m) {
WJR_ASSERT(WJR_IS_SAME_OR_INCR_P(dst, n + m, src0, n));
WJR_ASSERT(WJR_IS_SAME_OR_INCR_P(dst, n + m, src1, m));

if (n < m) {
std::swap(n, m);
std::swap(src0, src1);
}

WJR_ASSUME(m >= 1);

do {
if (m <= toom22_mul_threshold) {
break;
}

if (m <= toom22_mul_threshold * 4) {
if (5 * m <= 4 * n) {
break;
}
}
void toom33_mul_s(T *dst, const T *src0, size_t n, const T *src1, size_t m) {
WJR_ASSUME(n >= m);

if (5 * m <= 3 * n) {
break;
}
const size_t l = (n + 2) / 3;
const size_t rn = n - l * 2;
const size_t rm = m - l * 2;

return toom22_mul_s(dst, src0, n, src1, m);
} while (0);
const auto u0p = src0;
const auto u1p = src0 + l;
const auto u2p = src0 + l * 2;

return basecase_mul_s(dst, src0, n, src1, m);
const auto v0p = src1;
const auto v1p = src1 + l;
const auto v2p = src1 + l * 2;
}

} // namespace wjr
Expand Down
24 changes: 0 additions & 24 deletions include/wjr/math/shift.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,7 @@ WJR_INTRINSIC_CONSTEXPR T fallback_shld(T hi, T lo, unsigned int c) {

template <typename T>
WJR_INTRINSIC_CONSTEXPR_E T shld(T hi, T lo, unsigned int c) {
#if WJR_HAS_BUILTIN(ASM_SHLD)
if (is_constant_evaluated() || (WJR_BUILTIN_CONSTANT_P(hi == 0) && hi == 0) ||
(WJR_BUILTIN_CONSTANT_P(lo == 0) && lo == 0) ||
(WJR_BUILTIN_CONSTANT_P(c == 0) && c == 0) ||
(WJR_BUILTIN_CONSTANT_P(c) &&
(WJR_BUILTIN_CONSTANT_P(hi) || WJR_BUILTIN_CONSTANT_P(lo)))) {
return fallback_shld(hi, lo, c);
}

return asm_shld(hi, lo, c);
#else
return fallback_shld(hi, lo, c);
#endif
}

template <typename T>
Expand All @@ -40,19 +28,7 @@ WJR_INTRINSIC_CONSTEXPR T fallback_shrd(T lo, T hi, unsigned int c) {

template <typename T>
WJR_INTRINSIC_CONSTEXPR_E T shrd(T lo, T hi, unsigned int c) {
#if WJR_HAS_BUILTIN(ASM_SHRD)
if (is_constant_evaluated() || (WJR_BUILTIN_CONSTANT_P(hi == 0) && hi == 0) ||
(WJR_BUILTIN_CONSTANT_P(lo == 0) && lo == 0) ||
(WJR_BUILTIN_CONSTANT_P(c == 0) && c == 0) ||
(WJR_BUILTIN_CONSTANT_P(c) &&
(WJR_BUILTIN_CONSTANT_P(hi) || WJR_BUILTIN_CONSTANT_P(lo)))) {
return fallback_shrd(lo, hi, c);
}

return asm_shrd(lo, hi, c);
#else
return fallback_shrd(lo, hi, c);
#endif
}

template <typename T>
Expand Down
2 changes: 1 addition & 1 deletion include/wjr/x86/gen_addsub.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ WJR_INLINE U WJR_PP_CONCAT(asm_, WJR_PP_CONCAT(WJR_addcsubc, _n))(T *dst, const
"%[n] * 4]}\n\t" \
"lea{q (%[t0], %[n], 1), %[n]| %[n], [%[t0] + %[n]]}\n\t" \
"jmp{q *%[n]| %[n]}\n\t" \
".align 4\n\t" \
".align 8\n\t" \
".Lasm_" WJR_PP_STR(WJR_addcsubc) "_n_lookup%=:\n\t" \
".long .Lcase0%=-.Lasm_" WJR_PP_STR(WJR_addcsubc) "_n_lookup%=\n\t" \
".long .Lcase1%=-.Lasm_" WJR_PP_STR(WJR_addcsubc) "_n_lookup%=\n\t" \
Expand Down
6 changes: 3 additions & 3 deletions include/wjr/x86/mul.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ WJR_INLINE uint64_t asm_mul_1(uint64_t *dst, const uint64_t *src0, size_t n,
"lea{q (%[t3], %[t0], 1), %[t0]| %[t0], [%[t0] + %[t3]]}\n\t"
"jmp{q *%[t0]| %[t0]}\n\t"

".align 4\n\t"
".align 8\n\t"
".Lasm_addmul_1_lookup%=:\n\t"
".long .Lcase0%=-.Lasm_addmul_1_lookup%=\n\t"
".long .Lcase1%=-.Lasm_addmul_1_lookup%=\n\t"
Expand Down Expand Up @@ -180,7 +180,7 @@ WJR_INLINE uint64_t asm_addmul_1(uint64_t *dst, const uint64_t *src0, size_t n,
"lea{q (%[t3], %[t0], 1), %[t0]| %[t0], [%[t0] + %[t3]]}\n\t"
"jmp{q *%[t0]| %[t0]}\n\t"

".align 4\n\t"
".align 8\n\t"
".Lasm_addmul_1_lookup%=:\n\t"
".long .Lcase0%=-.Lasm_addmul_1_lookup%=\n\t"
".long .Lcase1%=-.Lasm_addmul_1_lookup%=\n\t"
Expand Down Expand Up @@ -293,7 +293,7 @@ WJR_INLINE uint64_t asm_submul_1(uint64_t *dst, const uint64_t *src0, size_t n,
"lea{q (%[t3], %[t0], 1), %[t0]| %[t0], [%[t0] + %[t3]]}\n\t"
"jmp{q *%[t0]| %[t0]}\n\t"

".align 4\n\t"
".align 8\n\t"
".Lasm_submul_1_lookup%=:\n\t"
".long .Lcase0%=-.Lasm_submul_1_lookup%=\n\t"
".long .Lcase1%=-.Lasm_submul_1_lookup%=\n\t"
Expand Down
70 changes: 0 additions & 70 deletions include/wjr/x86/shift.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,76 +15,6 @@ WJR_INTRINSIC_CONSTEXPR_E T shld(T hi, T lo, unsigned int c);
template <typename T>
WJR_INTRINSIC_CONSTEXPR_E T shrd(T lo, T hi, unsigned int c);

#if WJR_HAS_FEATURE(INLINE_ASM) && \
(defined(WJR_COMPILER_CLANG) || defined(WJR_COMPILER_GCC))
#define WJR_HAS_BUILTIN_ASM_SHLD WJR_HAS_DEF
#endif

#if WJR_HAS_BUILTIN(ASM_SHLD)

template <typename T>
WJR_INTRINSIC_INLINE T asm_shld(T hi, T lo, unsigned int c) {
constexpr auto nd = std::numeric_limits<T>::digits;

#define WJR_REGISTER_BUILTIN_ASM_SHLD(args) \
WJR_PP_TRANSFORM_PUT(args, WJR_REGISTER_BUILTIN_ASM_SHLD_I_CALLER)
#define WJR_REGISTER_BUILTIN_ASM_SHLD_I(suffix, type) \
if constexpr (nd == std::numeric_limits<type>::digits) { \
asm("shld{" #suffix " %b2, %1, %0| %0, %1, %b2 }" \
: "+r"(hi) \
: "r"(lo), "ci"(c) \
: "cc"); \
return hi; \
} else
#define WJR_REGISTER_BUILTIN_ASM_SHLD_I_CALLER(args) WJR_REGISTER_BUILTIN_ASM_SHLD_I args

WJR_REGISTER_BUILTIN_ASM_SHLD(
((b, uint8_t), (w, uint16_t), (l, uint32_t), (q, uint64_t))) {
static_assert(nd <= 64, "not supported yet");
}

#undef WJR_REGISTER_BUILTIN_ASM_SHLD_I_CALLER
#undef WJR_REGISTER_BUILTIN_ASM_SHLD_I
#undef WJR_REGISTER_BUILTIN_ASM_SHLD
}

#endif // WJR_HAS_BUILTIN(ASM_SHLD)

#if WJR_HAS_FEATURE(INLINE_ASM) && \
(defined(WJR_COMPILER_CLANG) || defined(WJR_COMPILER_GCC))
#define WJR_HAS_BUILTIN_ASM_SHRD WJR_HAS_DEF
#endif

#if WJR_HAS_BUILTIN(ASM_SHRD)

template <typename T>
WJR_INTRINSIC_INLINE T asm_shrd(T lo, T hi, unsigned int c) {
constexpr auto nd = std::numeric_limits<T>::digits;

#define WJR_REGISTER_BUILTIN_ASM_SHRD(args) \
WJR_PP_TRANSFORM_PUT(args, WJR_REGISTER_BUILTIN_ASM_SHRD_I_CALLER)
#define WJR_REGISTER_BUILTIN_ASM_SHRD_I(suffix, type) \
if constexpr (nd == std::numeric_limits<type>::digits) { \
asm("shrd{" #suffix " %b2, %1, %0| %0, %1, %b2 }" \
: "+r"(lo) \
: "r"(hi), "ci"(c) \
: "cc"); \
return lo; \
} else
#define WJR_REGISTER_BUILTIN_ASM_SHRD_I_CALLER(args) WJR_REGISTER_BUILTIN_ASM_SHRD_I args

WJR_REGISTER_BUILTIN_ASM_SHRD(
((b, uint8_t), (w, uint16_t), (l, uint32_t), (q, uint64_t))) {
static_assert(nd <= 64, "not supported yet");
}

#undef WJR_REGISTER_BUILTIN_ASM_SHRD_I_CALLER
#undef WJR_REGISTER_BUILTIN_ASM_SHRD_I
#undef WJR_REGISTER_BUILTIN_ASM_SHRD
}

#endif // WJR_HAS_BUILTIN(ASM_SHRD)

#if WJR_HAS_SIMD(SSE2) && WJR_HAS_SIMD(SIMD)
#define WJR_HAS_BUILTIN_LSHIFT_N WJR_HAS_DEF
#define WJR_HAS_BUILTIN_RSHIFT_N WJR_HAS_DEF
Expand Down

0 comments on commit 42db3e4

Please sign in to comment.