Skip to content

Commit

Permalink
update div_qr_s
Browse files Browse the repository at this point in the history
  • Loading branch information
wjr-z committed Feb 15, 2024
1 parent eec2a3e commit 8ea7070
Show file tree
Hide file tree
Showing 20 changed files with 299 additions and 156 deletions.
119 changes: 119 additions & 0 deletions include/wjr/assert.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
#ifndef WJR_ASSERT_HPP__
#define WJR_ASSERT_HPP__

#include <cstdio>
#include <type_traits>
#include <utility>

#include <wjr/preprocessor.hpp>

namespace wjr {

// ASSERT_LEVEL : 0 ~ 3
// 0 : Release (defined(NDEBUG) && ! defined(WJR_DEBUG_LEVEL))
// 1 : Some simple runtime checks, such as boundary checks (default)
// 2 : Most runtime checks
// 3 : Maximize runtime checks
#if defined WJR_DEBUG_LEVEL
#define WJR_ASSERT_LEVEL WJR_DEBUG_LEVEL
#elif defined(NDEBUG)
#define WJR_DEBUG_LEVEL 0
#else
#define WJR_DEBUG_LEVEL 1
#endif //

template <typename Handler, typename... Args,
std::enable_if_t<std::is_invocable_v<Handler, Args...>, int> = 0>
void __assert_fail(Handler handler, Args &&...args) {
handler(std::forward<Args>(args)...);
std::abort();
}

template <typename Handler, typename... Args,
std::enable_if_t<std::is_invocable_v<Handler, Args...>, int> = 0>
WJR_COLD WJR_NOINLINE void __assert_fail(const char *expr, const char *file, int line,
Handler handler, Args &&...args) {
(void)fprintf(stderr, "Assertion failed: %s", expr);
if ((file != nullptr) && (file[0] != '\0')) {
(void)fprintf(stderr, ", file %s", file);
}
if (line != -1) {
(void)fprintf(stderr, ", line %d", line);
}
(void)fprintf(stderr, "\n");
__assert_fail(handler, std::forward<Args>(args)...);
}

struct __assert_t {
void operator()() const {}

template <typename... Args>
void operator()(const char *fmt, Args &&...args) {
(void)fprintf(stderr, "Additional message: ");
(void)fprintf(stderr, fmt, std::forward<Args>(args)...);
(void)fprintf(stderr, "\n");
}
};

inline constexpr __assert_t __assert{};

#define WJR_ASSERT_NOMESSAGE_FAIL(handler, exprstr) \
::wjr::__assert_fail(exprstr, WJR_FILE, WJR_LINE, handler)
#define WJR_ASSERT_MESSAGE_FAIL(handler, exprstr, ...) \
::wjr::__assert_fail(exprstr, WJR_FILE, WJR_LINE, handler, __VA_ARGS__)

#define WJR_ASSERT_CHECK_I_NOMESSAGE(handler, expr) \
do { \
if (WJR_UNLIKELY(!(expr))) { \
WJR_ASSERT_NOMESSAGE_FAIL(handler, #expr); \
} \
} while (0)
#define WJR_ASSERT_CHECK_I_MESSAGE(handler, expr, ...) \
do { \
if (WJR_UNLIKELY(!(expr))) { \
WJR_ASSERT_MESSAGE_FAIL(handler, #expr, __VA_ARGS__); \
} \
} while (0)

#define WJR_ASSERT_CHECK_I(...) \
WJR_ASSERT_CHECK_I_N(WJR_PP_ARGS_LEN(__VA_ARGS__), __VA_ARGS__)
#define WJR_ASSERT_CHECK_I_N(N, ...) \
WJR_PP_BOOL_IF(WJR_PP_EQ(N, 1), WJR_ASSERT_CHECK_I_NOMESSAGE, \
WJR_ASSERT_CHECK_I_MESSAGE) \
(::wjr::__assert, __VA_ARGS__)

// do nothing
#define WJR_ASSERT_UNCHECK_I(expr, ...)

#define WJR_ALWAYS_ASSERT_UNCHECK_I(expr, ...) (expr)

// level = [0, 2]
// The higher the level, the less likely it is to be detected
// Runtime detect : 1
// Maximize detect : 2
#define WJR_ASSERT_L(level, ...) \
WJR_PP_BOOL_IF(WJR_PP_GT(WJR_DEBUG_LEVEL, level), WJR_ASSERT_CHECK_I, \
WJR_ASSERT_UNCHECK_I) \
(__VA_ARGS__)

// level of assert is zero at default.
#define WJR_ASSERT(...) WJR_ASSERT_L(0, __VA_ARGS__)

#define WJR_ALWAYS_ASSERT_L(level, ...) \
WJR_PP_BOOL_IF(WJR_PP_GT(WJR_DEBUG_LEVEL, level), WJR_ASSERT_CHECK_I, \
WJR_ALWAYS_ASSERT_UNCHECK_I) \
(__VA_ARGS__)

// level of assert is zero at default.
#define WJR_ALWAYS_ASSERT(...) WJR_ALWAYS_ASSERT_L(0, __VA_ARGS__)

#define WJR_ASSERT_ASSUME_L(level, ...) \
WJR_ASSERT_L(level, __VA_ARGS__); \
__WJR_ASSERT_ASSUME_L_ASSUME(__VA_ARGS__)
#define __WJR_ASSERT_ASSUME_L_ASSUME(expr, ...) WJR_ASSUME(expr)

#define WJR_ASSERT_ASSUME(...) WJR_ASSERT_ASSUME_L(0, __VA_ARGS__)

} // namespace wjr

#endif // WJR_ASSERT_HPP__
1 change: 1 addition & 0 deletions include/wjr/math.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define WJR_MATH_HPP__

#include <wjr/math/bit.hpp>
#include <wjr/math/compare.hpp>
#include <wjr/math/div.hpp>
#include <wjr/math/neg.hpp>

Expand Down
1 change: 1 addition & 0 deletions include/wjr/math/bit.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define WJR_MATH_BIT_HPP__

#include <wjr/math/clz.hpp>
#include <wjr/math/ctz.hpp>

namespace wjr {

Expand Down
2 changes: 1 addition & 1 deletion include/wjr/math/clz.hpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#ifndef WJR_MATH_CLZ_HPP__
#define WJR_MATH_CLZ_HPP__

#include <wjr/math/ctz.hpp>
#include <wjr/math/popcount.hpp>

namespace wjr {

Expand Down
2 changes: 0 additions & 2 deletions include/wjr/math/details.hpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
#ifndef WJR_MATH_DETAILS_HPP__
#define WJR_MATH_DETAILS_HPP__

#include <array>

#include <wjr/stack_allocator.hpp>
#include <wjr/type_traits.hpp>

Expand Down
156 changes: 148 additions & 8 deletions include/wjr/math/div.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ WJR_NOINLINE WJR_CONSTEXPR20 T div_qr_1_without_shift(T *dst, T &rem, const T *s
}

template <typename T>
WJR_NOINLINE WJR_CONSTEXPR20 T div_qr_1_with_shift(T *dst, T &rem, const T *src, size_t n,
const div2by1_divider<T> &div) {
WJR_CONSTEXPR20 T div_qr_1_with_shift(T *dst, T &rem, const T *src, size_t n,
const div2by1_divider<T> &div) {
WJR_ASSERT_ASSUME(n >= 1);
WJR_ASSERT(div.get_shift() != 0);
WJR_ASSERT(WJR_IS_SAME_OR_DECR_P(dst, n, src, n));
Expand Down Expand Up @@ -157,9 +157,8 @@ WJR_INTRINSIC_CONSTEXPR20 void div_qr_1(T *dst, T &rem, const T *src, size_t n,
}

template <typename T>
WJR_NOINLINE WJR_CONSTEXPR20 T div_qr_2_without_shift(T *dst, T *rem, const T *src,
size_t n,
const div3by2_divider<T> &div) {
WJR_CONSTEXPR20 T div_qr_2_without_shift(T *dst, T *rem, const T *src, size_t n,
const div3by2_divider<T> &div) {
WJR_ASSERT_ASSUME(n >= 2);
WJR_ASSERT(WJR_IS_SAME_OR_DECR_P(dst, n, src, n));

Expand Down Expand Up @@ -199,8 +198,8 @@ WJR_NOINLINE WJR_CONSTEXPR20 T div_qr_2_without_shift(T *dst, T *rem, const T *s
}

template <typename T>
WJR_NOINLINE WJR_CONSTEXPR20 T div_qr_2_with_shift(T *dst, T *rem, const T *src, size_t n,
const div3by2_divider<T> &div) {
WJR_CONSTEXPR20 T div_qr_2_with_shift(T *dst, T *rem, const T *src, size_t n,
const div3by2_divider<T> &div) {
WJR_ASSERT_ASSUME(n >= 2);
WJR_ASSERT(div.get_shift() != 0);
WJR_ASSERT(WJR_IS_SAME_OR_DECR_P(dst, n, src, n));
Expand Down Expand Up @@ -561,10 +560,151 @@ WJR_INTRINSIC_CONSTEXPR20 void div_qr_s(T *dst, T *rem, const T *src, size_t n,
return div_qr_2(dst, rem, src, n, div);
}
default: {
WJR_UNREACHABLE();
break;
}
}

unsigned int adjust = src[n - 1] >= div[m - 1];
if (n + adjust >= 2 * m) {
T *sp;
T *dp;

const auto shift = clz(div[m - 1]);
const size_t alloc = n + 1 + (shift != 0 ? m : 0);
unique_stack_ptr ptr(math_details::stack_alloc, sizeof(T) * alloc);
auto stk = static_cast<T *>(ptr.get());
sp = stk;

if (shift != 0) {
dp = stk + (n + 1);
lshift_n(dp, div, m, shift);
sp[n] = lshift_n(sp, src, n, shift);
} else {
dp = div;
std::copy_n(src, n, sp);
sp[n] = 0;
}

n += adjust;

const auto dinv = div3by2_divider<T>::reciprocal(dp[m - 2], dp[m - 1]);

if (m < dc_div_qr_threshold) {
sb_div_qr_s(dst, sp, n, dp, m, dinv);
} else {
dc_div_qr_s(dst, sp, n, dp, m, dinv);
}

rshift_n(rem, sp, m, shift);
return;
}

// 2 * m > n + adjust

auto qn = n - m;
dst[qn] = 0;
qn += adjust;

if (qn == 0) {
std::copy_n(src, m, rem);
return;
}

T *sp, *dp;
size_t st;

st = m - qn; // st = m - qn = 2 * m - n + adjust > 2 * adjust

const auto shift = clz(div[m - 1]);

const size_t alloc = 2 * qn + (shift != 0 ? qn : 0);
unique_stack_ptr ptr(math_details::stack_alloc, sizeof(T) * alloc);
auto stk = static_cast<T *>(ptr.get());
sp = stk;

if (shift != 0) {
dp = stk + 2 * qn;
lshift_n(dp, div + st, qn, shift, div[st - 1]);
if (adjust) {
sp[2 * qn - 1] =
lshift_n(sp, src + n - 2 * qn + 1, 2 * qn - 1, shift, src[n - 2 * qn]);
} else {
lshift_n(sp, src + n - 2 * qn, 2 * qn, shift, src[n - 2 * qn - 1]);
}
} else {
dp = div + st;
if (adjust) {
std::copy_n(src + n - 2 * qn + 1, 2 * qn - 1, sp);
sp[2 * qn - 1] = 0;
} else {
std::copy_n(src + n - 2 * qn, 2 * qn, sp);
}
}

if (qn == 1) {
const auto dinv = div2by1_divider<T>::reciprocal(dp[0]);
auto hi = sp[1];
dst[0] = div2by1_divider<T>::divide(dp[0], dinv, sp[0], hi);
sp[0] = hi;
} else if (qn == 2) {
const auto lo = dp[0];
const auto hi = dp[1];
const auto dinv = div3by2_divider<T>::reciprocal(lo, hi);
div_qr_2_without_shift(dst, sp, sp, 4, div3by2_divider<T>(lo, hi, dinv, 0u));
} else {
const auto lo = dp[qn - 2];
const auto hi = dp[qn - 1];
const auto dinv = div3by2_divider<T>::reciprocal(lo, hi);
if (qn < dc_div_qr_threshold) {
sb_div_qr_s(dst, sp, 2 * qn, dp, qn, dinv);
} else {
dc_div_qr_s(dst, sp, 2 * qn, dp, qn, dinv);
}
}

WJR_ASSUME(st >= 1);

T fix = rshift_n(sp, sp, qn, shift);

unique_stack_ptr ptr2(math_details::stack_alloc, sizeof(T) * m);
auto stk2 = static_cast<T *>(ptr2.get());
auto rp = stk2;

unsigned int cf;

if (!shift) {
if (qn >= st) {
mul_s(rp, dst, qn, div, st);
} else {
mul_s(rp, div, st, dst, qn);
}
cf = subc_n(rem, src, rp, st);
} else {
constexpr auto digits = std::numeric_limits<T>::digits;
T mask = (1ull << (digits - shift)) - 1;
if (st != 1) {
if (qn >= st - 1) {
mul_s(rp, dst, qn, div, st - 1);
} else {
mul_s(rp, div, st - 1, dst, qn);
}
rp[m - 1] = addmul_1(rp + st - 1, dst, qn, div[st - 1] & mask);
cf = subc_n(rem, src, rp, st - 1);
rem[st - 1] = subc((src[st - 1] & mask) | fix, rp[st - 1], cf, cf);
} else {
rp[m - 1] = mul_1(rp, dst, qn, div[0] & mask);
rem[0] = subc((src[0] & mask) | fix, rp[0], 0u, cf);
}
}

cf = subc_n(rem + st, sp, rp + st, qn, cf);

while (cf != 0) {
subc_1(dst, dst, qn, 1u);
cf -= addc_n(rem, rem, div, m);
}

return;
}

template <typename T>
Expand Down
3 changes: 1 addition & 2 deletions include/wjr/math/mul.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
#include <wjr/math/div-impl.hpp>

#include <wjr/math/add.hpp>
#include <wjr/math/clz.hpp>
#include <wjr/math/compare.hpp>
#include <wjr/math/ctz.hpp>
#include <wjr/math/shift.hpp>
#include <wjr/math/sub.hpp>

Expand Down
1 change: 1 addition & 0 deletions include/wjr/math/replace.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

namespace wjr {

// maybe deprected
// size_t ret = find_not_n(src, from, n);
// set_n(dst, to, ret)
// return ret;
Expand Down
1 change: 0 additions & 1 deletion include/wjr/math/sub.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,6 @@ WJR_INTRINSIC_CONSTEXPR_E ssize_t abs_subc_s(T *dst, const T *src0, size_t n,

auto cf = subc_s(dst, src0, m + idx, src1, m);
WJR_ASSERT(cf == 0);
(void)(cf);

ssize_t ret = m + idx;
WJR_ASSUME(ret > 0);
Expand Down
1 change: 1 addition & 0 deletions include/wjr/preprocessor/arithmatic/add.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define WJR_PREPROCESSOR_ARITHMATIC_ADD_HPP__

#include <wjr/preprocessor/arithmatic/basic.hpp>
#include <wjr/preprocessor/details/basic.hpp>

#define WJR_PP_ADD(x, y) WJR_PP_ADD_I(x, y)
#define WJR_PP_ADD_I(x, y) \
Expand Down
2 changes: 0 additions & 2 deletions include/wjr/preprocessor/arithmatic/basic.hpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
#ifndef WJR_PREPROCESSOR_ARITHMATIC_BASIC_HPP__
#define WJR_PREPROCESSOR_ARITHMATIC_BASIC_HPP__

#include <wjr/preprocessor/details/basic.hpp>

#define WJR_PP_ARITHMATIC_FROM_NUMBER(x) WJR_PP_ARITHMATIC_FROM_NUMBER_I(x)
#define WJR_PP_ARITHMATIC_FROM_NUMBER_I(x) WJR_PP_ARITHMATIC_FROM_NUMBER_##x

Expand Down
1 change: 1 addition & 0 deletions include/wjr/preprocessor/arithmatic/cmp.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <wjr/preprocessor/arithmatic/basic.hpp>
#include <wjr/preprocessor/arithmatic/inc.hpp>
#include <wjr/preprocessor/arithmatic/neg.hpp>
#include <wjr/preprocessor/details/basic.hpp>
#include <wjr/preprocessor/logical/basic.hpp>
#include <wjr/preprocessor/logical/bool.hpp>

Expand Down
Loading

0 comments on commit 8ea7070

Please sign in to comment.