Skip to content

Commit

Permalink
start SIMD
Browse files Browse the repository at this point in the history
  • Loading branch information
wjr-z committed Aug 5, 2024
1 parent 44ee874 commit efc6984
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 17 deletions.
6 changes: 6 additions & 0 deletions include/wjr/simd/detail.hpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
#ifndef WJR_SIMD_DETAIL_HPP__
#define WJR_SIMD_DETAIL_HPP__

#include <wjr/simd/simd_cast.hpp>
#include <wjr/simd/simd_mask.hpp>

namespace wjr {

namespace simd_abi {
Expand All @@ -19,6 +22,9 @@ inline constexpr vector_aligned_t vector_aligned{};
template <typename T, typename Abi>
class simd;

template <typename T, size_t N>
using fixed_size_simd = simd<T, simd_abi::fixed_size<N>>;

} // namespace wjr

#endif // WJR_SIMD_DETAIL_HPP__
67 changes: 63 additions & 4 deletions include/wjr/simd/simd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ class simd<T, simd_abi::fixed_size<N>> {
static constexpr size_t BitWidth = sizeof(T) * 8 * N;
using int_type = uint_t<BitWidth>;

static_assert(std::is_unsigned_v<T>, "");
static_assert(N >= 2, "");

public:
using mask_type = simd_detail::basic_simd_mask<T, N, BitWidth>;

Expand All @@ -31,23 +34,79 @@ class simd<T, simd_abi::fixed_size<N>> {
}

template <typename Flags = element_aligned_t>
void copy_from(const T *mem, Flags flags = {}) noexcept {
void copy_from(const T *mem, Flags = {}) noexcept {
m_data = read_memory<int_type>(mem);
}

template <typename Flags = element_aligned_t>
void copy_to(T *mem, Flags flags = {}) noexcept {
void copy_to(T *mem, Flags = {}) noexcept {
write_memory<int_type>(mem, m_data);
}

friend constexpr mask_type operator==(const simd &lhs, const simd &rhs) noexcept {
return lhs.m_data ^ rhs.m_dat;
constexpr simd &operator&=(const simd &other) noexcept {
m_data &= other.m_data;
return *this;
}

friend constexpr simd operator&(const simd &lhs, const simd &rhs) noexcept {
simd ret(lhs);
ret &= rhs;
return ret;
}

constexpr simd &operator|=(const simd &other) noexcept {
m_data |= other.m_data;
return *this;
}

friend constexpr simd operator|(const simd &lhs, const simd &rhs) noexcept {
simd ret(lhs);
ret |= rhs;
return ret;
}

constexpr simd &operator^=(const simd &other) noexcept {
m_data ^= other.m_data;
return *this;
}

friend constexpr simd operator^(const simd &lhs, const simd &rhs) noexcept {
simd ret(lhs);
ret ^= rhs;
return ret;
}

/// this is slow.
// friend constexpr mask_type operator==(const simd &lhs, const simd &rhs) noexcept {
// return ~(lhs.m_data ^ rhs.m_data);
// }

private:
int_type m_data;
};

template <size_t N>
struct is_native_simd_bit : std::false_type {};

template <>
struct is_native_simd_bit<8> : std::true_type {};
template <>
struct is_native_simd_bit<16> : std::true_type {};
template <>
struct is_native_simd_bit<32> : std::true_type {};
template <>
struct is_native_simd_bit<64> : std::true_type {};

#if WJR_HAS_SIMD(NATIVE_128BIT)
template <>
struct is_native_simd_bit<128> : std::true_type {};
#endif

#if WJR_HAS_SIMD(NATIVE_256BIT)
template <>
struct is_native_simd_bit<256> : std::true_type {};
#endif

} // namespace wjr

#endif // WJR_SIMD_SIMD_HPP__
26 changes: 13 additions & 13 deletions include/wjr/simd/simd_mask.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
#define WJR_SIMD_SIMD_MASK_HPP__

#include <wjr/assert.hpp>
#include <wjr/math/clz.hpp>
#include <wjr/math/ctz.hpp>
#include <wjr/type_traits.hpp>

namespace wjr::simd_detail {
Expand All @@ -10,42 +12,40 @@ template <typename T, size_t Size, size_t BitWidth>
class basic_simd_mask {
using mask_type = uint_t<BitWidth>;
constexpr static size_t __mask_bits = BitWidth / Size;
constexpr static mask_type __half_mask =
static_cast<uint_t<BitWidth / 2>>(in_place_max);
constexpr static mask_type __full_mask = in_place_max;

public:
WJR_ENABLE_DEFAULT_SPECIAL_MEMBERS(basic_simd_mask);

constexpr basic_simd_mask(mask_type mask) noexcept : m_mask(mask) {}

constexpr int clz() const noexcept {
WJR_PURE WJR_CONSTEXPR20 int clz() const noexcept {
WJR_ASSERT_ASSUME(m_mask != 0);

if constexpr (Size == 2) {
constexpr auto high_mask =
static_cast<mask_type>(static_cast<uint_t<BitWidth / 2>>(in_place_max))
<< (BitWidth / 2);
constexpr auto high_mask = __half_mask << (BitWidth / 2);

return (m_mask & high_mask) ? 0 : 1;
} else {
return clz(m_mask) / __mask_bits;
return ::wjr::clz(m_mask) / __mask_bits;
}
}

template <typename U, WJR_REQUIRES(is_nonbool_integral_v<U> && sizeof(U) < sizeof(T))>
constexpr int ctz(U) const noexcept {
WJR_PURE WJR_CONSTEXPR20 int ctz() const noexcept {
WJR_ASSERT_ASSUME(m_mask != 0);

constexpr size_t elements = sizeof(T) / sizeof(U);
if constexpr (elements == 2) {
constexpr auto low_mask =
static_cast<mask_type>(static_cast<uint_t<BitWidth / 2>>(in_place_max));
if constexpr (Size == 2) {
constexpr auto low_mask = __half_mask;

return (m_mask & low_mask) ? 0 : 1;
} else {
return ctz(m_mask) / __mask_bits;
return ::wjr::ctz(m_mask) / __mask_bits;
}
}

constexpr bool any() const noexcept { return m_mask != 0; }
WJR_PURE constexpr bool all() const noexcept { return m_mask == __full_mask; }

private:
mask_type m_mask;
Expand Down

0 comments on commit efc6984

Please sign in to comment.