From 608a043a44f1bed78210163499402ae4b9419050 Mon Sep 17 00:00:00 2001 From: Nick Breed Date: Wed, 19 Jun 2024 13:59:13 +0000 Subject: [PATCH] Review adaptations --- benchmark/benchmark_device_partial_sort.hpp | 7 +- docs/reference/ops_summary.rst | 2 +- .../rocprim/device/device_nth_element.hpp | 80 ++-- .../rocprim/device/device_partial_sort.hpp | 141 ++++--- test/rocprim/test_device_partial_sort.cpp | 351 +++++++++++++----- 5 files changed, 393 insertions(+), 188 deletions(-) diff --git a/benchmark/benchmark_device_partial_sort.hpp b/benchmark/benchmark_device_partial_sort.hpp index d0dba07da..d16c9fb1c 100644 --- a/benchmark/benchmark_device_partial_sort.hpp +++ b/benchmark/benchmark_device_partial_sort.hpp @@ -101,10 +101,9 @@ struct device_partial_sort_benchmark : public config_autotune_interface size * sizeof(*d_keys_input), hipMemcpyHostToDevice)); - ::rocprim::less lesser_op; - - void* d_temporary_storage = nullptr; - size_t temporary_storage_bytes = 0; + rocprim::less lesser_op; + void* d_temporary_storage = nullptr; + size_t temporary_storage_bytes = 0; HIP_CHECK(rocprim::partial_sort_copy(d_temporary_storage, temporary_storage_bytes, d_keys_input, diff --git a/docs/reference/ops_summary.rst b/docs/reference/ops_summary.rst index c308eb6bd..9dbf13d68 100644 --- a/docs/reference/ops_summary.rst +++ b/docs/reference/ops_summary.rst @@ -32,7 +32,7 @@ Rearrangement ================ * ``sort`` rearranges the sequence by sorting it. It could be according to a comparison operator or a value using a radix approach -* ``partial_sort`` rearranges the sequence by sorting it up to and including the middle index, according to a comparison operator. +* ``partial_sort`` rearranges the sequence by sorting it up to and including a given index, according to a comparison operator. * ``nth_element`` places the nth element in its sorted position, with elements less-than before, and greater after, according to a comparison operator. * ``exchange`` rearranges the elements according to a different stride configuration which is equivalent to a tensor axis transposition * ``shuffle`` rotates the elements diff --git a/rocprim/include/rocprim/device/device_nth_element.hpp b/rocprim/include/rocprim/device/device_nth_element.hpp index f3090699c..6b387b594 100644 --- a/rocprim/include/rocprim/device/device_nth_element.hpp +++ b/rocprim/include/rocprim/device/device_nth_element.hpp @@ -43,28 +43,29 @@ namespace detail { template -ROCPRIM_INLINE hipError_t nth_element_impl( - void* temporary_storage, - size_t& storage_size, - KeysIterator keys, - size_t nth, - size_t size, - BinaryFunction compare_function, - hipStream_t stream, - bool debug_synchronous, - typename std::iterator_traits::value_type* keys_double_buffer = nullptr) +ROCPRIM_INLINE +hipError_t + nth_element_impl(void* temporary_storage, + size_t& storage_size, + KeysIterator keys, + size_t nth, + size_t size, + BinaryFunction compare_function, + hipStream_t stream, + bool debug_synchronous, + typename std::iterator_traits::value_type* keys_double_buffer + = nullptr) { using key_type = typename std::iterator_traits::value_type; using config = wrapped_nth_element_config; - detail::target_arch target_arch; - hipError_t result = host_target_arch(stream, target_arch); + target_arch target_arch; + hipError_t result = host_target_arch(stream, target_arch); if(result != hipSuccess) { return result; } - const detail::nth_element_config_params params - = detail::dispatch_target_arch(target_arch); + const nth_element_config_params params = dispatch_target_arch(target_arch); constexpr unsigned int num_partitions = 3; const unsigned int num_buckets = params.number_of_buckets; @@ -73,19 +74,18 @@ ROCPRIM_INLINE hipError_t nth_element_impl( const unsigned int num_items_per_threads = params.kernel_config.items_per_thread; const unsigned int num_threads_per_block = params.kernel_config.block_size; const unsigned int num_items_per_block = num_threads_per_block * num_items_per_threads; - const unsigned int num_blocks = detail::ceiling_div(size, num_items_per_block); + const unsigned int num_blocks = ceiling_div(size, num_items_per_block); - // oracles stores the bucket that correlates with the index - key_type* tree = nullptr; - size_t* buckets = nullptr; - detail::n_th_element_iteration_data* nth_element_data = nullptr; - bool* equality_buckets = nullptr; - detail::nth_element_onesweep_lookback_state* lookback_states = nullptr; + key_type* tree = nullptr; + size_t* buckets = nullptr; + n_th_element_iteration_data* nth_element_data = nullptr; + bool* equality_buckets = nullptr; + nth_element_onesweep_lookback_state* lookback_states = nullptr; key_type* keys_buffer = nullptr; { - using namespace detail::temp_storage; + using namespace temp_storage; hipError_t partition_result; if(keys_double_buffer == nullptr) @@ -141,22 +141,22 @@ ROCPRIM_INLINE hipError_t nth_element_impl( std::cout << "storage_size: " << storage_size << '\n'; } - return detail::nth_element_keys_impl(keys, - keys_buffer, - tree, - nth, - size, - buckets, - equality_buckets, - lookback_states, - num_buckets, - stop_recursion_size, - num_threads_per_block, - num_items_per_threads, - nth_element_data, - compare_function, - stream, - debug_synchronous); + return nth_element_keys_impl(keys, + keys_buffer, + tree, + nth, + size, + buckets, + equality_buckets, + lookback_states, + num_buckets, + stop_recursion_size, + num_threads_per_block, + num_items_per_threads, + nth_element_data, + compare_function, + stream, + debug_synchronous); } } // namespace detail @@ -195,7 +195,7 @@ ROCPRIM_INLINE hipError_t nth_element_impl( /// The signature of the function should be equivalent to the following: /// bool f(const T &a, const T &b);. The signature does not need to have /// const &, but function object must not modify the objects passed to it. -/// The comperator must meet the C++ named requirement Compare. +/// The comparator must meet the C++ named requirement Compare. /// The default value is `BinaryFunction()`. /// \param [in] stream [optional] HIP stream object. Default is `0` (default stream). /// \param [in] debug_synchronous [optional] If true, synchronization after every kernel @@ -295,7 +295,7 @@ ROCPRIM_INLINE hipError_t nth_element(void* temporary_storage, /// The signature of the function should be equivalent to the following: /// bool f(const T &a, const T &b);. The signature does not need to have /// const &, but function object must not modify the objects passed to it. -/// The comperator must meet the C++ named requirement Compare. +/// The comparator must meet the C++ named requirement Compare. /// The default value is `BinaryFunction()`. /// \param [in] stream [optional] HIP stream object. Default is `0` (default stream). /// \param [in] debug_synchronous [optional] If true, synchronization after every kernel diff --git a/rocprim/include/rocprim/device/device_partial_sort.hpp b/rocprim/include/rocprim/device/device_partial_sort.hpp index cddc23aa1..e9bacae0d 100644 --- a/rocprim/include/rocprim/device/device_partial_sort.hpp +++ b/rocprim/include/rocprim/device/device_partial_sort.hpp @@ -58,10 +58,11 @@ namespace detail } \ while(0) -template +template hipError_t partial_sort_impl(void* temporary_storage, size_t& storage_size, - KeysIterator keys, + KeysIterator keys_in, + KeysIterator keys_out, size_t middle, size_t size, BinaryFunction compare_function, @@ -69,35 +70,42 @@ hipError_t partial_sort_impl(void* temporary_storage, bool debug_synchronous) { using key_type = typename std::iterator_traits::value_type; - using config - = detail::default_or_custom_config>; + using config = default_or_custom_config>; using config_merge_sort = typename config::merge_sort; using config_nth_element = typename config::nth_element; + if(size != 0 && middle >= size) + { + return hipErrorInvalidValue; + } + size_t storage_size_nth_element{}; // non-null placeholder so that no buffer is allocated for keys key_type* keys_buffer_placeholder = reinterpret_cast(1); - RETURN_ON_ERROR(nth_element_impl(nullptr, - storage_size_nth_element, - keys, - middle, - size, - compare_function, - stream, - debug_synchronous, - keys_buffer_placeholder)); - + const bool full_sort = middle + 1 == size; + if(!full_sort) + { + RETURN_ON_ERROR(nth_element_impl(nullptr, + storage_size_nth_element, + keys_in, + middle, + size, + compare_function, + stream, + debug_synchronous, + keys_buffer_placeholder)); + } size_t storage_size_merge_sort{}; RETURN_ON_ERROR( merge_sort_impl(nullptr, storage_size_merge_sort, - keys, - keys, + keys_in, + keys_out, static_cast(nullptr), // values_input static_cast(nullptr), // values_output - middle, + (!inplace || full_sort) ? middle + 1 : middle, compare_function, stream, debug_synchronous, @@ -107,12 +115,14 @@ hipError_t partial_sort_impl(void* temporary_storage, void* temporary_storage_nth_element = nullptr; void* temporary_storage_merge_sort = nullptr; key_type* keys_buffer = nullptr; + key_type* keys_output_nth_element = nullptr; const hipError_t partition_result = temp_storage::partition( temporary_storage, storage_size, temp_storage::make_linear_partition( temp_storage::ptr_aligned_array(&keys_buffer, size), + temp_storage::ptr_aligned_array(&keys_output_nth_element, inplace ? 0 : size), temp_storage::make_partition(&temporary_storage_nth_element, storage_size_nth_element), temp_storage::make_partition(&temporary_storage_merge_sort, storage_size_merge_sort))); @@ -126,36 +136,51 @@ hipError_t partial_sort_impl(void* temporary_storage, return hipSuccess; } - if(middle > size) + if(!inplace) { - return hipErrorInvalidValue; + RETURN_ON_ERROR(transform(keys_in, + keys_output_nth_element, + size, + rocprim::identity(), + stream, + debug_synchronous)); } - if(middle < size) + if(!full_sort) { - RETURN_ON_ERROR(nth_element_impl(temporary_storage_nth_element, - storage_size_nth_element, - keys, - middle, - size, - compare_function, - stream, - debug_synchronous, - keys_buffer)); + RETURN_ON_ERROR( + nth_element_impl(temporary_storage_nth_element, + storage_size_nth_element, + inplace ? keys_in : keys_output_nth_element, + middle, + size, + compare_function, + stream, + debug_synchronous, + keys_buffer)); } if(middle == 0) { + if(!inplace) + { + RETURN_ON_ERROR(transform(keys_output_nth_element, + keys_out, + 1, + rocprim::identity(), + stream, + debug_synchronous)); + } return hipSuccess; } return merge_sort_impl(temporary_storage_merge_sort, storage_size_merge_sort, - keys, - keys, + inplace ? keys_in : keys_output_nth_element, + keys_out, static_cast(nullptr), // values_input static_cast(nullptr), // values_output - middle, + (!inplace || full_sort) ? middle + 1 : middle, compare_function, stream, debug_synchronous, @@ -171,7 +196,7 @@ hipError_t partial_sort_impl(void* temporary_storage, /// * The contents of the inputs are not altered by the function. /// * Returns the required size of `temporary_storage` in `storage_size` /// if `temporary_storage` is a null pointer. -/// * Accepts custom compare_functions for nth_element across the device. +/// * Accepts custom compare_functions for partial_sort_copy across the device. /// * Streams in graph capture mode are not supported /// /// \tparam Config [optional] configuration of the primitive. It has to be `partial_sort_config`. @@ -184,18 +209,18 @@ hipError_t partial_sort_impl(void* temporary_storage, /// /// \param [in] temporary_storage pointer to a device-accessible temporary storage. When /// a null pointer is passed, the required allocation size (in bytes) is written to -/// `storage_size` and function returns without performing the nth_element rearrangement. +/// `storage_size` and function returns without performing the partial_sort_copy rearrangement. /// \param [in,out] storage_size reference to a size (in bytes) of `temporary_storage`. /// \param [in] keys_input iterator to the input range. /// \param [out] keys_output iterator to the output range. No overlap at all is allowed between `keys_input` and `keys_output`. -/// `keys_output` should be able to be written and read from for `size` elements. +/// `keys_output` should be able to be write to at least `middle` + 1 elements. /// \param [in] middle The index of the point till where it is sorted in the input range. /// \param [in] size number of element in the input range. /// \param [in] compare_function binary operation function object that will be used for comparison. /// The signature of the function should be equivalent to the following: /// bool f(const T &a, const T &b);. The signature does not need to have /// const &, but function object must not modify the objects passed to it. -/// The comperator must meet the C++ named requirement Compare. +/// The comparator must meet the C++ named requirement Compare. /// The default value is `BinaryFunction()`. /// \param [in] stream [optional] HIP stream object. Default is `0` (default stream). /// \param [in] debug_synchronous [optional] If true, synchronization after every kernel @@ -206,7 +231,7 @@ hipError_t partial_sort_impl(void* temporary_storage, /// /// \par Example /// \parblock -/// In this example a device-level nth_element is performed where input keys are +/// In this example a device-level partial_sort_copy is performed where input keys are /// represented by an array of unsigned integers. /// /// \code{.cpp} @@ -216,7 +241,7 @@ hipError_t partial_sort_impl(void* temporary_storage, /// size_t input_size; // e.g., 8 /// size_t middle; // e.g., 4 /// unsigned int * keys_input; // e.g., [ 6, 3, 5, 4, 1, 8, 2, 7 ] -/// unsigned int * keys_output; // empty array of 8 elements +/// unsigned int * keys_output; // e.g., [ 9, 9, 9, 9, 9, 9, 9, 9 ] /// /// size_t temporary_storage_size_bytes; /// void * temporary_storage_ptr = nullptr; @@ -234,7 +259,7 @@ hipError_t partial_sort_impl(void* temporary_storage, /// temporary_storage_ptr, temporary_storage_size_bytes, /// keys_input, keys_output, middle, input_size /// ); -/// // possible keys_output: [ 1, 2, 3, 4, 5, 8, 7, 6 ] +/// // possible keys_output: [ 1, 2, 3, 4, 5, 9, 9, 9 ] /// \endcode /// \endparblock template::value_type>::value, "KeysInputIterator and KeysOutputIterator must have the same value_type"); - RETURN_ON_ERROR(transform(keys_input, - keys_output, - size, - ::rocprim::identity(), - stream, - debug_synchronous)); - - return detail::partial_sort_impl(temporary_storage, - storage_size, - keys_output, - middle, - size, - compare_function, - stream, - debug_synchronous); + return detail::partial_sort_impl( + temporary_storage, + storage_size, + keys_input, + keys_output, + middle, + size, + compare_function, + stream, + debug_synchronous); } /// \brief Rearranges elements such that the range [0, middle) contains the sorted middle smallest elements in the range [0, size). @@ -281,7 +301,7 @@ hipError_t partial_sort_copy(void* temporary_storage, /// * The contents of the inputs are not altered by the function. /// * Returns the required size of `temporary_storage` in `storage_size` /// if `temporary_storage` is a null pointer. -/// * Accepts custom compare_functions for nth_element across the device. +/// * Accepts custom compare_functions for partial_sort across the device. /// * Streams in graph capture mode are not supported /// /// \tparam Config [optional] configuration of the primitive. It has to be `partial_sort_config`. @@ -292,7 +312,7 @@ hipError_t partial_sort_copy(void* temporary_storage, /// /// \param [in] temporary_storage pointer to a device-accessible temporary storage. When /// a null pointer is passed, the required allocation size (in bytes) is written to -/// `storage_size` and function returns without performing the nth_element rearrangement. +/// `storage_size` and function returns without performing the partial_sort rearrangement. /// \param [in,out] storage_size reference to a size (in bytes) of `temporary_storage`. /// \param [in,out] keys iterator to the input range. /// \param [in] middle The index of the point till where it is sorted in the input range. @@ -301,7 +321,7 @@ hipError_t partial_sort_copy(void* temporary_storage, /// The signature of the function should be equivalent to the following: /// bool f(const T &a, const T &b);. The signature does not need to have /// const &, but function object must not modify the objects passed to it. -/// The comperator must meet the C++ named requirement Compare. +/// The comparator must meet the C++ named requirement Compare. /// The default value is `BinaryFunction()`. /// \param [in] stream [optional] HIP stream object. Default is `0` (default stream). /// \param [in] debug_synchronous [optional] If true, synchronization after every kernel @@ -312,7 +332,7 @@ hipError_t partial_sort_copy(void* temporary_storage, /// /// \par Example /// \parblock -/// In this example a device-level nth_element is performed where input keys are +/// In this example a device-level partial_sort is performed where input keys are /// represented by an array of unsigned integers. /// /// \code{.cpp} @@ -328,7 +348,7 @@ hipError_t partial_sort_copy(void* temporary_storage, /// // Get required size of the temporary storage /// rocprim::partial_sort( /// temporary_storage_ptr, temporary_storage_size_bytes, -/// keys, nth, input_size +/// keys, middle, input_size /// ); /// /// // allocate temporary storage @@ -337,7 +357,7 @@ hipError_t partial_sort_copy(void* temporary_storage, /// // perform partial_sort /// rocprim::partial_sort( /// temporary_storage_ptr, temporary_storage_size_bytes, -/// keys, nth, input_size +/// keys, middle, input_size /// ); /// // possible keys: [ 1, 2, 3, 4, 5, 8, 7, 6 ] /// \endcode @@ -358,6 +378,7 @@ hipError_t partial_sort(void* temporary_storage, return detail::partial_sort_impl(temporary_storage, storage_size, keys, + keys, middle, size, compare_function, diff --git a/test/rocprim/test_device_partial_sort.cpp b/test/rocprim/test_device_partial_sort.cpp index 3cf701108..077d1fc59 100644 --- a/test/rocprim/test_device_partial_sort.cpp +++ b/test/rocprim/test_device_partial_sort.cpp @@ -31,6 +31,7 @@ #include "../common_test_header.hpp" // required rocprim headers +#include #include #include #include @@ -96,20 +97,25 @@ using RocprimDevicePartialSortTestsParams = ::testing::Types< nth_element_config<128, 4, 32, 16, rocprim::block_radix_rank_algorithm::basic>>>>; template -void inline compare_cpp_14(InputVector input, - OutputVector output, - size_t middle, - CompareFunction compare_op) +void inline compare_partial_sort_cpp_14(InputVector input, + OutputVector output, + size_t middle, + CompareFunction compare_op) { using key_type = typename InputVector::value_type; + if(input.size() == 0) + { + return; + } + // Calculate sorted input results on host std::vector sorted_input(input); std::sort(sorted_input.begin(), sorted_input.end(), compare_op); // Calculate sorted output results on host std::vector sorted_output(output); - std::sort(sorted_output.begin() + middle, sorted_output.end(), compare_op); + std::sort(sorted_output.begin() + middle + 1, sorted_output.end(), compare_op); // Check if the values are the same ASSERT_NO_FATAL_FAILURE(test_utils::assert_eq(sorted_output, sorted_input)); @@ -117,24 +123,29 @@ void inline compare_cpp_14(InputVector input, #if CPP17 template -void inline compare_cpp_17(InputVector input, - OutputVector output, - size_t middle, - CompareFunction compare_op) +void inline compare_partial_sort_cpp_17(InputVector input, + OutputVector output, + size_t middle, + CompareFunction compare_op) { using key_type = typename InputVector::value_type; + if(input.size() == 0) + { + return; + } + // Calculate sorted input results on host std::vector sorted_input(input); std::partial_sort(sorted_input.begin(), - sorted_input.begin() + middle, + sorted_input.begin() + middle + 1, sorted_input.end(), compare_op); - std::sort(sorted_input.begin() + middle, sorted_input.end(), compare_op); + std::sort(sorted_input.begin() + middle + 1, sorted_input.end(), compare_op); // Calculate sorted output results on host std::vector sorted_output(output); - std::sort(sorted_output.begin() + middle, sorted_output.end(), compare_op); + std::sort(sorted_output.begin() + middle + 1, sorted_output.end(), compare_op); // Check if the values are the same ASSERT_NO_FATAL_FAILURE(test_utils::assert_eq(sorted_output, sorted_input)); @@ -142,15 +153,15 @@ void inline compare_cpp_17(InputVector input, #endif template -void inline compare(InputVector input, - OutputVector output, - size_t middle, - CompareFunction compare_op) +void inline compare_partial_sort(InputVector input, + OutputVector output, + size_t middle, + CompareFunction compare_op) { - compare_cpp_14(input, output, middle, compare_op); + compare_partial_sort_cpp_14(input, output, middle, compare_op); #if CPP17 // this comparison is only compiled and executed if c++17 is available - compare_cpp_17(input, output, middle, compare_op); + compare_partial_sort_cpp_17(input, output, middle, compare_op); #else ROCPRIM_PRAGMA_MESSAGE("c++17 not available skips direct comparison with std::partial_sort"); #endif @@ -170,8 +181,6 @@ TYPED_TEST(RocprimDevicePartialSortTests, PartialSort) const bool debug_synchronous = TestFixture::debug_synchronous; static constexpr bool use_indirect_iterator = TestFixture::use_indirect_iterator; - bool in_place = false; - for(size_t seed_index = 0; seed_index < random_seeds_count + seed_size; ++seed_index) { unsigned int seed_value @@ -183,12 +192,10 @@ TYPED_TEST(RocprimDevicePartialSortTests, PartialSort) SCOPED_TRACE(testing::Message() << "with size = " << size); std::vector middles = {0}; - if(size > 0) - { - middles.push_back(size); - } + if(size > 1) { + middles.push_back(size - 1); middles.push_back(test_utils::get_random_value(1, size - 1, seed_value)); } @@ -223,15 +230,7 @@ TYPED_TEST(RocprimDevicePartialSortTests, PartialSort) hipMemcpyHostToDevice)); key_type* d_output; - if(in_place) - { - d_output = d_input; - } - else - { - HIP_CHECK( - test_common_utils::hipMallocHelper(&d_output, size * sizeof(key_type))); - } + d_output = d_input; const auto input_it = test_utils::wrap_in_indirect_iterator(d_input); @@ -240,29 +239,14 @@ TYPED_TEST(RocprimDevicePartialSortTests, PartialSort) // Allocate temporary storage size_t temp_storage_size_bytes{}; - if(in_place) - { - HIP_CHECK(rocprim::partial_sort(nullptr, - temp_storage_size_bytes, - input_it, - middle, - size, - compare_op, - stream, - debug_synchronous)); - } - else - { - HIP_CHECK(rocprim::partial_sort_copy(nullptr, - temp_storage_size_bytes, - input_it, - d_output, - middle, - size, - compare_op, - stream, - debug_synchronous)); - } + HIP_CHECK(rocprim::partial_sort(nullptr, + temp_storage_size_bytes, + input_it, + middle, + size, + compare_op, + stream, + debug_synchronous)); ASSERT_GT(temp_storage_size_bytes, 0); void* d_temp_storage{}; @@ -274,30 +258,238 @@ TYPED_TEST(RocprimDevicePartialSortTests, PartialSort) { graph = test_utils::createGraphHelper(stream); } - if(in_place) + HIP_CHECK(rocprim::partial_sort(d_temp_storage, + temp_storage_size_bytes, + input_it, + middle, + size, + compare_op, + stream, + debug_synchronous)); + + HIP_CHECK(hipGetLastError()); + + hipGraphExec_t graph_instance; + if(TestFixture::use_graphs) + { + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); + } + + std::vector output(size); + HIP_CHECK(hipMemcpy(output.data(), + d_output, + size * sizeof(key_type), + hipMemcpyDeviceToHost)); + + compare_partial_sort(input, output, middle, compare_op); + + HIP_CHECK(hipFree(d_input)); + HIP_CHECK(hipFree(d_temp_storage)); + + if(TestFixture::use_graphs) { - HIP_CHECK(rocprim::partial_sort(d_temp_storage, - temp_storage_size_bytes, - input_it, - middle, - size, - compare_op, - stream, - debug_synchronous)); + test_utils::cleanupGraphHelper(graph, graph_instance); + HIP_CHECK(hipStreamDestroy(stream)); + } + } + } + } +} + +template +void inline compare_partial_sort_copy_cpp_14(InputVector input, + OutputVector output, + OutputVector orignal_output, + size_t middle, + CompareFunction compare_op) +{ + using key_type = typename InputVector::value_type; + + if(input.size() == 0) + { + return; + } + std::vector expected_output; + // Calculate sorted input results on host + std::vector sorted_input(input); + std::sort(sorted_input.begin(), sorted_input.end(), compare_op); + + expected_output.insert(expected_output.end(), + sorted_input.begin(), + sorted_input.begin() + std::min(middle + 1, sorted_input.size())); + + if(middle + 1 < orignal_output.size()) + { + expected_output.insert(expected_output.end(), + orignal_output.begin() + middle + 1, + orignal_output.end()); + } + + // Check if the values are the same + ASSERT_NO_FATAL_FAILURE(test_utils::assert_eq(output, expected_output)); +} + +#if CPP17 +template +void inline compare_partial_sort_copy_cpp_17(InputVector input, + OutputVector output, + OutputVector orignal_output, + size_t middle, + CompareFunction compare_op) +{ + using key_type = typename InputVector::value_type; + + if(input.size() == 0) + { + return; + } + + // Calculate sorted input results on host + std::vector sorted_output(orignal_output); + std::partial_sort_copy(input.begin(), + input.end(), + sorted_output.begin(), + sorted_output.begin() + middle + 1, + compare_op); + + // Check if the values are the same + ASSERT_NO_FATAL_FAILURE(test_utils::assert_eq(sorted_output, output)); +} +#endif + +template +void inline compare_partial_sort_copy(InputVector input, + OutputVector output, + OutputVector orignal_output, + size_t middle, + CompareFunction compare_op) +{ + compare_partial_sort_copy_cpp_14(input, output, orignal_output, middle, compare_op); +#if CPP17 + // this comparison is only compiled and executed if c++17 is available + compare_partial_sort_copy_cpp_17(input, output, orignal_output, middle, compare_op); +#else + ROCPRIM_PRAGMA_MESSAGE( + "c++17 not available skips direct comparison with std::partial_sort_copy"); +#endif +} + +TYPED_TEST(RocprimDevicePartialSortTests, PartialSortCopy) +{ + int device_id = test_common_utils::obtain_device_from_ctest(); + SCOPED_TRACE(testing::Message() << "with device_id = " << device_id); + HIP_CHECK(hipSetDevice(device_id)); + + using key_type = typename TestFixture::key_type; + using compare_function = typename TestFixture::compare_function; + using config = typename TestFixture::config; + const bool debug_synchronous = TestFixture::debug_synchronous; + static constexpr bool use_indirect_iterator = TestFixture::use_indirect_iterator; + + for(size_t seed_index = 0; seed_index < random_seeds_count + seed_size; ++seed_index) + { + unsigned int seed_value + = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count]; + SCOPED_TRACE(testing::Message() << "with seed = " << seed_value); + + for(size_t size : test_utils::get_sizes(seed_value)) + { + SCOPED_TRACE(testing::Message() << "with size = " << size); + + std::vector middles = {0}; + + if(size > 1) + { + middles.push_back(size - 1); + middles.push_back(test_utils::get_random_value(1, size - 1, seed_value)); + } + + for(size_t middle : middles) + { + SCOPED_TRACE(testing::Message() << "with middle = " << middle); + + hipStream_t stream = 0; // default + if(TestFixture::use_graphs) + { + HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); + } + + std::vector input; + std::vector output_original; + if(rocprim::is_floating_point::value) + { + input = test_utils::get_random_data(size, -1000, 1000, seed_value); + output_original + = test_utils::get_random_data(size, -1000, 1000, seed_value + 1); } else { - HIP_CHECK(rocprim::partial_sort_copy(d_temp_storage, - temp_storage_size_bytes, - input_it, - d_output, - middle, - size, - compare_op, - stream, - debug_synchronous)); + input = test_utils::get_random_data( + size, + test_utils::numeric_limits::min(), + test_utils::numeric_limits::max(), + seed_value); + output_original = test_utils::get_random_data( + size, + test_utils::numeric_limits::min(), + test_utils::numeric_limits::max(), + seed_value + 1); + } + + key_type* d_input; + HIP_CHECK(test_common_utils::hipMallocHelper(&d_input, size * sizeof(key_type))); + HIP_CHECK(hipMemcpy(d_input, + input.data(), + size * sizeof(key_type), + hipMemcpyHostToDevice)); + + key_type* d_output; + + HIP_CHECK(test_common_utils::hipMallocHelper(&d_output, size * sizeof(key_type))); + HIP_CHECK(hipMemcpy(d_output, + output_original.data(), + size * sizeof(key_type), + hipMemcpyHostToDevice)); + + const auto input_it + = test_utils::wrap_in_indirect_iterator(d_input); + + compare_function compare_op; + + // Allocate temporary storage + size_t temp_storage_size_bytes{}; + + HIP_CHECK(rocprim::partial_sort_copy(nullptr, + temp_storage_size_bytes, + input_it, + d_output, + middle, + size, + compare_op, + stream, + debug_synchronous)); + + ASSERT_GT(temp_storage_size_bytes, 0); + void* d_temp_storage{}; + HIP_CHECK( + test_common_utils::hipMallocHelper(&d_temp_storage, temp_storage_size_bytes)); + + hipGraph_t graph; + if(TestFixture::use_graphs) + { + graph = test_utils::createGraphHelper(stream); } + HIP_CHECK(rocprim::partial_sort_copy(d_temp_storage, + temp_storage_size_bytes, + input_it, + d_output, + middle, + size, + compare_op, + stream, + debug_synchronous)); + HIP_CHECK(hipGetLastError()); hipGraphExec_t graph_instance; @@ -306,21 +498,16 @@ TYPED_TEST(RocprimDevicePartialSortTests, PartialSort) graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); } - // The algorithm sorted [first, middle). Since the order of [middle, last) is not specified, - // sort [middle, last) to compare with expected values. std::vector output(size); HIP_CHECK(hipMemcpy(output.data(), d_output, size * sizeof(key_type), hipMemcpyDeviceToHost)); - compare(input, output, middle, compare_op); + compare_partial_sort_copy(input, output, output_original, middle, compare_op); HIP_CHECK(hipFree(d_input)); - if(!in_place) - { - HIP_CHECK(hipFree(d_output)); - } + HIP_CHECK(hipFree(d_output)); HIP_CHECK(hipFree(d_temp_storage)); if(TestFixture::use_graphs) @@ -328,8 +515,6 @@ TYPED_TEST(RocprimDevicePartialSortTests, PartialSort) test_utils::cleanupGraphHelper(graph, graph_instance); HIP_CHECK(hipStreamDestroy(stream)); } - - in_place = !in_place; } } }