From 61ed89605ba835db619c910df8b44160024a321e Mon Sep 17 00:00:00 2001 From: Nara Prasetya Date: Tue, 3 Sep 2024 15:03:47 +0000 Subject: [PATCH] fix: fixed incorrect definition and use of rocprim specific type traits for 128-bit integers This fixes an issue where in certain situations where using radix codec on 128-bit integers would not compile due to ambiguity. --- CHANGELOG.md | 2 + .../rocprim/thread/radix_key_codec.hpp | 99 +++------------ rocprim/include/rocprim/type_traits.hpp | 114 ++++++++++++------ test/rocprim/test_utils_data_generation.hpp | 37 +----- 4 files changed, 99 insertions(+), 153 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8b311618d..ff6db71e0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,8 @@ Documentation for rocPRIM is available at ### Fixes * Fixed an issue where `rocprim::partial_sort_copy` would yield a compile error if the input iterator is const. +* Fixed incorrect 128-bit signed and unsigned integers type traits. +* Fixed compilation issue when `rocprim::radix_key_codec<...>` is specialized with a 128-bit integer. ### Deprecations diff --git a/rocprim/include/rocprim/thread/radix_key_codec.hpp b/rocprim/include/rocprim/thread/radix_key_codec.hpp index 78a05b0fa..c83767890 100644 --- a/rocprim/include/rocprim/thread/radix_key_codec.hpp +++ b/rocprim/include/rocprim/thread/radix_key_codec.hpp @@ -78,33 +78,6 @@ struct radix_key_codec_integral -struct radix_key_codec_integral< - Key, - BitKey, - typename std::enable_if::value>::type> -{ - using bit_key_type = BitKey; - - ROCPRIM_HOST_DEVICE static bit_key_type encode(Key key) - { - return ::rocprim::detail::bit_cast(key); - } - - ROCPRIM_HOST_DEVICE static Key decode(bit_key_type bit_key) - { - return ::rocprim::detail::bit_cast(bit_key); - } - - template - ROCPRIM_HOST_DEVICE static unsigned int - extract_digit(bit_key_type bit_key, unsigned int start, unsigned int length) - { - unsigned int mask = (1u << length) - 1; - return static_cast(bit_key >> start) & mask; - } -}; - template struct radix_key_codec_integral -struct radix_key_codec_integral::value>::type> -{ - using bit_key_type = BitKey; - - static constexpr bit_key_type sign_bit = bit_key_type(1) << (sizeof(bit_key_type) * 8 - 1); - - ROCPRIM_HOST_DEVICE static bit_key_type encode(Key key) - { - const auto bit_key = ::rocprim::detail::bit_cast(key); - return sign_bit ^ bit_key; - } - - ROCPRIM_HOST_DEVICE static Key decode(bit_key_type bit_key) - { - bit_key ^= sign_bit; - return ::rocprim::detail::bit_cast(bit_key); - } - - template - ROCPRIM_HOST_DEVICE static unsigned int - extract_digit(bit_key_type bit_key, unsigned int start, unsigned int length) - { - unsigned int mask = (1u << length) - 1; - return static_cast(bit_key >> start) & mask; - } -}; - template struct radix_key_codec_floating { @@ -222,19 +165,7 @@ struct radix_key_codec_base template struct radix_key_codec_base::value>::type> - : radix_key_codec_integral::type> -{}; - -template -struct radix_key_codec_base::value>::type> - : radix_key_codec_integral -{}; - -template -struct radix_key_codec_base::value>::type> - : radix_key_codec_integral + : radix_key_codec_integral::type> {}; template<> @@ -279,21 +210,29 @@ template<> struct radix_key_codec_base : radix_key_codec_floating {}; -template -struct radix_key_fundamental +template +struct has_bit_key_type { - static constexpr bool value = false; + template + static std::true_type check(typename U::bit_key_type*); + + template + static std::false_type check(...); + + using result = decltype(check(nullptr)); }; template -struct radix_key_fundamental< - T, - ::rocprim::detail::void_t::bit_key_type>> -{ - static constexpr bool value = true; -}; +using radix_key_fundamental = typename has_bit_key_type>::result; -} // end namespace detail +static_assert(radix_key_fundamental::value, "'int' should be fundamental"); +static_assert(!radix_key_fundamental::value, "'int*' should not be fundamental"); +static_assert(radix_key_fundamental<__int128_t>::value, "'__int128_t' should be fundamental"); +static_assert(radix_key_fundamental<__uint128_t>::value, "'__uint128_t' should be fundamental"); +static_assert(!radix_key_fundamental<__int128_t*>::value, + "'__int128_t*' should not be fundamental"); + +} // namespace detail /// \brief Key encoder, decoder and bit-extractor for radix-based sorts. /// diff --git a/rocprim/include/rocprim/type_traits.hpp b/rocprim/include/rocprim/type_traits.hpp index 9824ef05c..91bcb14dc 100644 --- a/rocprim/include/rocprim/type_traits.hpp +++ b/rocprim/include/rocprim/type_traits.hpp @@ -36,8 +36,7 @@ BEGIN_ROCPRIM_NAMESPACE -/// \brief Behaves like std::is_floating_point, but also includes half-precision and bfloat16-precision -/// floating point type (rocprim::half). +/// \brief Extension of `std::is_floating_point`, which includes support for \ref rocprim::half and \ref rocprim::bfloat16. template struct is_floating_point : std::integral_constant< @@ -47,60 +46,95 @@ struct is_floating_point std::is_same<::rocprim::bfloat16, typename std::remove_cv::type>::value > {}; -/// \brief Alias for std::is_integral. +/// \brief Extension of `std::is_integral`, which includes support for 128-bit integers. template -using is_integral = std::is_integral; +struct is_integral + : std::integral_constant< + bool, + std::is_integral::value + || std::is_same<__int128_t, typename std::remove_cv::type>::value + || std::is_same<__uint128_t, typename std::remove_cv::type>::value> +{}; -/// \brief Behaves like std::is_arithmetic, but also includes half-precision and bfloat16-precision -/// floating point type (\ref rocprim::half). +/// \brief Extension of `std::is_arithmetic`, which includes support for \ref rocprim::half , \ref rocprim::bfloat16 and 128-bit integers. template struct is_arithmetic : std::integral_constant< - bool, - std::is_arithmetic::value || - std::is_same<::rocprim::half, typename std::remove_cv::type>::value || - std::is_same<::rocprim::bfloat16, typename std::remove_cv::type>::value - > {}; + bool, + std::is_arithmetic::value + || std::is_same<::rocprim::half, typename std::remove_cv::type>::value + || std::is_same<::rocprim::bfloat16, typename std::remove_cv::type>::value + || std::is_same<__int128_t, typename std::remove_cv::type>::value + || std::is_same<__uint128_t, typename std::remove_cv::type>::value> +{}; -/// \brief Behaves like std::is_fundamental, but also includes half-precision and bfloat16-precision -/// floating point type (\ref rocprim::half). +/// \brief Extension of `std::is_fundamental`, which includes support for \ref rocprim::half , \ref rocprim::bfloat16 and 128-bit integers. template struct is_fundamental - : std::integral_constant< - bool, - std::is_fundamental::value || - std::is_same<::rocprim::half, typename std::remove_cv::type>::value || - std::is_same<::rocprim::bfloat16, typename std::remove_cv::type>::value -> {}; + : std::integral_constant< + bool, + std::is_fundamental::value + || std::is_same<::rocprim::half, typename std::remove_cv::type>::value + || std::is_same<::rocprim::bfloat16, typename std::remove_cv::type>::value + || std::is_same<__int128_t, typename std::remove_cv::type>::value + || std::is_same<__uint128_t, typename std::remove_cv::type>::value> +{}; -/// \brief Alias for std::is_unsigned. +/// \brief Extension of `std::is_unsigned`, which includes support for 128-bit integers. template -using is_unsigned = std::is_unsigned; +struct is_unsigned + : std::integral_constant< + bool, + std::is_unsigned::value + || std::is_same<__uint128_t, typename std::remove_cv::type>::value> +{}; -/// \brief Behaves like std::is_signed, but also includes half-precision and bfloat16-precision -/// floating point type (\ref rocprim::half). +/// \brief Extension of `std::is_signed`, which includes support for \ref rocprim::half , \ref rocprim::bfloat16 and 128-bit integers. template struct is_signed : std::integral_constant< - bool, - std::is_signed::value || - std::is_same<::rocprim::half, typename std::remove_cv::type>::value || - std::is_same<::rocprim::bfloat16, typename std::remove_cv::type>::value - > {}; + bool, + std::is_signed::value + || std::is_same<::rocprim::half, typename std::remove_cv::type>::value + || std::is_same<::rocprim::bfloat16, typename std::remove_cv::type>::value + || std::is_same<__int128_t, typename std::remove_cv::type>::value> +{}; -/// \brief Behaves like std::is_scalar, but also includes half-precision and bfloat16-precision -/// floating point type (\ref rocprim::half). +/// \brief Extension of `std::is_scalar`, which includes support for \ref rocprim::half , \ref rocprim::bfloat16 and 128-bit integers. template struct is_scalar : std::integral_constant< - bool, - std::is_scalar::value || - std::is_same<::rocprim::half, typename std::remove_cv::type>::value || - std::is_same<::rocprim::bfloat16, typename std::remove_cv::type>::value - > {}; + bool, + std::is_scalar::value + || std::is_same<::rocprim::half, typename std::remove_cv::type>::value + || std::is_same<::rocprim::bfloat16, typename std::remove_cv::type>::value + || std::is_same<__int128_t, typename std::remove_cv::type>::value + || std::is_same<__uint128_t, typename std::remove_cv::type>::value> +{}; + +/// \brief Extension of `std::make_unsigned`, which includes support for 128-bit integers. +template +struct make_unsigned : std::make_unsigned +{}; + +#ifndef DOXYGEN_SHOULD_SKIP_THIS // skip specialized versions +template<> +struct make_unsigned<__int128_t> +{ + using type = __uint128_t; +}; + +template<> +struct make_unsigned<__uint128_t> +{ + using type = __uint128_t; +}; +#endif -/// \brief Behaves like std::is_compound, but also supports half-precision -/// floating point type (\ref rocprim::half). `value` for rocprim::half is `false`. +static_assert(std::is_same::type, __uint128_t>::value, + "'__int128_t' needs to implement 'make_unsigned' trait."); + +/// \brief Extension of `std::is_compound`, which includes support for \ref rocprim::half , \ref rocprim::bfloat16 and 128-bit integers. template struct is_compound : std::integral_constant< @@ -143,6 +177,12 @@ struct get_unsigned_bits_type { typedef uint64_t unsigned_type; }; + +template +struct get_unsigned_bits_type +{ + typedef __uint128_t unsigned_type; +}; #endif // DOXYGEN_SHOULD_SKIP_THIS #ifndef DOXYGEN_SHOULD_SKIP_THIS diff --git a/test/rocprim/test_utils_data_generation.hpp b/test/rocprim/test_utils_data_generation.hpp index ae618d021..d45d9acef 100644 --- a/test/rocprim/test_utils_data_generation.hpp +++ b/test/rocprim/test_utils_data_generation.hpp @@ -268,38 +268,6 @@ inline OutputIter segmented_generate_n(OutputIter it, size_t size, Generator&& g return it + size; } -template -inline auto generate_random_data_n(OutputIter it, size_t size, U min, V max, Generator&& gen) - -> std::enable_if_t, __int128_t>::value, OutputIter> -{ - using T = it_value_t; - - using dis_type = typename std::conditional< - is_valid_for_int_distribution::value, - T, - typename std::conditional::value, int, unsigned int>::type>::type; - std::uniform_int_distribution distribution(test_utils::saturate_cast(min), - test_utils::saturate_cast(max)); - - return segmented_generate_n(it, size, [&]() { return static_cast(distribution(gen)); }); -} - -template -inline auto generate_random_data_n(OutputIter it, size_t size, U min, V max, Generator&& gen) - -> std::enable_if_t, __uint128_t>::value, OutputIter> -{ - using T = it_value_t; - - using dis_type = typename std::conditional< - is_valid_for_int_distribution::value, - T, - typename std::conditional::value, int, unsigned int>::type>::type; - std::uniform_int_distribution distribution(test_utils::saturate_cast(min), - test_utils::saturate_cast(max)); - - return segmented_generate_n(it, size, [&]() { return static_cast(distribution(gen)); }); -} - template inline auto generate_random_data_n(OutputIter it, size_t size, U min, V max, Generator&& gen) -> std::enable_if_t>::value, OutputIter> @@ -309,10 +277,7 @@ inline auto generate_random_data_n(OutputIter it, size_t size, U min, V max, Gen using dis_type = typename std::conditional< is_valid_for_int_distribution::value, T, - typename std::conditional::value, - int, - unsigned int>::type - >::type; + typename std::conditional::value, int, unsigned int>::type>::type; std::uniform_int_distribution distribution(test_utils::saturate_cast(min), test_utils::saturate_cast(max));