Skip to content

Commit

Permalink
opt math
Browse files Browse the repository at this point in the history
  • Loading branch information
wjr-z committed Dec 21, 2023
1 parent cde193f commit 806b29d
Show file tree
Hide file tree
Showing 13 changed files with 577 additions and 53 deletions.
2 changes: 1 addition & 1 deletion examples/demo/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ endif()
if (MSVC)
set(CMAKE_CXX_FLAGS "$ENV{CXXFLAGS} /O2 /std:c++17 /Zc:preprocessor")
else()
set(CMAKE_CXX_FLAGS "$ENV{CXXFLAGS} -O2 -std=c++17")
set(CMAKE_CXX_FLAGS "$ENV{CXXFLAGS} -O2 -std=c++17 -g")
endif()

set(WJR_INCLUDE_DIR ${PROJECT_SOURCE_DIR}/../../include)
Expand Down
13 changes: 8 additions & 5 deletions examples/demo/src/main.cpp
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
#include <chrono>
#include <iostream>
#define WJR_DEBUG_LEVEL 3
#include <random>
#include <wjr/atomic.hpp>
#include <wjr/math/add.hpp>
#include <wjr/math/sub.hpp>
#include <wjr/preprocessor.hpp>

extern auto foo(uint64_t a, uint64_t b, uint64_t c, uint64_t &d) {
return wjr::addc(a, b, 1, d);
}

int main() {

return 0;
}
int main() { return 0; }
17 changes: 0 additions & 17 deletions include/wjr/asm/adc.hpp

This file was deleted.

3 changes: 1 addition & 2 deletions include/wjr/atomic.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@
#define WJR_ATOMIC_CONCAT(x) WJR_ATOMIC_CONCAT_I x
#define WJR_ATOMIC_CONCAT_I(a, b) a b

#define WJR_ATOMIC_VERIFY(veris) \
WJR_PP_QUEUE_PUT(WJR_PP_QUEUE_TRANSFORM(veris, WJR_ATOMIC_VERIFY_IMPL))
#define WJR_ATOMIC_VERIFY(veris) WJR_PP_TRANSFORM_PUT(veris, WJR_ATOMIC_VERIFY_IMPL)
#define WJR_ATOMIC_VERIFY_IMPL(ptr) \
WJR_ASSERT_L(1, \
reinterpret_cast<::wjr::uintptr_t>(ptr) % \
Expand Down
3 changes: 3 additions & 0 deletions include/wjr/math.hpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
#ifndef WJR_MATH_HPP__
#define WJR_MATH_HPP__

#include <wjr/math/add.hpp>
#include <wjr/math/sub.hpp>

#endif // WJR_MATH_HPP__
196 changes: 196 additions & 0 deletions include/wjr/math/add.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
#ifndef WJR_MATH_ADD_HPP__
#define WJR_MATH_ADD_HPP__

#include <iostream>
#include <wjr/type_traits.hpp>

namespace wjr {

template <typename T, typename U>
WJR_INTRINSIC_CONSTEXPR T fallback_addc(T a, T b, U c_in, U &c_out) {
T ret = a;
U c = 0;
ret += b;
c = ret < b;
ret += c_in;
c |= ret < c_in;
c_out = c;
return ret;
}

#if WJR_HAS_BUILTIN(__builtin_addc) || WJR_HAS_CLANG(5, 0, 0)
#define WJR_HAS_BUILTIN_ADDC WJR_HAS_DEF
#endif

#if WJR_HAS_FEATURE(INLINE_ASM) && defined(WJR_X86) && \
(defined(WJR_COMPILER_CLANG) || defined(WJR_COMPILER_GCC))
#define WJR_HAS_BUILTIN_ASM_ADDC WJR_HAS_DEF
#endif

#if WJR_HAS_BUILTIN(ASM_ADDC)

template <typename T, typename U>
WJR_INTRINSIC_INLINE T asm_addc_1(T a, T b, U &c_out) {
constexpr auto nd = std::numeric_limits<T>::digits;

#define WJR_REGISTER_BUILTIN_ASM_ADDC(args) \
WJR_PP_TRANSFORM_PUT(args, WJR_REGISTER_BUILTIN_ASM_ADDC_I_CALLER)
#define WJR_REGISTER_BUILTIN_ASM_ADDC_I(suffix, type) \
if constexpr (nd == std::numeric_limits<type>::digits) { \
unsigned char cf = 0; \
asm volatile("stc\n\t" \
"adc{" #suffix " %2, %0| %0, %2}\n\t" \
"setb %b1\n\t" \
: "=r"(a), "=r"(cf) \
: "%r"(b), "0"(a) \
: "cc"); \
c_out = cf; \
return a; \
} else
#define WJR_REGISTER_BUILTIN_ASM_ADDC_I_CALLER(args) WJR_REGISTER_BUILTIN_ASM_ADDC_I args

WJR_REGISTER_BUILTIN_ASM_ADDC(
((b, uint8_t), (w, uint16_t), (l, uint32_t), (q, uint64_t))) {
static_assert(nd <= 64, "not supported yet");
}

#undef WJR_REGISTER_BUILTIN_ASM_ADDC_I_CALLER
#undef WJR_REGISTER_BUILTIN_ASM_ADDC_I
#undef WJR_REGISTER_BUILTIN_ASM_ADDC
}

#endif // WJR_HAS_BUILTIN(ASM_ADDC)

#if WJR_HAS_BUILTIN(ADDC)

template <typename T, typename U>
WJR_INTRINSIC_INLINE T builtin_addc(T a, T b, U c_in, U &c_out) {
constexpr auto nd = std::numeric_limits<T>::digits;

#define WJR_REGISTER_BUILTIN_ADDC(args) \
WJR_PP_TRANSFORM_PUT(args, WJR_REGISTER_BUILTIN_ADDC_I_CALLER)
#define WJR_REGISTER_BUILTIN_ADDC_I(suffix, type) \
if constexpr (nd <= std::numeric_limits<type>::digits) { \
type __c_out = 0; \
T ret = __builtin_addc##suffix(a, b, static_cast<type>(c_in), &__c_out); \
c_out = static_cast<U>(__c_out); \
return ret; \
} else
#define WJR_REGISTER_BUILTIN_ADDC_I_CALLER(args) WJR_REGISTER_BUILTIN_ADDC_I args

WJR_REGISTER_BUILTIN_ADDC(((b, unsigned char), (s, unsigned short), (, unsigned int),
(l, unsigned long), (ll, unsigned long long))) {
static_assert(nd <= 64, "not supported yet");
}

#undef WJR_REGISTER_BUILTIN_ADDC_I_CALLER
#undef WJR_REGISTER_BUILTIN_ADDC_I
#undef WJR_REGISTER_BUILTIN_ADDC
}

#endif // WJR_HAS_BUILTIN(ADDC)

template <typename T, typename U>
WJR_INTRINSIC_CONSTEXPR T addc(T a, T b, type_identity_t<U> c_in, U &c_out) {
WJR_ASSERT_L(1, c_in == 0 || c_in == 1);
WJR_ASSUME((c_in == 0 || c_in == 1));

#if !WJR_HAS_BUILTIN(ADDC) && !WJR_HAS_BUILTIN(ASM_ADDC)
return fallback_addc(a, b, c_in, c_out);
#else
constexpr auto is_constant_or_zero = [](const auto &x) -> int {
return is_constant_p(x) ? (x == 0 ? 2 : 1) : 0;
};

// The compiler should be able to optimize the judgment condition of if when enabling
// optimization. If it doesn't work, then there should be a issue
if (is_constant_evaluated() ||
// constant value is zero or constant value number greater or equal than 2
(is_constant_or_zero(a) + is_constant_or_zero(b) + is_constant_or_zero(c_in) >=
2)) {
return fallback_addc(a, b, c_in, c_out);
}

#if WJR_HAS_BUILTIN(ASM_ADDC)
if (is_constant_p(c_in) && c_in == 1) {
return asm_addc_1(a, b, c_out);
}
#endif

#if WJR_HAS_BUILTIN(ADDC)
return builtin_addc(a, b, c_in, c_out);
#else
return fallback_addc(a, b, c_in, c_out);
#endif // WJR_HAS_BUILTIN(ADDC)

#endif
}

template <size_t div, typename T, typename U>
WJR_INTRINSIC_CONSTEXPR U addc_n_res(const T *src0, const T *src1, T *dst, U c_in,
size_t n) {

constexpr size_t mask = div - 1;

n &= mask;

if (WJR_UNLIKELY(n == 0)) {
return c_in;
}

src0 += n;
src1 += n;
dst += n;

#define WJR_REGISTER_ADDC_RES_CASE_CALLER(idx) \
case idx: { \
dst[-idx] = addc(src0[-idx], src1[-idx], c_in, c_in); \
WJR_FALLTHROUGH; \
}

#define WJR_REGISTER_ADDC_RES_SWITCH_CALLER(size) \
if constexpr (div == size) { \
switch (n) { \
WJR_PP_TRANSFORM_PUT( \
WJR_PP_QUEUE_TRANSFORM( \
WJR_PP_QUEUE_REVERSE((WJR_PP_IOTA(WJR_PP_DEC(size)))), WJR_PP_INC), \
WJR_REGISTER_ADDC_RES_CASE_CALLER) \
} \
return c_in; \
} else

WJR_REGISTER_ADDC_RES_SWITCH_CALLER(2)
WJR_REGISTER_ADDC_RES_SWITCH_CALLER(4) WJR_REGISTER_ADDC_RES_SWITCH_CALLER(8) {
static_assert(div <= 8, "not support yet");
}

#undef WJR_REGISTER_ADDC_RES_SWITCH_CALLER
#undef WJR_REGISTER_ADDC_RES_CASE_CALLER
} // namespace wjr

template <typename T, typename U>
WJR_INTRINSIC_CONSTEXPR U addc_n_impl(const T *src0, const T *src1, T *dst, U c_in,
size_t n) {
size_t m = n / 4;
for (size_t i = 0; i < m; ++i) {
dst[0] = addc(src0[0], src1[0], c_in, c_in);
dst[1] = addc(src0[1], src1[1], c_in, c_in);
dst[2] = addc(src0[2], src1[2], c_in, c_in);
dst[3] = addc(src0[3], src1[3], c_in, c_in);

src0 += 4;
src1 += 4;
dst += 4;
}

return addc_n_res<4>(src0, src1, dst, c_in, n);
}

template <typename T, typename U>
WJR_INTRINSIC_CONSTEXPR U addc_n(const T *src0, const T *src1, T *dst, U c_in, size_t n) {
return addc_n_impl(src0, src1, dst, c_in, n);
}

} // namespace wjr

#endif // WJR_MATH_ADD_HPP__
Loading

0 comments on commit 806b29d

Please sign in to comment.