Skip to content

Commit

Permalink
fix: fixed incorrect definition and use of rocprim specific type trai…
Browse files Browse the repository at this point in the history
…ts for 128-bit integers

This fixes an issue where in certain situations where using radix codec on 128-bit integers would not compile due to ambiguity.
  • Loading branch information
Naraenda committed Sep 5, 2024
1 parent 380871b commit 61ed896
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 153 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ Documentation for rocPRIM is available at
### Fixes

* Fixed an issue where `rocprim::partial_sort_copy` would yield a compile error if the input iterator is const.
* Fixed incorrect 128-bit signed and unsigned integers type traits.
* Fixed compilation issue when `rocprim::radix_key_codec<...>` is specialized with a 128-bit integer.

### Deprecations

Expand Down
99 changes: 19 additions & 80 deletions rocprim/include/rocprim/thread/radix_key_codec.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,33 +78,6 @@ struct radix_key_codec_integral<Key,
}
};

template<class Key, class BitKey>
struct radix_key_codec_integral<
Key,
BitKey,
typename std::enable_if<std::is_same<Key, __uint128_t>::value>::type>
{
using bit_key_type = BitKey;

ROCPRIM_HOST_DEVICE static bit_key_type encode(Key key)
{
return ::rocprim::detail::bit_cast<bit_key_type>(key);
}

ROCPRIM_HOST_DEVICE static Key decode(bit_key_type bit_key)
{
return ::rocprim::detail::bit_cast<Key>(bit_key);
}

template<bool Descending>
ROCPRIM_HOST_DEVICE static unsigned int
extract_digit(bit_key_type bit_key, unsigned int start, unsigned int length)
{
unsigned int mask = (1u << length) - 1;
return static_cast<unsigned int>(bit_key >> start) & mask;
}
};

template<class Key, class BitKey>
struct radix_key_codec_integral<Key,
BitKey,
Expand Down Expand Up @@ -135,36 +108,6 @@ struct radix_key_codec_integral<Key,
}
};

template<class Key, class BitKey>
struct radix_key_codec_integral<Key,
BitKey,
typename std::enable_if<std::is_same<Key, __int128_t>::value>::type>
{
using bit_key_type = BitKey;

static constexpr bit_key_type sign_bit = bit_key_type(1) << (sizeof(bit_key_type) * 8 - 1);

ROCPRIM_HOST_DEVICE static bit_key_type encode(Key key)
{
const auto bit_key = ::rocprim::detail::bit_cast<bit_key_type>(key);
return sign_bit ^ bit_key;
}

ROCPRIM_HOST_DEVICE static Key decode(bit_key_type bit_key)
{
bit_key ^= sign_bit;
return ::rocprim::detail::bit_cast<Key>(bit_key);
}

template<bool Descending>
ROCPRIM_HOST_DEVICE static unsigned int
extract_digit(bit_key_type bit_key, unsigned int start, unsigned int length)
{
unsigned int mask = (1u << length) - 1;
return static_cast<unsigned int>(bit_key >> start) & mask;
}
};

template<class Key, class BitKey>
struct radix_key_codec_floating
{
Expand Down Expand Up @@ -222,19 +165,7 @@ struct radix_key_codec_base

template<class Key>
struct radix_key_codec_base<Key, typename std::enable_if<::rocprim::is_integral<Key>::value>::type>
: radix_key_codec_integral<Key, typename std::make_unsigned<Key>::type>
{};

template<class Key>
struct radix_key_codec_base<Key,
typename std::enable_if<std::is_same<Key, __int128_t>::value>::type>
: radix_key_codec_integral<Key, __uint128_t>
{};

template<class Key>
struct radix_key_codec_base<Key,
typename std::enable_if<std::is_same<Key, __uint128_t>::value>::type>
: radix_key_codec_integral<Key, __uint128_t>
: radix_key_codec_integral<Key, typename ::rocprim::make_unsigned<Key>::type>
{};

template<>
Expand Down Expand Up @@ -279,21 +210,29 @@ template<>
struct radix_key_codec_base<double> : radix_key_codec_floating<double, unsigned long long>
{};

template<class T, class = void>
struct radix_key_fundamental
template<class T>
struct has_bit_key_type
{
static constexpr bool value = false;
template<class U>
static std::true_type check(typename U::bit_key_type*);

template<class U>
static std::false_type check(...);

using result = decltype(check<T>(nullptr));
};

template<class T>
struct radix_key_fundamental<
T,
::rocprim::detail::void_t<typename radix_key_codec_base<T>::bit_key_type>>
{
static constexpr bool value = true;
};
using radix_key_fundamental = typename has_bit_key_type<radix_key_codec_base<T>>::result;

} // end namespace detail
static_assert(radix_key_fundamental<int>::value, "'int' should be fundamental");
static_assert(!radix_key_fundamental<int*>::value, "'int*' should not be fundamental");
static_assert(radix_key_fundamental<__int128_t>::value, "'__int128_t' should be fundamental");
static_assert(radix_key_fundamental<__uint128_t>::value, "'__uint128_t' should be fundamental");
static_assert(!radix_key_fundamental<__int128_t*>::value,
"'__int128_t*' should not be fundamental");

} // namespace detail

/// \brief Key encoder, decoder and bit-extractor for radix-based sorts.
///
Expand Down
114 changes: 77 additions & 37 deletions rocprim/include/rocprim/type_traits.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,7 @@

BEGIN_ROCPRIM_NAMESPACE

/// \brief Behaves like std::is_floating_point, but also includes half-precision and bfloat16-precision
/// floating point type (rocprim::half).
/// \brief Extension of `std::is_floating_point`, which includes support for \ref rocprim::half and \ref rocprim::bfloat16.
template<class T>
struct is_floating_point
: std::integral_constant<
Expand All @@ -47,60 +46,95 @@ struct is_floating_point
std::is_same<::rocprim::bfloat16, typename std::remove_cv<T>::type>::value
> {};

/// \brief Alias for std::is_integral.
/// \brief Extension of `std::is_integral`, which includes support for 128-bit integers.
template<class T>
using is_integral = std::is_integral<T>;
struct is_integral
: std::integral_constant<
bool,
std::is_integral<T>::value
|| std::is_same<__int128_t, typename std::remove_cv<T>::type>::value
|| std::is_same<__uint128_t, typename std::remove_cv<T>::type>::value>
{};

/// \brief Behaves like std::is_arithmetic, but also includes half-precision and bfloat16-precision
/// floating point type (\ref rocprim::half).
/// \brief Extension of `std::is_arithmetic`, which includes support for \ref rocprim::half , \ref rocprim::bfloat16 and 128-bit integers.
template<class T>
struct is_arithmetic
: std::integral_constant<
bool,
std::is_arithmetic<T>::value ||
std::is_same<::rocprim::half, typename std::remove_cv<T>::type>::value ||
std::is_same<::rocprim::bfloat16, typename std::remove_cv<T>::type>::value
> {};
bool,
std::is_arithmetic<T>::value
|| std::is_same<::rocprim::half, typename std::remove_cv<T>::type>::value
|| std::is_same<::rocprim::bfloat16, typename std::remove_cv<T>::type>::value
|| std::is_same<__int128_t, typename std::remove_cv<T>::type>::value
|| std::is_same<__uint128_t, typename std::remove_cv<T>::type>::value>
{};

/// \brief Behaves like std::is_fundamental, but also includes half-precision and bfloat16-precision
/// floating point type (\ref rocprim::half).
/// \brief Extension of `std::is_fundamental`, which includes support for \ref rocprim::half , \ref rocprim::bfloat16 and 128-bit integers.
template<class T>
struct is_fundamental
: std::integral_constant<
bool,
std::is_fundamental<T>::value ||
std::is_same<::rocprim::half, typename std::remove_cv<T>::type>::value ||
std::is_same<::rocprim::bfloat16, typename std::remove_cv<T>::type>::value
> {};
: std::integral_constant<
bool,
std::is_fundamental<T>::value
|| std::is_same<::rocprim::half, typename std::remove_cv<T>::type>::value
|| std::is_same<::rocprim::bfloat16, typename std::remove_cv<T>::type>::value
|| std::is_same<__int128_t, typename std::remove_cv<T>::type>::value
|| std::is_same<__uint128_t, typename std::remove_cv<T>::type>::value>
{};

/// \brief Alias for std::is_unsigned.
/// \brief Extension of `std::is_unsigned`, which includes support for 128-bit integers.
template<class T>
using is_unsigned = std::is_unsigned<T>;
struct is_unsigned
: std::integral_constant<
bool,
std::is_unsigned<T>::value
|| std::is_same<__uint128_t, typename std::remove_cv<T>::type>::value>
{};

/// \brief Behaves like std::is_signed, but also includes half-precision and bfloat16-precision
/// floating point type (\ref rocprim::half).
/// \brief Extension of `std::is_signed`, which includes support for \ref rocprim::half , \ref rocprim::bfloat16 and 128-bit integers.
template<class T>
struct is_signed
: std::integral_constant<
bool,
std::is_signed<T>::value ||
std::is_same<::rocprim::half, typename std::remove_cv<T>::type>::value ||
std::is_same<::rocprim::bfloat16, typename std::remove_cv<T>::type>::value
> {};
bool,
std::is_signed<T>::value
|| std::is_same<::rocprim::half, typename std::remove_cv<T>::type>::value
|| std::is_same<::rocprim::bfloat16, typename std::remove_cv<T>::type>::value
|| std::is_same<__int128_t, typename std::remove_cv<T>::type>::value>
{};

/// \brief Behaves like std::is_scalar, but also includes half-precision and bfloat16-precision
/// floating point type (\ref rocprim::half).
/// \brief Extension of `std::is_scalar`, which includes support for \ref rocprim::half , \ref rocprim::bfloat16 and 128-bit integers.
template<class T>
struct is_scalar
: std::integral_constant<
bool,
std::is_scalar<T>::value ||
std::is_same<::rocprim::half, typename std::remove_cv<T>::type>::value ||
std::is_same<::rocprim::bfloat16, typename std::remove_cv<T>::type>::value
> {};
bool,
std::is_scalar<T>::value
|| std::is_same<::rocprim::half, typename std::remove_cv<T>::type>::value
|| std::is_same<::rocprim::bfloat16, typename std::remove_cv<T>::type>::value
|| std::is_same<__int128_t, typename std::remove_cv<T>::type>::value
|| std::is_same<__uint128_t, typename std::remove_cv<T>::type>::value>
{};

/// \brief Extension of `std::make_unsigned`, which includes support for 128-bit integers.
template<class T>
struct make_unsigned : std::make_unsigned<T>
{};

#ifndef DOXYGEN_SHOULD_SKIP_THIS // skip specialized versions
template<>
struct make_unsigned<__int128_t>
{
using type = __uint128_t;
};

template<>
struct make_unsigned<__uint128_t>
{
using type = __uint128_t;
};
#endif

/// \brief Behaves like std::is_compound, but also supports half-precision
/// floating point type (\ref rocprim::half). `value` for rocprim::half is `false`.
static_assert(std::is_same<make_unsigned<__int128_t>::type, __uint128_t>::value,
"'__int128_t' needs to implement 'make_unsigned' trait.");

/// \brief Extension of `std::is_compound`, which includes support for \ref rocprim::half , \ref rocprim::bfloat16 and 128-bit integers.
template<class T>
struct is_compound
: std::integral_constant<
Expand Down Expand Up @@ -143,6 +177,12 @@ struct get_unsigned_bits_type<T,8>
{
typedef uint64_t unsigned_type;
};

template<typename T>
struct get_unsigned_bits_type<T, 16>
{
typedef __uint128_t unsigned_type;
};
#endif // DOXYGEN_SHOULD_SKIP_THIS

#ifndef DOXYGEN_SHOULD_SKIP_THIS
Expand Down
37 changes: 1 addition & 36 deletions test/rocprim/test_utils_data_generation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -268,38 +268,6 @@ inline OutputIter segmented_generate_n(OutputIter it, size_t size, Generator&& g
return it + size;
}

template<class OutputIter, class U, class V, class Generator>
inline auto generate_random_data_n(OutputIter it, size_t size, U min, V max, Generator&& gen)
-> std::enable_if_t<std::is_same<it_value_t<OutputIter>, __int128_t>::value, OutputIter>
{
using T = it_value_t<OutputIter>;

using dis_type = typename std::conditional<
is_valid_for_int_distribution<T>::value,
T,
typename std::conditional<std::is_signed<T>::value, int, unsigned int>::type>::type;
std::uniform_int_distribution<dis_type> distribution(test_utils::saturate_cast<dis_type>(min),
test_utils::saturate_cast<dis_type>(max));

return segmented_generate_n(it, size, [&]() { return static_cast<T>(distribution(gen)); });
}

template<class OutputIter, class U, class V, class Generator>
inline auto generate_random_data_n(OutputIter it, size_t size, U min, V max, Generator&& gen)
-> std::enable_if_t<std::is_same<it_value_t<OutputIter>, __uint128_t>::value, OutputIter>
{
using T = it_value_t<OutputIter>;

using dis_type = typename std::conditional<
is_valid_for_int_distribution<T>::value,
T,
typename std::conditional<std::is_signed<T>::value, int, unsigned int>::type>::type;
std::uniform_int_distribution<dis_type> distribution(test_utils::saturate_cast<dis_type>(min),
test_utils::saturate_cast<dis_type>(max));

return segmented_generate_n(it, size, [&]() { return static_cast<T>(distribution(gen)); });
}

template<class OutputIter, class U, class V, class Generator>
inline auto generate_random_data_n(OutputIter it, size_t size, U min, V max, Generator&& gen)
-> std::enable_if_t<rocprim::is_integral<it_value_t<OutputIter>>::value, OutputIter>
Expand All @@ -309,10 +277,7 @@ inline auto generate_random_data_n(OutputIter it, size_t size, U min, V max, Gen
using dis_type = typename std::conditional<
is_valid_for_int_distribution<T>::value,
T,
typename std::conditional<std::is_signed<T>::value,
int,
unsigned int>::type
>::type;
typename std::conditional<rocprim::is_signed<T>::value, int, unsigned int>::type>::type;
std::uniform_int_distribution<dis_type> distribution(test_utils::saturate_cast<dis_type>(min),
test_utils::saturate_cast<dis_type>(max));

Expand Down

0 comments on commit 61ed896

Please sign in to comment.