Skip to content

Commit

Permalink
Moved swap_pair into its callable
Browse files Browse the repository at this point in the history
  • Loading branch information
SadiinsoSnowfall authored Jan 14, 2025
1 parent a7cfc03 commit 77f3c9a
Show file tree
Hide file tree
Showing 11 changed files with 195 additions and 93 deletions.
2 changes: 1 addition & 1 deletion include/eve/module/core/detail/simd/x86/basic_shuffle.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ basic_shuffle_(EVE_SUPPORTS(avx_),
[](auto i, auto c)
{
Pattern r;
return (i < c / 2 ? r(i, c) : r(i - c / 2, c)) << 1;
return (r(i,c) % 2) << 1;
});

auto const m = as_indexes<wide<T, N>>(fixed_pattern);
Expand Down
6 changes: 3 additions & 3 deletions include/eve/module/core/regular/bit_shr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@ namespace eve
template<typename Options>
struct bit_shr_t : strict_elementwise_callable<bit_shr_t, Options>
{
template<eve::value T, integral_value S>
constexpr EVE_FORCEINLINE as_wide_as_t<T, S> operator()(T v, S s) const
template<integral_value T, integral_value S>
constexpr EVE_FORCEINLINE as_wide_as_t<T, S> operator()(T v, S s) const
{
return EVE_DISPATCH_CALL(v, s);
}

template<eve::integral_value T, std::ptrdiff_t S>
template<integral_value T, std::ptrdiff_t S>
constexpr EVE_FORCEINLINE T operator()(T v, index_t<S> s) const
{
constexpr std::ptrdiff_t l = sizeof(element_type_t<T>) * 8;
Expand Down
45 changes: 38 additions & 7 deletions include/eve/module/core/regular/bit_swap_pairs.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,34 @@
#include <eve/module/core/regular/bit_cast.hpp>
#include <eve/module/core/regular/bit_xor.hpp>
#include <eve/module/core/regular/bit_shl.hpp>
#include <eve/module/core/regular/convert.hpp>
#include <eve/module/core/constant/one.hpp>
#include <eve/traits/max_lanes.hpp>

namespace eve
{

template<typename Options>
struct bit_swap_pairs_t : strict_elementwise_callable<bit_swap_pairs_t, Options>
{
template<eve::integral_value T, integral_value I0, integral_value I1>
constexpr EVE_FORCEINLINE T operator()(T v, I0 i0, I1 i1) const noexcept
template<typename T, typename I0, typename I1>
struct result
{
using type = std::conditional_t<scalar_value<T> && scalar_value<I0> && scalar_value<I1>, T, as_wide_t<T, max_lanes_t<T, I0, I1>>>;
};

template<integral_value T, integral_value I0, integral_value I1>
EVE_FORCEINLINE constexpr typename result<T, I0, I1>::type operator()(T v, I0 i0, I1 i1) const noexcept
requires same_lanes_or_scalar<T, I0, I1>
{
return EVE_DISPATCH_CALL(v, i0, i1);
}

template<integral_value T, std::ptrdiff_t I0, std::ptrdiff_t I1>
EVE_FORCEINLINE constexpr T operator()(T a, index_t<I0> i0, index_t<I1> i1) const noexcept
{
return EVE_DISPATCH_CALL(a, i0, i1);
}

EVE_CALLABLE_OBJECT(bit_swap_pairs_t, bit_swap_pairs_);
};
Expand Down Expand Up @@ -83,16 +98,24 @@ namespace eve
namespace detail
{
template<callable_options O, conditional_expr C, value T, integral_value I0, integral_value I1>
constexpr T bit_swap_pairs_(EVE_REQUIRES(cpu_), C const& cx, O const&, T a, I0 i0, I1 i1) noexcept
constexpr auto bit_swap_pairs_(EVE_REQUIRES(cpu_), C const& cx, O const&, T a, I0 i0, I1 i1) noexcept
{
auto i0m = if_else(cx, i0, zero);
auto i1m = if_else(cx, i1, zero);
if constexpr (scalar_value<T> && scalar_value<I0> && scalar_value<I1>)
{
return bit_swap_pairs(a, cx ? i0 : 0, cx ? i1 : 0);
}
else
{
using MC = max_lanes_t<T, I0, I1>;

return bit_swap_pairs(a, i0m, i1m);
auto i0m = if_else(cx, as_wide_t<I0, MC>{i0}, zero);
auto i1m = if_else(cx, as_wide_t<I1, MC>{i1}, zero);
return bit_swap_pairs(as_wide_t<T, MC>{a}, i0m, i1m);
}
}

template<callable_options O, value T, integral_value I0, integral_value I1>
constexpr T bit_swap_pairs_(EVE_REQUIRES(cpu_), O const&, T a, I0 i0, I1 i1) noexcept
constexpr auto bit_swap_pairs_(EVE_REQUIRES(cpu_), O const&, T a, I0 i0, I1 i1) noexcept
{
// 1 if the bits of a at i0 and i1 are different, 0 otherwise
auto x = bit_and(
Expand All @@ -106,5 +129,13 @@ namespace eve
// if the bits are different, swap them by toggling both
return bit_xor(a, bit_shl(x, i1), bit_shl(x, i0));
}

template<callable_options O, typename T, std::ptrdiff_t I0, std::ptrdiff_t I1>
EVE_FORCEINLINE T bit_swap_pairs_(EVE_REQUIRES(cpu_), O const& o, T x, index_t<I0>, index_t<I1>) noexcept
{
constexpr std::ptrdiff_t C = sizeof(element_type_t<T>) * 8;
static_assert((I0 >= 0) && (I1 >= 0) && (I0 < C) && (I1 < C), "some index(es) are out or range");
return bit_swap_pairs[o](x, I0, I1);
}
}
}
6 changes: 4 additions & 2 deletions include/eve/module/core/regular/byte_swap_pairs.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@ namespace eve
struct byte_swap_pairs_t : strict_elementwise_callable<byte_swap_pairs_t, Options>
{
template<integral_value T, std::ptrdiff_t I0, std::ptrdiff_t I1>
EVE_FORCEINLINE T operator()(T a, index_t<I0> const & i0, index_t<I1> const & i1) const noexcept
{ return EVE_DISPATCH_CALL(a, i0, i1); }
EVE_FORCEINLINE T operator()(T a, index_t<I0> i0, index_t<I1> i1) const noexcept
{
return EVE_DISPATCH_CALL(a, i0, i1);
}

EVE_CALLABLE_OBJECT(byte_swap_pairs_t, byte_swap_pairs_);
};
Expand Down
13 changes: 2 additions & 11 deletions include/eve/module/core/regular/impl/byte_swap_pairs.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,7 @@
namespace eve::detail
{
template<typename T, std::ptrdiff_t I0, std::ptrdiff_t I1, callable_options O>
EVE_FORCEINLINE T byte_swap_pairs_(EVE_REQUIRES(cpu_),
O const &,
T x ,
index_t<I0> const & ,
index_t<I1> const &) noexcept
EVE_FORCEINLINE T byte_swap_pairs_(EVE_REQUIRES(cpu_), O const &, T x, index_t<I0>, index_t<I1>) noexcept
{
if constexpr(simd_value<T>)
{
Expand Down Expand Up @@ -54,12 +50,7 @@ namespace eve::detail

// Masked case
template<conditional_expr C, typename T, std::ptrdiff_t I0, std::ptrdiff_t I1, callable_options O>
EVE_FORCEINLINE T byte_swap_pairs_(EVE_REQUIRES(cpu_),
C const& cond,
O const &,
T t,
index_t<I0> const & i0,
index_t<I1> const & i1) noexcept
EVE_FORCEINLINE T byte_swap_pairs_(EVE_REQUIRES(cpu_), C const& cond, O const&, T t, index_t<I0> i0, index_t<I1> i1) noexcept
{
return mask_op(cond, eve::byte_swap_pairs, t, i0, i1);
}
Expand Down
28 changes: 10 additions & 18 deletions include/eve/module/core/regular/impl/swap_pairs.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,24 +13,16 @@

namespace eve::detail
{
template<value T, std::ptrdiff_t I0, std::ptrdiff_t I1>
EVE_FORCEINLINE T
swap_pairs_(EVE_SUPPORTS(cpu_), T x
, index_t<I0> const &
, index_t<I1> const & ) noexcept
template<callable_options O, simd_value T, std::ptrdiff_t I0, std::ptrdiff_t I1>
EVE_FORCEINLINE constexpr T swap_pairs_(EVE_REQUIRES(cpu_), O const&, T x, index_t<I0>, index_t<I1>) noexcept
{
[[maybe_unused]] constexpr std::ptrdiff_t C = scalar_value<T> ? 1 : cardinal_v<T>;
EVE_ASSERT((I0 >= 0) && (I1 >= 0) && (I0 < C) && (I1 < C), "some index(es) are out or range");
if constexpr(simd_value<T>)
{
auto p = [](auto i, auto){
return (i == I0) ? I1 :(i == I1 ? I0 : i) ;
};
return eve::shuffle(x, eve::as_pattern(p));
}
else if constexpr(scalar_value<T>)
{
return x;
}
constexpr std::ptrdiff_t C = cardinal_v<T>;
static_assert((I0 >= 0) && (I1 >= 0) && (I0 < C) && (I1 < C), "some index(es) are out or range");

auto p = [](auto i, auto){
return (i == I0) ? I1 :(i == I1 ? I0 : i) ;
};

return eve::shuffle(x, eve::as_pattern(p));
}
}
92 changes: 51 additions & 41 deletions include/eve/module/core/regular/swap_pairs.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,47 +11,57 @@

namespace eve
{
//TODO DOC
//================================================================================================
//! @addtogroup core_bitops
//! @{
//! @var swap_pairs
//! @brief swap chosen pair of elements.
//!
//! @groupheader{Header file}
//!
//! @code
//! #include <eve/module/core.hpp>
//! @endcode
//!
//! @groupheader{Callable Signatures}
//!
//! @code
//! namespace eve
//! {
//! template<value T, std::ptrdiff_t I0, std::ptrdiff_t I1 >
//! T swap_pairs(T x, index_t<I0> const & i0, index_t<I1> const & i1);
//! @endcode
//!
//! **Parameters**
//!
//! * `x` : [argument](@ref eve::value).
//! * `i0` : first index
//! * `i1` : second index
//!
//! **Return value**
//!
//! Return x with element i0 and i1 swapped. Action on scalar is identity.
//! Assert if i0 or i1 are out of range.
//!
//! @groupheader{Example}
//!
//! @godbolt{doc/core/swap_pairs.cpp}
//================================================================================================
EVE_MAKE_CALLABLE(swap_pairs_, swap_pairs);
//================================================================================================
//! @}
//================================================================================================
template<typename Options>
struct swap_pairs_t : callable<swap_pairs_t, Options>
{
template<simd_value T, std::ptrdiff_t I0, std::ptrdiff_t I1>
EVE_FORCEINLINE T operator()(T x, index_t<I0> i0, index_t<I1> i1) const noexcept
{
return EVE_DISPATCH_CALL(x, i0, i1);
}

EVE_CALLABLE_OBJECT(swap_pairs_t, swap_pairs_);
};

//================================================================================================
//! @addtogroup core_bitops
//! @{
//! @var swap_pairs
//! @brief swap chosen pair of elements.
//!
//! @groupheader{Header file}
//!
//! @code
//! #include <eve/module/core.hpp>
//! @endcode
//!
//! @groupheader{Callable Signatures}
//!
//! @code
//! namespace eve
//! {
//! template<simd_value T, std::ptrdiff_t I0, std::ptrdiff_t I1 >
//! T swap_pairs(T x, index_t<I0> i0, index_t<I1> i1);
//! @endcode
//!
//! **Parameters**
//!
//! * `x` : [argument](@ref eve::simd_value).
//! * `i0` : first index
//! * `i1` : second index
//!
//! **Return value**
//!
//! Return x with element i0 and i1 swapped.
//!
//! @groupheader{Example}
//!
//! @godbolt{doc/core/swap_pairs.cpp}
//================================================================================================
inline constexpr auto swap_pairs = functor<swap_pairs_t>;
//================================================================================================
//! @}
//================================================================================================
}

#include <eve/module/core/regular/impl/swap_pairs.hpp>
51 changes: 51 additions & 0 deletions include/eve/traits/max_lanes.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
//==================================================================================================
/*
EVE - Expressive Vector Engine
Copyright : EVE Project Contributors
SPDX-License-Identifier: BSL-1.0
*/
//==================================================================================================
#pragma once

namespace eve
{
namespace detail
{
template<typename... Ts>
consteval auto compute_max_lanes()
{
std::ptrdiff_t cards[] = { cardinal_v<Ts>... };

auto max_card = cards[0];
for(auto c : cards) max_card = max_card < c ? c : max_card;

return max_card;
}
}

//================================================================================================
//! @addtogroup traits
//! @{
//! @var max_lanes
//!
//! @tparam Ts Types to process
//!
//! @brief A meta function for getting a maximum lane count of given wide or scalar types.
//! @}
//================================================================================================
template <typename... Ts>
inline constexpr auto max_lanes_v = detail::compute_max_lanes<Ts...>();

//================================================================================================
//! @addtogroup traits
//! @{
//! @var max_lanes
//!
//! @tparam Ts Types to process
//!
//! @brief The cardinal type of the maximum lane count of given wide or scalar types.
//! @}
//================================================================================================
template <typename... Ts>
using max_lanes_t = fixed<max_lanes_v<Ts...>>;
}
14 changes: 12 additions & 2 deletions test/doc/core/bit_swap_pairs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,19 @@ int main()
std::cout << std::showbase << std::hex;
std::cout << "<- wi0 = " << wi0 << "\n";
std::cout << "<- wi1 = " << wi1 << "\n";
std::cout << "<- wi2 = " << wi2 << "\n";
std::cout << "<- wi2 = " << wi2 << "\n\n";

std::cout << "-> bit_swap_pairs(wi0, wi1, wi2) = " << eve::bit_swap_pairs(wi0, wi1, wi2) << "\n";
std::cout << "-> bit_swap_pairs[ignore_last(2)](wi0, wi1, wi2) = " << eve::bit_swap_pairs[eve::ignore_last(2)](wi0, wi1, wi2) << "\n";
std::cout << "-> bit_swap_pairs[wi3 > 0](wi0, wi1, wi3) = " << eve::bit_swap_pairs[wi3 >= 0](wi0, wi1, wi3) << "\n";
std::cout << "-> bit_swap_pairs[wi3 > 0](wi0, wi1, wi3) = " << eve::bit_swap_pairs[wi3 >= 0](wi0, wi1, wi3) << "\n\n";

std::cout << "-> bit_swap_pairs(wi0, 3, 2) = " << eve::bit_swap_pairs(wi0, 3, 2) << "\n";
std::cout << "-> bit_swap_pairs[ignore_last(2)](wi0, 3, 2) = " << eve::bit_swap_pairs[eve::ignore_last(2)](wi0, 3, 2) << "\n";
std::cout << "-> bit_swap_pairs[wi3 > 0](wi0, 3, 2) = " << eve::bit_swap_pairs[wi3 >= 0](wi0, 3, 2) << "\n\n";

auto i3 = eve::index_t<3>{};
auto i2 = eve::index_t<2>{};
std::cout << "-> bit_swap_pairs(wi0, i3, i2) = " << eve::bit_swap_pairs(wi0, i3, i2) << "\n";
std::cout << "-> bit_swap_pairs[ignore_last(2)](wi0, i3, i2) = " << eve::bit_swap_pairs[eve::ignore_last(2)](wi0, i3, i2) << "\n";
std::cout << "-> bit_swap_pairs[wi3 > 0](wi0, i3, i2) = " << eve::bit_swap_pairs[wi3 >= 0](wi0, i3, i2) << "\n\n";
}
17 changes: 15 additions & 2 deletions test/unit/module/core/bit_swap_pairs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,21 @@ TTS_CASE_WITH("Check behavior of bit_swap_pairs(simd) on integral types",
using v_t = eve::element_type_t<T>;
using eve::bit_swap_pairs;

TTS_EQUAL(bit_swap_pairs(a0, 0u, 7u), tts::map([](auto e) -> v_t { return eve::bit_swap_pairs(e, 0u, 7u); }, a0)) << a0 << '\n';
TTS_EQUAL(eve::bit_swap_pairs[t](a0, 0u, 7u), eve::if_else(t, eve::bit_swap_pairs(a0, 0u, 7u), a0));
// full scalar
TTS_EQUAL(eve::bit_swap_pairs(0b01010101, 0, 7), 0b11010100);
TTS_EQUAL(eve::bit_swap_pairs(0b01010101, eve::index<0>, eve::index<7>), 0b11010100);

// scalar into wide
using wt = eve::wide<int, eve::fixed<4>>;
TTS_EQUAL(eve::bit_swap_pairs(0b01010101, wt{0}, wt{7}), wt{0b11010100});

// wide
TTS_EQUAL(bit_swap_pairs(a0, 0u, 7), tts::map([](auto e) -> v_t { return eve::bit_swap_pairs(e, 0, 7u); }, a0)) << a0 << '\n';
TTS_EQUAL(bit_swap_pairs(a0, eve::index<0>, eve::index<7>), tts::map([](auto e) -> v_t { return eve::bit_swap_pairs(e, eve::index<0>, eve::index<7>); }, a0)) << a0 << '\n';

// wide masked
TTS_EQUAL(eve::bit_swap_pairs[t](a0, 0u, 7), eve::if_else(t, eve::bit_swap_pairs(a0, 0, 7u), a0));
TTS_EQUAL(eve::bit_swap_pairs[t](a0, eve::index<0>, eve::index<7>), eve::if_else(t, eve::bit_swap_pairs(a0, eve::index<0>, eve::index<7>), a0));

eve::wide<int, typename T::cardinal_type> wn{[](auto i, auto) { return -i; }};
TTS_EQUAL(eve::bit_swap_pairs[wn > 0](a0, wn, 7), a0);
Expand Down
Loading

0 comments on commit 77f3c9a

Please sign in to comment.