diff --git a/CHANGELOG.md b/CHANGELOG.md index 5c206dbe0..1d0573522 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,10 +5,11 @@ Full documentation for rocPRIM is available at [https://rocprim.readthedocs.io/e ## [Unreleased rocPRIM-3.0.0 for ROCm 6.0.0] ### Added ### Changed - - Removed deprecated functionality: `reduce_by_key_config`, `MatchAny`, `scan_config`, `scan_by_key_config` and `radix_sort_config`. - - Renamed `scan_config_v2` to `scan_config`, `scan_by_key_config_v2` to `scan_by_key_config`, `radix_sort_config_v2` to `radix_sort_config`, `reduce_by_key_config_v2` to `reduce_by_key_config`, `radix_sort_config_v2` to `radix_sort_config`. - - Removed support for custom config types for device algorithms. +- Removed deprecated functionality: `reduce_by_key_config`, `MatchAny`, `scan_config`, `scan_by_key_config` and `radix_sort_config`. +- Renamed `scan_config_v2` to `scan_config`, `scan_by_key_config_v2` to `scan_by_key_config`, `radix_sort_config_v2` to `radix_sort_config`, `reduce_by_key_config_v2` to `reduce_by_key_config`, `radix_sort_config_v2` to `radix_sort_config`. +- Removed support for custom config types for device algorithms. - `host_warp_size()` was moved into `rocprim/device/config_types.hpp`, and now uses either a `device_id` or a `stream` parameter to query the proper device and a `device_id` out parameter. The return type is `hipError_t`. +- Added support for __int128_t in `device_radix_sort` and `block_radix_sort`. ### Fixed ## [Unreleased rocPRIM-2.13.1 for ROCm 5.7.0] diff --git a/rocprim/include/rocprim/detail/radix_sort.hpp b/rocprim/include/rocprim/detail/radix_sort.hpp index 66ba2e356..32ff17e87 100644 --- a/rocprim/include/rocprim/detail/radix_sort.hpp +++ b/rocprim/include/rocprim/detail/radix_sort.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017-2022 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -67,6 +67,33 @@ struct radix_key_codec_integral +struct radix_key_codec_integral< + Key, + BitKey, + typename std::enable_if::value>::type> +{ + using bit_key_type = BitKey; + + ROCPRIM_DEVICE ROCPRIM_INLINE static bit_key_type encode(Key key) + { + return __builtin_bit_cast(bit_key_type, key); + } + + ROCPRIM_DEVICE ROCPRIM_INLINE static Key decode(bit_key_type bit_key) + { + return __builtin_bit_cast(Key, bit_key); + } + + template + ROCPRIM_DEVICE static unsigned int + extract_digit(bit_key_type bit_key, unsigned int start, unsigned int length) + { + unsigned int mask = (1u << length) - 1; + return static_cast(bit_key >> start) & mask; + } +}; + template struct radix_key_codec_integral::value>::type> { @@ -97,6 +124,36 @@ struct radix_key_codec_integral +struct radix_key_codec_integral::value>::type> +{ + using bit_key_type = BitKey; + + static constexpr bit_key_type sign_bit = bit_key_type(1) << (sizeof(bit_key_type) * 8 - 1); + + ROCPRIM_DEVICE ROCPRIM_INLINE static bit_key_type encode(Key key) + { + const bit_key_type bit_key = __builtin_bit_cast(bit_key_type, key); + return sign_bit ^ bit_key; + } + + ROCPRIM_DEVICE ROCPRIM_INLINE static Key decode(bit_key_type bit_key) + { + bit_key ^= sign_bit; + return __builtin_bit_cast(Key, bit_key); + } + + template + ROCPRIM_DEVICE static unsigned int + extract_digit(bit_key_type bit_key, unsigned int start, unsigned int length) + { + unsigned int mask = (1u << length) - 1; + return static_cast(bit_key >> start) & mask; + } +}; + template struct float_bit_mask; @@ -199,6 +256,18 @@ struct radix_key_codec_base< typename std::enable_if<::rocprim::is_integral::value>::type > : radix_key_codec_integral::type> { }; +template +struct radix_key_codec_base::value>::type> + : radix_key_codec_integral +{}; + +template +struct radix_key_codec_base::value>::type> + : radix_key_codec_integral +{}; + template<> struct radix_key_codec_base { diff --git a/test/rocprim/test_block_radix_sort.cpp b/test/rocprim/test_block_radix_sort.cpp index f903cee50..bcc032b2b 100644 --- a/test/rocprim/test_block_radix_sort.cpp +++ b/test/rocprim/test_block_radix_sort.cpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2017-2022 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -40,7 +40,7 @@ struct RocprimBlockRadixSort; struct Integral; #define suite_name RocprimBlockRadixSort -#define warp_params BlockParamsIntegral +#define warp_params BlockParamsIntegralExtended #define name_suffix Integral #include "test_block_radix_sort.hpp" diff --git a/test/rocprim/test_device_radix_sort.cpp.in b/test/rocprim/test_device_radix_sort.cpp.in index bd37a9fb5..b56e85889 100644 --- a/test/rocprim/test_device_radix_sort.cpp.in +++ b/test/rocprim/test_device_radix_sort.cpp.in @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2022 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2022-2023 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -49,6 +49,10 @@ #endif #if ROCPRIM_TEST_TYPE_SLICE == 0 +#if defined(__GNUC__) || defined(__clang__) + INSTANTIATE(params<__int128_t, __int128_t>) + INSTANTIATE(params<__uint128_t, __uint128_t>) +#endif INSTANTIATE(params) INSTANTIATE(params) INSTANTIATE(params) diff --git a/test/rocprim/test_utils_assertions.hpp b/test/rocprim/test_utils_assertions.hpp index cb5e713c3..3e9ff4e93 100644 --- a/test/rocprim/test_utils_assertions.hpp +++ b/test/rocprim/test_utils_assertions.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2021 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2021-2023 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -240,6 +240,81 @@ void assert_bit_eq(const std::vector& result, const std::vector& expected) } } +#if defined(__GNUC__) || defined(__clang__) +inline void assert_bit_eq(const std::vector<__int128_t>& result, + const std::vector<__int128_t>& expected) +{ + ASSERT_EQ(result.size(), expected.size()); + + auto to_string = [](__int128_t value) + { + static const char* charmap = "0123456789"; + + std::string result; + result.reserve(41); // max. 40 digits possible ( uint64_t has 20) plus sign + __uint128_t helper = (value < 0) ? -value : value; + + do + { + result += charmap[helper % 10]; + helper /= 10; + } + while(helper); + if(value < 0) + { + result += "-"; + } + std::reverse(result.begin(), result.end()); + return result; + }; + + for(size_t i = 0; i < result.size(); i++) + { + if(!bit_equal(result[i], expected[i])) + { + FAIL() << "Expected strict/bitwise equality of these values: " << std::endl + << " result[i]: " << to_string(result[i]) << std::endl + << " expected[i]: " << to_string(expected[i]) << std::endl + << "where index = " << i; + } + } +} + +inline void assert_bit_eq(const std::vector<__uint128_t>& result, + const std::vector<__uint128_t>& expected) +{ + ASSERT_EQ(result.size(), expected.size()); + + auto to_string = [](__uint128_t value) + { + static const char* charmap = "0123456789"; + + std::string result; + result.reserve(40); // max. 40 digits possible ( uint64_t has 20) + __uint128_t helper = value; + + do + { + result += charmap[helper % 10]; + helper /= 10; + } + while(helper); + std::reverse(result.begin(), result.end()); + return result; + }; + + for(size_t i = 0; i < result.size(); i++) + { + if(!bit_equal(result[i], expected[i])) + { + FAIL() << "Expected strict/bitwise equality of these values: " << std::endl + << " result[i]: " << to_string(result[i]) << std::endl + << " expected[i]: " << to_string(expected[i]) << std::endl + << "where index = " << i; + } + } +} +#endif } #endif //ROCPRIM_TEST_UTILS_ASSERTIONS_HPP diff --git a/test/rocprim/test_utils_data_generation.hpp b/test/rocprim/test_utils_data_generation.hpp index 4c8881f12..00621e48c 100644 --- a/test/rocprim/test_utils_data_generation.hpp +++ b/test/rocprim/test_utils_data_generation.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2021-2022 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2021-2023 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -159,6 +159,90 @@ void add_special_values(std::vector& source, seed_type seed_value) } } +template +inline auto get_random_data(size_t size, U min, V max, seed_type seed_value) -> + typename std::enable_if::value, std::vector>::type +{ + engine_type gen{seed_value}; + using dis_type = typename std::conditional< + is_valid_for_int_distribution::value, + T, + typename std::conditional::value, int, unsigned int>::type>::type; + std::uniform_int_distribution distribution(static_cast(min), + static_cast(max)); + std::vector data(size); + size_t segment_size = size / random_data_generation_segments; + if(segment_size != 0) + { + for(uint32_t segment_index = 0; segment_index < random_data_generation_segments; + segment_index++) + { + if(segment_index % random_data_generation_repeat_strides == 0) + { + T repeated_value = static_cast(distribution(gen)); + std::fill(data.begin() + segment_size * segment_index, + data.begin() + segment_size * (segment_index + 1), + repeated_value); + } + else + { + std::generate(data.begin() + segment_size * segment_index, + data.begin() + segment_size * (segment_index + 1), + [&]() { return static_cast(distribution(gen)); }); + } + } + } + else + { + std::generate(data.begin(), + data.end(), + [&]() { return static_cast(distribution(gen)); }); + } + return data; +} + +template +inline auto get_random_data(size_t size, U min, V max, seed_type seed_value) -> + typename std::enable_if::value, std::vector>::type +{ + engine_type gen{seed_value}; + using dis_type = typename std::conditional< + is_valid_for_int_distribution::value, + T, + typename std::conditional::value, int, unsigned int>::type>::type; + std::uniform_int_distribution distribution(static_cast(min), + static_cast(max)); + std::vector data(size); + size_t segment_size = size / random_data_generation_segments; + if(segment_size != 0) + { + for(uint32_t segment_index = 0; segment_index < random_data_generation_segments; + segment_index++) + { + if(segment_index % random_data_generation_repeat_strides == 0) + { + T repeated_value = static_cast(distribution(gen)); + std::fill(data.begin() + segment_size * segment_index, + data.begin() + segment_size * (segment_index + 1), + repeated_value); + } + else + { + std::generate(data.begin() + segment_size * segment_index, + data.begin() + segment_size * (segment_index + 1), + [&]() { return static_cast(distribution(gen)); }); + } + } + } + else + { + std::generate(data.begin(), + data.end(), + [&]() { return static_cast(distribution(gen)); }); + } + return data; +} + template inline auto get_random_data(size_t size, U min, V max, seed_type seed_value) -> typename std::enable_if::value, std::vector>::type diff --git a/test/rocprim/test_utils_sort_comparator.hpp b/test/rocprim/test_utils_sort_comparator.hpp index 71eed3cc7..fa7be26ce 100644 --- a/test/rocprim/test_utils_sort_comparator.hpp +++ b/test/rocprim/test_utils_sort_comparator.hpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2017-2022 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -65,6 +65,46 @@ struct key_comparator +struct key_comparator::value>::type> +{ + static constexpr Key radix_mask_upper + = EndBit == 8 * sizeof(Key) ? ~Key(0) : (Key(1) << EndBit) - 1; + static constexpr Key radix_mask_bottom = (Key(1) << StartBit) - 1; + static constexpr Key radix_mask = radix_mask_upper ^ radix_mask_bottom; + + bool operator()(const Key& lhs, const Key& rhs) const + { + Key l = lhs & radix_mask; + Key r = rhs & radix_mask; + return Descending ? (r < l) : (l < r); + } +}; + +template +struct key_comparator::value>::type> +{ + static constexpr Key radix_mask_upper + = EndBit == 8 * sizeof(Key) ? ~Key(0) : (Key(1) << EndBit) - 1; + static constexpr Key radix_mask_bottom = (Key(1) << StartBit) - 1; + static constexpr Key radix_mask = radix_mask_upper ^ radix_mask_bottom; + + bool operator()(const Key& lhs, const Key& rhs) const + { + Key l = lhs & radix_mask; + Key r = rhs & radix_mask; + return Descending ? (r < l) : (l < r); + } +}; + template struct key_comparator BlockParamsIntegral; +typedef ::testing::Types), + block_param_type(uint8_t, short), + block_param_type(int8_t, float), + block_param_type(__uint128_t, short), + block_param_type(__int128_t, float)> + BlockParamsIntegralExtended; + typedef ::testing::Types< block_param_type(float, long), block_param_type(double, test_utils::custom_test_type),