From 8a5242be75964787d1c0e5924506819ef98db66b Mon Sep 17 00:00:00 2001 From: Nol Moonen Date: Fri, 26 Apr 2024 14:58:39 +0000 Subject: [PATCH 01/14] naive implementation --- .../rocprim/device/device_nth_element.hpp | 263 ++++++++++++++++++ test/rocprim/CMakeLists.txt | 1 + test/rocprim/test_device_partial_sort.cpp | 236 ++++++++++++++++ 3 files changed, 500 insertions(+) create mode 100644 test/rocprim/test_device_partial_sort.cpp diff --git a/rocprim/include/rocprim/device/device_nth_element.hpp b/rocprim/include/rocprim/device/device_nth_element.hpp index c1beaec65..680dc05db 100644 --- a/rocprim/include/rocprim/device/device_nth_element.hpp +++ b/rocprim/include/rocprim/device/device_nth_element.hpp @@ -28,6 +28,7 @@ #include "../config.hpp" #include "config_types.hpp" +#include "device_merge_sort.hpp" #include "device_nth_element_config.hpp" #include "device_transform.hpp" @@ -335,6 +336,268 @@ ROCPRIM_INLINE hipError_t nth_element(void* temporary_storage, debug_synchronous); } +namespace detail +{ + +template +hipError_t partial_sort_impl(void* temporary_storage, + size_t& storage_size, + KeysInputIterator keys_input, + KeysOutputIterator keys_output, + size_t middle, + size_t size, + BinaryFunction compare_function, + hipStream_t stream, + bool debug_synchronous) +{ + using key_type = typename std::iterator_traits::value_type; + static_assert( + std::is_same::value_type>::value, + "KeysInputIterator and KeysOutputIterator must have the same value_type"); + + using config = wrapped_nth_element_config; + + target_arch target_arch; + hipError_t result = host_target_arch(stream, target_arch); + if(result != hipSuccess) + { + return result; + } + const nth_element_config_params params = dispatch_target_arch(target_arch); + + constexpr unsigned int num_partitions = 3; + const unsigned int num_buckets = params.number_of_buckets; + const unsigned int num_splitters = num_buckets - 1; + const unsigned int stop_recursion_size = num_buckets; + const unsigned int num_items_per_threads = params.kernel_config.items_per_thread; + const unsigned int num_threads_per_block = params.kernel_config.block_size; + const unsigned int num_items_per_block = num_threads_per_block * num_items_per_threads; + const unsigned int num_blocks = ceiling_div(size, num_items_per_block); + + size_t storage_size_merge_sort{}; + // non-null placeholder so that no buffer is allocated for keys + key_type* keys_buffer_placeholder = reinterpret_cast(1); + + result = merge_sort_impl(nullptr, + storage_size_merge_sort, + keys_output, + keys_output, + static_cast(nullptr), // values_input + static_cast(nullptr), // values_output + middle, + compare_function, + stream, + debug_synchronous, + keys_buffer_placeholder, // keys_buffer + static_cast(nullptr)); // values_buffer + if(result != hipSuccess) + { + return result; + } + + key_type* tree = nullptr; + size_t* buckets = nullptr; + n_th_element_iteration_data* nth_element_data = nullptr; + uint8_t* oracles = nullptr; + bool* equality_buckets = nullptr; + nth_element_onesweep_lookback_state* lookback_states = nullptr; + key_type* keys_buffer = nullptr; + void* temporary_storage_merge_sort = nullptr; + + const hipError_t partition_result = temp_storage::partition( + temporary_storage, + storage_size, + temp_storage::make_linear_partition( + temp_storage::ptr_aligned_array(&tree, num_splitters), + temp_storage::ptr_aligned_array(&equality_buckets, num_buckets), + temp_storage::ptr_aligned_array(&buckets, num_buckets), + temp_storage::ptr_aligned_array(&oracles, size), + temp_storage::ptr_aligned_array(&keys_buffer, size), + temp_storage::ptr_aligned_array(&nth_element_data, 1), + temp_storage::ptr_aligned_array(&lookback_states, num_partitions * num_blocks), + temp_storage::make_partition(&temporary_storage_merge_sort, storage_size_merge_sort))); + + if(partition_result != hipSuccess || temporary_storage == nullptr) + { + return partition_result; + } + + if(size == 0) + { + return hipSuccess; + } + + if(middle > size) + { + return hipErrorInvalidValue; + } + + if(debug_synchronous) + { + std::cout << "-----" << '\n'; + std::cout << "size: " << size << '\n'; + std::cout << "num_buckets: " << num_buckets << '\n'; + std::cout << "num_threads_per_block: " << num_threads_per_block << '\n'; + std::cout << "num_blocks: " << num_blocks << '\n'; + std::cout << "storage_size: " << storage_size << '\n'; + } + + if(keys_input != keys_output) + { + hipError_t error = transform(keys_input, + keys_output, + size, + ::rocprim::identity(), + stream, + debug_synchronous); + if(result != hipSuccess) + { + return result; + } + } + + result = nth_element_keys_impl(keys_output, + keys_buffer, + tree, + middle, + size, + buckets, + equality_buckets, + oracles, + lookback_states, + num_buckets, + stop_recursion_size, + num_threads_per_block, + num_items_per_threads, + nth_element_data, + compare_function, + stream, + debug_synchronous); + if(result != hipSuccess) + { + return result; + } + + return merge_sort_impl(temporary_storage_merge_sort, + storage_size_merge_sort, + keys_output, + keys_output, + static_cast(nullptr), // values_input + static_cast(nullptr), // values_output + middle, + compare_function, + stream, + debug_synchronous, + keys_buffer, // keys_buffer + static_cast(nullptr)); // values_buffer +} + +} // namespace detail + +/// \brief Parallel nth_element for device level. +/// +/// `nth_element` function performs a device-wide nth_element, +/// this function sets nth element as if the list was sorted. +/// Also for all values `i` in `[first, nth)` and all values `j` in `[nth, last)` +/// the condition `comp(*j, *i)` is `false` where `comp` is the compare function. +/// +/// \par Overview +/// * The contents of the inputs are not altered by the function. +/// * Returns the required size of `temporary_storage` in `storage_size` +/// if `temporary_storage` in a null pointer. +/// * Accepts custom compare_functions for nth_element across the device. +/// * Does not work with hipGraph +/// +/// \tparam Config [optional] configuration of the primitive. It has to be `radix_sort_config` +/// or a class derived from it. +/// \tparam KeysIterator [inferred] random-access iterator type of the input range. Must meet the +/// requirements of a C++ InputIterator concept. It can be a simple pointer type. +/// \tparam CompareFunction [inferred] Type of binary function that accepts two arguments of the +/// type `KeysIterator` and returns a value convertible to bool. Default type is `::rocprim::less<>.` +/// +/// \param [in] temporary_storage pointer to a device-accessible temporary storage. When +/// a null pointer is passed, the required allocation size (in bytes) is written to +/// `storage_size` and function returns without performing the sort operation. +/// \param [in,out] storage_size reference to a size (in bytes) of `temporary_storage`. +/// \param [in] keys_input iterator to the input range. +/// \param [out] keys_output iterator to the output range. Allowed to point to the same elements as `keys_input`. +/// Only complete overlap or no overlap at all is allowed between `keys_input` and `keys_output`. In other words +/// writing to `keys_output[i]` is only allowed to overwrite `keys_input[i]`, any other element must not be changed. +/// \param [in] nth The index of the nth_element in the input range. +/// \param [in] size number of element in the input range. +/// \param [in] compare_function binary operation function object that will be used for comparison. +/// The signature of the function should be equivalent to the following: +/// bool f(const T &a, const T &b);. The signature does not need to have +/// const &, but function object must not modify the objects passed to it. +/// The comperator must meet the C++ named requirement Compare. +/// The default value is `BinaryFunction()`. +/// \param [in] stream [optional] HIP stream object. Default is `0` (default stream). +/// \param [in] debug_synchronous [optional] If true, synchronization after every kernel +/// launch is forced in order to check for errors. Default value is `false`. +/// +/// \returns `hipSuccess` (`0`) after successful sort; otherwise a HIP runtime error of +/// type `hipError_t`. +/// +/// \par Example +/// \parblock +/// In this example a device-level nth_element is performed where input keys are +/// represented by an array of unsigned integers. +/// +/// \code{.cpp} +/// #include +/// +/// // Prepare input and output (declare pointers, allocate device memory etc.) +/// size_t input_size; // e.g., 8 +/// size_t nth; // e.g., 4 +/// unsigned int * keys_input; // e.g., [ 6, 3, 5, 4, 1, 8, 2, 7 ] +/// unsigned int * keys_output; // empty array of 8 elements +/// +/// size_t temporary_storage_size_bytes; +/// void * temporary_storage_ptr = nullptr; +/// // Get required size of the temporary storage +/// rocprim::nth_element( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// keys_input, keys_output, nth, input_size +/// ); +/// +/// // allocate temporary storage +/// hipMalloc(&temporary_storage_ptr, temporary_storage_size_bytes); +/// +/// // perform nth_element +/// rocprim::nth_element( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// keys_input, keys_output, nth, input_size +/// ); +/// // possible keys_output: [ 1, 3, 4, 2, 5, 8, 7, 6 ] +/// \endcode +/// \endparblock +template::value_type>> +hipError_t partial_sort(void* temporary_storage, + size_t& storage_size, + KeysInputIterator keys_input, + KeysOutputIterator keys_output, + size_t middle, + size_t size, + BinaryFunction compare_function = BinaryFunction(), + hipStream_t stream = 0, + bool debug_synchronous = false) +{ + return detail::partial_sort_impl(temporary_storage, + storage_size, + keys_input, + keys_output, + middle, + size, + compare_function, + stream, + debug_synchronous); +} + /// @} // end of group devicemodule diff --git a/test/rocprim/CMakeLists.txt b/test/rocprim/CMakeLists.txt index 1d61cafa3..3f6e9f02f 100644 --- a/test/rocprim/CMakeLists.txt +++ b/test/rocprim/CMakeLists.txt @@ -260,6 +260,7 @@ add_rocprim_test("rocprim.device_histogram" test_device_histogram.cpp) add_rocprim_test("rocprim.device_merge" test_device_merge.cpp) add_rocprim_test("rocprim.device_merge_sort" test_device_merge_sort.cpp) add_rocprim_cpp17_test("rocprim.nth_element" test_device_nth_element.cpp) +add_rocprim_test("rocprim.device_partial_sort" test_device_partial_sort.cpp) add_rocprim_test("rocprim.device_partition" test_device_partition.cpp) add_rocprim_test_parallel("rocprim.device_radix_sort" test_device_radix_sort.cpp.in) add_rocprim_test("rocprim.device_reduce_by_key" test_device_reduce_by_key.cpp) diff --git a/test/rocprim/test_device_partial_sort.cpp b/test/rocprim/test_device_partial_sort.cpp new file mode 100644 index 000000000..f643e4f96 --- /dev/null +++ b/test/rocprim/test_device_partial_sort.cpp @@ -0,0 +1,236 @@ +// MIT License +// +// Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +// required test headers +#include "test_utils_assertions.hpp" +#include "test_utils_custom_float_type.hpp" +#include "test_utils_custom_test_types.hpp" +#include "test_utils_data_generation.hpp" +#include "test_utils_types.hpp" + +#include "../common_test_header.hpp" + +// required rocprim headers +#include +#include +#include +#include + +#include +#include +#include + +#include +#include + +// Params for tests +template, + class Config = ::rocprim::default_config, + bool UseGraphs = false> +struct DevicePartialSortParams +{ + using key_type = KeyType; + using compare_function = CompareFunction; + using config = Config; + static constexpr bool use_graphs = UseGraphs; +}; + +template +class RocprimDevicePartialSortTests : public ::testing::Test +{ +public: + using key_type = typename Params::key_type; + using compare_function = typename Params::compare_function; + using config = typename Params::config; + const bool debug_synchronous = false; + bool use_graphs = Params::use_graphs; +}; + +// TODO add custom config +// TODO no graph support +using RocprimDevicePartialSortTestsParams + = ::testing::Types, + DevicePartialSortParams, + DevicePartialSortParams, + DevicePartialSortParams>, + DevicePartialSortParams, + DevicePartialSortParams>, + DevicePartialSortParams, + DevicePartialSortParams, + DevicePartialSortParams, + DevicePartialSortParams, + DevicePartialSortParams, + DevicePartialSortParams, + DevicePartialSortParams>, + DevicePartialSortParams, + DevicePartialSortParams>>; + +TYPED_TEST_SUITE(RocprimDevicePartialSortTests, RocprimDevicePartialSortTestsParams); + +TYPED_TEST(RocprimDevicePartialSortTests, PartialSort) +{ + int device_id = test_common_utils::obtain_device_from_ctest(); + SCOPED_TRACE(testing::Message() << "with device_id = " << device_id); + HIP_CHECK(hipSetDevice(device_id)); + + using key_type = typename TestFixture::key_type; + using compare_function = typename TestFixture::compare_function; + using config = typename TestFixture::config; + const bool debug_synchronous = TestFixture::debug_synchronous; + + bool in_place = false; + + for(size_t seed_index = 0; seed_index < random_seeds_count + seed_size; ++seed_index) + { + unsigned int seed_value + = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count]; + SCOPED_TRACE(testing::Message() << "with seed = " << seed_value); + + for(size_t size : test_utils::get_sizes(seed_value)) + { + SCOPED_TRACE(testing::Message() << "with size = " << size); + + std::vector middles = {0}; + if(size > 0) + { + middles.push_back(size); + } + if(size > 1) + { + middles.push_back(test_utils::get_random_value(1, size - 1, seed_value)); + } + + for(size_t middle : middles) + { + SCOPED_TRACE(testing::Message() << "with middle = " << middle); + + hipStream_t stream = 0; // default + if(TestFixture::use_graphs) + { + HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); + } + + std::vector input; + if(rocprim::is_floating_point::value) + { + input = test_utils::get_random_data(size, -1000, 1000, seed_value); + } + else + { + input = test_utils::get_random_data( + size, + test_utils::numeric_limits::min(), + test_utils::numeric_limits::max(), + seed_value); + } + key_type* d_input; + HIP_CHECK(test_common_utils::hipMallocHelper(&d_input, size * sizeof(key_type))); + HIP_CHECK(hipMemcpy(d_input, + input.data(), + size * sizeof(key_type), + hipMemcpyHostToDevice)); + + key_type* d_output; + if(in_place) + { + d_output = d_input; + } + else + { + HIP_CHECK( + test_common_utils::hipMallocHelper(&d_output, size * sizeof(key_type))); + } + + compare_function compare_op; + + // Allocate temporary storage + size_t temp_storage_size_bytes{}; + HIP_CHECK(rocprim::partial_sort(nullptr, + temp_storage_size_bytes, + d_input, + d_output, + middle, + size, + compare_op, + stream, + debug_synchronous)); + ASSERT_GT(temp_storage_size_bytes, 0); + void* d_temp_storage{}; + HIP_CHECK( + test_common_utils::hipMallocHelper(&d_temp_storage, temp_storage_size_bytes)); + + hipGraph_t graph; + if(TestFixture::use_graphs) + { + graph = test_utils::createGraphHelper(stream); + } + + HIP_CHECK(rocprim::partial_sort(d_temp_storage, + temp_storage_size_bytes, + d_input, + d_output, + middle, + size, + compare_op, + stream, + debug_synchronous)); + HIP_CHECK(hipGetLastError()); + + hipGraphExec_t graph_instance; + if(TestFixture::use_graphs) + { + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); + } + + // The algorithm sorted [first, middle). Since the order of [middle, last) is not specified, + // sort [middle, last) to compare with expected values. + std::vector output(size); + HIP_CHECK(hipMemcpy(output.data(), + d_output, + size * sizeof(key_type), + hipMemcpyDeviceToHost)); + std::sort(output.begin() + middle, output.begin() + size, compare_op); + + // Sort input fully to compare + std::sort(input.begin(), input.end(), compare_op); + + ASSERT_NO_FATAL_FAILURE(test_utils::assert_eq(output, input)); + + HIP_CHECK(hipFree(d_input)); + if(!in_place) + { + hipFree(d_output); + } + HIP_CHECK(hipFree(d_temp_storage)); + + if(TestFixture::use_graphs) + { + test_utils::cleanupGraphHelper(graph, graph_instance); + HIP_CHECK(hipStreamDestroy(stream)); + } + + in_place = !in_place; + } + } + } +} From f0264705876dd9e521d3d6f4062478ec1b13841b Mon Sep 17 00:00:00 2001 From: Nick Breed Date: Fri, 17 May 2024 07:57:19 +0000 Subject: [PATCH 02/14] partial sort benchmark --- benchmark/CMakeLists.txt | 1 + benchmark/benchmark_device_partial_sort.cpp | 123 +++++++++++++ benchmark/benchmark_device_partial_sort.hpp | 180 ++++++++++++++++++++ 3 files changed, 304 insertions(+) create mode 100644 benchmark/benchmark_device_partial_sort.cpp create mode 100644 benchmark/benchmark_device_partial_sort.hpp diff --git a/benchmark/CMakeLists.txt b/benchmark/CMakeLists.txt index c743eca3b..638b7a16b 100644 --- a/benchmark/CMakeLists.txt +++ b/benchmark/CMakeLists.txt @@ -142,6 +142,7 @@ add_rocprim_benchmark(benchmark_device_merge_sort.cpp) add_rocprim_benchmark(benchmark_device_merge_sort_block_sort.cpp) add_rocprim_benchmark(benchmark_device_merge_sort_block_merge.cpp) add_rocprim_benchmark(benchmark_device_nth_element.cpp) +add_rocprim_benchmark(benchmark_device_partial_sort.cpp) add_rocprim_benchmark(benchmark_device_partition.cpp) add_rocprim_benchmark(benchmark_device_radix_sort.cpp) add_rocprim_benchmark(benchmark_device_radix_sort_block_sort.cpp) diff --git a/benchmark/benchmark_device_partial_sort.cpp b/benchmark/benchmark_device_partial_sort.cpp new file mode 100644 index 000000000..a921e537f --- /dev/null +++ b/benchmark/benchmark_device_partial_sort.cpp @@ -0,0 +1,123 @@ +// MIT License +// +// Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +#include "benchmark_device_partial_sort.hpp" +#include "benchmark_utils.hpp" + +// CmdParser +#include "cmdparser.hpp" + +// Google Benchmark +#include + +// HIP API +#include + +#include +#include + +#ifndef DEFAULT_N +const size_t DEFAULT_N = 1024 * 1024 * 32; +#endif + +#define CREATE_BENCHMARK_PARTIAL_SORT(TYPE, SMALL_N) \ + { \ + const device_partial_sort_benchmark instance(SMALL_N); \ + REGISTER_BENCHMARK(benchmarks, size, seed, stream, instance); \ + } + +#define CREATE_BENCHMARK(TYPE) \ + { \ + CREATE_BENCHMARK_PARTIAL_SORT(TYPE, true) \ + CREATE_BENCHMARK_PARTIAL_SORT(TYPE, false) \ + } + +int main(int argc, char* argv[]) +{ + cli::Parser parser(argc, argv); + parser.set_optional("size", "size", DEFAULT_N, "number of values"); + parser.set_optional("trials", "trials", -1, "number of iterations"); + parser.set_optional("name_format", + "name_format", + "human", + "either: json,human,txt"); + parser.set_optional("seed", "seed", "random", get_seed_message()); + parser.run_and_exit_if_error(); + + // Parse argv + benchmark::Initialize(&argc, argv); + const size_t size = parser.get("size"); + const int trials = parser.get("trials"); + bench_naming::set_format(parser.get("name_format")); + const std::string seed_type = parser.get("seed"); + const managed_seed seed(seed_type); + + // HIP + hipStream_t stream = 0; // default + + // Benchmark info + add_common_benchmark_info(); + benchmark::AddCustomContext("size", std::to_string(size)); + benchmark::AddCustomContext("seed", seed_type); + + // Add benchmarks + std::vector benchmarks = {}; + CREATE_BENCHMARK(int) + CREATE_BENCHMARK(long long) + CREATE_BENCHMARK(int8_t) + CREATE_BENCHMARK(uint8_t) + CREATE_BENCHMARK(rocprim::half) + CREATE_BENCHMARK(short) + CREATE_BENCHMARK(float) + + using custom_float2 = custom_type; + using custom_double2 = custom_type; + using custom_int2 = custom_type; + using custom_char_double = custom_type; + using custom_longlong_double = custom_type; + + CREATE_BENCHMARK(custom_float2) + CREATE_BENCHMARK(custom_double2) + CREATE_BENCHMARK(custom_int2) + CREATE_BENCHMARK(custom_char_double) + CREATE_BENCHMARK(custom_longlong_double) + + // Use manual timing + for(auto& b : benchmarks) + { + b->UseManualTime(); + b->Unit(benchmark::kMillisecond); + } + + // Force number of iterations + if(trials > 0) + { + for(auto& b : benchmarks) + { + b->Iterations(trials); + } + } + + // Run benchmarks + benchmark::RunSpecifiedBenchmarks(); + return 0; +} diff --git a/benchmark/benchmark_device_partial_sort.hpp b/benchmark/benchmark_device_partial_sort.hpp new file mode 100644 index 000000000..072d1342c --- /dev/null +++ b/benchmark/benchmark_device_partial_sort.hpp @@ -0,0 +1,180 @@ +// MIT License +// +// Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +#ifndef ROCPRIM_BENCHMARK_DEVICE_PARTIAL_SORT_PARALLEL_HPP_ +#define ROCPRIM_BENCHMARK_DEVICE_PARTIAL_SORT_PARALLEL_HPP_ + +#include "benchmark_utils.hpp" + +// Google Benchmark +#include + +// HIP API +#include + +// rocPRIM +#include + +#include +#include + +#include + +template +struct device_partial_sort_benchmark : public config_autotune_interface +{ + bool small_n = false; + + device_partial_sort_benchmark(bool SmallN) + { + small_n = SmallN; + } + + std::string name() const override + { + using namespace std::string_literals; + return bench_naming::format_name( + "{lvl:device,algo:partial_sort,nth:" + (small_n ? "small"s : "large"s) + + ",key_type:" + std::string(Traits::name()) + ",cfg:default_config}"); + } + + static constexpr unsigned int batch_size = 10; + static constexpr unsigned int warmup_size = 5; + + void run(benchmark::State& state, + size_t size, + const managed_seed& seed, + hipStream_t stream) const override + { + using key_type = Key; + + size_t nth = 10; + + if(!small_n) + { + nth = size / 2; + } + + // Generate data + std::vector keys_input; + if(std::is_floating_point::value) + { + keys_input = get_random_data(size, + static_cast(-1000), + static_cast(1000), + seed.get_0()); + } + else + { + keys_input = get_random_data(size, + std::numeric_limits::min(), + std::numeric_limits::max(), + seed.get_0()); + } + + key_type* d_keys_input; + key_type* d_keys_output; + HIP_CHECK(hipMalloc(&d_keys_input, size * sizeof(*d_keys_input))); + HIP_CHECK(hipMalloc(&d_keys_output, size * sizeof(*d_keys_output))); + + HIP_CHECK(hipMemcpy(d_keys_input, + keys_input.data(), + size * sizeof(*d_keys_input), + hipMemcpyHostToDevice)); + + ::rocprim::less lesser_op; + + void* d_temporary_storage = nullptr; + size_t temporary_storage_bytes = 0; + HIP_CHECK(rocprim::partial_sort(d_temporary_storage, + temporary_storage_bytes, + d_keys_input, + d_keys_output, + nth, + size, + lesser_op, + stream, + false)); + + HIP_CHECK(hipMalloc(&d_temporary_storage, temporary_storage_bytes)); + + // Warm-up + for(size_t i = 0; i < warmup_size; i++) + { + HIP_CHECK(rocprim::partial_sort(d_temporary_storage, + temporary_storage_bytes, + d_keys_input, + d_keys_output, + nth, + size, + lesser_op, + stream, + false)); + } + HIP_CHECK(hipDeviceSynchronize()); + + // HIP events creation + hipEvent_t start, stop; + HIP_CHECK(hipEventCreate(&start)); + HIP_CHECK(hipEventCreate(&stop)); + + for(auto _ : state) + { + // Record start event + HIP_CHECK(hipEventRecord(start, stream)); + + for(size_t i = 0; i < batch_size; i++) + { + HIP_CHECK(rocprim::partial_sort(d_temporary_storage, + temporary_storage_bytes, + d_keys_input, + d_keys_output, + nth, + size, + lesser_op, + stream, + false)); + } + + // Record stop event and wait until it completes + HIP_CHECK(hipEventRecord(stop, stream)); + HIP_CHECK(hipEventSynchronize(stop)); + + float elapsed_mseconds; + HIP_CHECK(hipEventElapsedTime(&elapsed_mseconds, start, stop)); + state.SetIterationTime(elapsed_mseconds / 1000); + } + + // Destroy HIP events + HIP_CHECK(hipEventDestroy(start)); + HIP_CHECK(hipEventDestroy(stop)); + + state.SetBytesProcessed(state.iterations() * batch_size * size * sizeof(*d_keys_input)); + state.SetItemsProcessed(state.iterations() * batch_size * size); + + HIP_CHECK(hipFree(d_temporary_storage)); + HIP_CHECK(hipFree(d_keys_input)); + HIP_CHECK(hipFree(d_keys_output)); + } +}; + +#endif // ROCPRIM_BENCHMARK_DEVICE_PARTIAL_SORT_PARALLEL_HPP_ From 3774e9b776982d1313ce9d2c1488018c435074da Mon Sep 17 00:00:00 2001 From: Nick Breed Date: Fri, 17 May 2024 08:33:29 +0000 Subject: [PATCH 03/14] Made partial_sort in place and created partial_sort_copy --- benchmark/benchmark_device_partial_sort.hpp | 58 ++--- .../rocprim/device/device_nth_element.hpp | 220 +++++++++++++----- test/rocprim/test_device_partial_sort.cpp | 65 ++++-- 3 files changed, 234 insertions(+), 109 deletions(-) diff --git a/benchmark/benchmark_device_partial_sort.hpp b/benchmark/benchmark_device_partial_sort.hpp index 072d1342c..560d2633d 100644 --- a/benchmark/benchmark_device_partial_sort.hpp +++ b/benchmark/benchmark_device_partial_sort.hpp @@ -67,11 +67,11 @@ struct device_partial_sort_benchmark : public config_autotune_interface { using key_type = Key; - size_t nth = 10; + size_t middle = 10; if(!small_n) { - nth = size / 2; + middle = size / 2; } // Generate data @@ -105,30 +105,30 @@ struct device_partial_sort_benchmark : public config_autotune_interface void* d_temporary_storage = nullptr; size_t temporary_storage_bytes = 0; - HIP_CHECK(rocprim::partial_sort(d_temporary_storage, - temporary_storage_bytes, - d_keys_input, - d_keys_output, - nth, - size, - lesser_op, - stream, - false)); + HIP_CHECK(rocprim::partial_sort_copy(d_temporary_storage, + temporary_storage_bytes, + d_keys_input, + d_keys_output, + middle, + size, + lesser_op, + stream, + false)); HIP_CHECK(hipMalloc(&d_temporary_storage, temporary_storage_bytes)); // Warm-up for(size_t i = 0; i < warmup_size; i++) { - HIP_CHECK(rocprim::partial_sort(d_temporary_storage, - temporary_storage_bytes, - d_keys_input, - d_keys_output, - nth, - size, - lesser_op, - stream, - false)); + HIP_CHECK(rocprim::partial_sort_copy(d_temporary_storage, + temporary_storage_bytes, + d_keys_input, + d_keys_output, + middle, + size, + lesser_op, + stream, + false)); } HIP_CHECK(hipDeviceSynchronize()); @@ -144,15 +144,15 @@ struct device_partial_sort_benchmark : public config_autotune_interface for(size_t i = 0; i < batch_size; i++) { - HIP_CHECK(rocprim::partial_sort(d_temporary_storage, - temporary_storage_bytes, - d_keys_input, - d_keys_output, - nth, - size, - lesser_op, - stream, - false)); + HIP_CHECK(rocprim::partial_sort_copy(d_temporary_storage, + temporary_storage_bytes, + d_keys_input, + d_keys_output, + middle, + size, + lesser_op, + stream, + false)); } // Record stop event and wait until it completes diff --git a/rocprim/include/rocprim/device/device_nth_element.hpp b/rocprim/include/rocprim/device/device_nth_element.hpp index 680dc05db..927b2c885 100644 --- a/rocprim/include/rocprim/device/device_nth_element.hpp +++ b/rocprim/include/rocprim/device/device_nth_element.hpp @@ -339,24 +339,18 @@ ROCPRIM_INLINE hipError_t nth_element(void* temporary_storage, namespace detail { -template -hipError_t partial_sort_impl(void* temporary_storage, - size_t& storage_size, - KeysInputIterator keys_input, - KeysOutputIterator keys_output, - size_t middle, - size_t size, - BinaryFunction compare_function, - hipStream_t stream, - bool debug_synchronous) +template +hipError_t partial_sort_impl(void* temporary_storage, + size_t& storage_size, + KeysIterator keys, + size_t middle, + size_t size, + BinaryFunction compare_function, + hipStream_t stream, + bool debug_synchronous) { - using key_type = typename std::iterator_traits::value_type; - static_assert( - std::is_same::value_type>::value, - "KeysInputIterator and KeysOutputIterator must have the same value_type"); - - using config = wrapped_nth_element_config; + using key_type = typename std::iterator_traits::value_type; + using config = wrapped_nth_element_config; target_arch target_arch; hipError_t result = host_target_arch(stream, target_arch); @@ -381,8 +375,8 @@ hipError_t partial_sort_impl(void* temporary_storage, result = merge_sort_impl(nullptr, storage_size_merge_sort, - keys_output, - keys_output, + keys, + keys, static_cast(nullptr), // values_input static_cast(nullptr), // values_output middle, @@ -443,37 +437,23 @@ hipError_t partial_sort_impl(void* temporary_storage, std::cout << "storage_size: " << storage_size << '\n'; } - if(keys_input != keys_output) - { - hipError_t error = transform(keys_input, - keys_output, - size, - ::rocprim::identity(), - stream, - debug_synchronous); - if(result != hipSuccess) - { - return result; - } - } - - result = nth_element_keys_impl(keys_output, - keys_buffer, - tree, - middle, - size, - buckets, - equality_buckets, - oracles, - lookback_states, - num_buckets, - stop_recursion_size, - num_threads_per_block, - num_items_per_threads, - nth_element_data, - compare_function, - stream, - debug_synchronous); + result = nth_element_keys_impl(keys, + keys_buffer, + tree, + middle, + size, + buckets, + equality_buckets, + oracles, + lookback_states, + num_buckets, + stop_recursion_size, + num_threads_per_block, + num_items_per_threads, + nth_element_data, + compare_function, + stream, + debug_synchronous); if(result != hipSuccess) { return result; @@ -481,8 +461,8 @@ hipError_t partial_sort_impl(void* temporary_storage, return merge_sort_impl(temporary_storage_merge_sort, storage_size_merge_sort, - keys_output, - keys_output, + keys, + keys, static_cast(nullptr), // values_input static_cast(nullptr), // values_output middle, @@ -577,19 +557,35 @@ template::value_type>> -hipError_t partial_sort(void* temporary_storage, - size_t& storage_size, - KeysInputIterator keys_input, - KeysOutputIterator keys_output, - size_t middle, - size_t size, - BinaryFunction compare_function = BinaryFunction(), - hipStream_t stream = 0, - bool debug_synchronous = false) +hipError_t partial_sort_copy(void* temporary_storage, + size_t& storage_size, + KeysInputIterator keys_input, + KeysOutputIterator keys_output, + size_t middle, + size_t size, + BinaryFunction compare_function = BinaryFunction(), + hipStream_t stream = 0, + bool debug_synchronous = false) { + using key_type = typename std::iterator_traits::value_type; + static_assert( + std::is_same::value_type>::value, + "KeysInputIterator and KeysOutputIterator must have the same value_type"); + + hipError_t error = transform(keys_input, + keys_output, + size, + ::rocprim::identity(), + stream, + debug_synchronous); + if(error != hipSuccess) + { + return error; + } + return detail::partial_sort_impl(temporary_storage, storage_size, - keys_input, keys_output, middle, size, @@ -598,6 +594,106 @@ hipError_t partial_sort(void* temporary_storage, debug_synchronous); } +/// \brief Parallel nth_element for device level. +/// +/// `nth_element` function performs a device-wide nth_element, +/// this function sets nth element as if the list was sorted. +/// Also for all values `i` in `[first, nth)` and all values `j` in `[nth, last)` +/// the condition `comp(*j, *i)` is `false` where `comp` is the compare function. +/// +/// \par Overview +/// * The contents of the inputs are not altered by the function. +/// * Returns the required size of `temporary_storage` in `storage_size` +/// if `temporary_storage` in a null pointer. +/// * Accepts custom compare_functions for nth_element across the device. +/// * Does not work with hipGraph +/// +/// \tparam Config [optional] configuration of the primitive. It has to be `radix_sort_config` +/// or a class derived from it. +/// \tparam KeysIterator [inferred] random-access iterator type of the input range. Must meet the +/// requirements of a C++ InputIterator concept. It can be a simple pointer type. +/// \tparam CompareFunction [inferred] Type of binary function that accepts two arguments of the +/// type `KeysIterator` and returns a value convertible to bool. Default type is `::rocprim::less<>.` +/// +/// \param [in] temporary_storage pointer to a device-accessible temporary storage. When +/// a null pointer is passed, the required allocation size (in bytes) is written to +/// `storage_size` and function returns without performing the sort operation. +/// \param [in,out] storage_size reference to a size (in bytes) of `temporary_storage`. +/// \param [in] keys_input iterator to the input range. +/// \param [out] keys_output iterator to the output range. Allowed to point to the same elements as `keys_input`. +/// Only complete overlap or no overlap at all is allowed between `keys_input` and `keys_output`. In other words +/// writing to `keys_output[i]` is only allowed to overwrite `keys_input[i]`, any other element must not be changed. +/// \param [in] nth The index of the nth_element in the input range. +/// \param [in] size number of element in the input range. +/// \param [in] compare_function binary operation function object that will be used for comparison. +/// The signature of the function should be equivalent to the following: +/// bool f(const T &a, const T &b);. The signature does not need to have +/// const &, but function object must not modify the objects passed to it. +/// The comperator must meet the C++ named requirement Compare. +/// The default value is `BinaryFunction()`. +/// \param [in] stream [optional] HIP stream object. Default is `0` (default stream). +/// \param [in] debug_synchronous [optional] If true, synchronization after every kernel +/// launch is forced in order to check for errors. Default value is `false`. +/// +/// \returns `hipSuccess` (`0`) after successful sort; otherwise a HIP runtime error of +/// type `hipError_t`. +/// +/// \par Example +/// \parblock +/// In this example a device-level nth_element is performed where input keys are +/// represented by an array of unsigned integers. +/// +/// \code{.cpp} +/// #include +/// +/// // Prepare input and output (declare pointers, allocate device memory etc.) +/// size_t input_size; // e.g., 8 +/// size_t nth; // e.g., 4 +/// unsigned int * keys_input; // e.g., [ 6, 3, 5, 4, 1, 8, 2, 7 ] +/// unsigned int * keys_output; // empty array of 8 elements +/// +/// size_t temporary_storage_size_bytes; +/// void * temporary_storage_ptr = nullptr; +/// // Get required size of the temporary storage +/// rocprim::nth_element( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// keys_input, keys_output, nth, input_size +/// ); +/// +/// // allocate temporary storage +/// hipMalloc(&temporary_storage_ptr, temporary_storage_size_bytes); +/// +/// // perform nth_element +/// rocprim::nth_element( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// keys_input, keys_output, nth, input_size +/// ); +/// // possible keys_output: [ 1, 3, 4, 2, 5, 8, 7, 6 ] +/// \endcode +/// \endparblock +template::value_type>> +hipError_t partial_sort(void* temporary_storage, + size_t& storage_size, + KeysIterator keys, + size_t middle, + size_t size, + BinaryFunction compare_function = BinaryFunction(), + hipStream_t stream = 0, + bool debug_synchronous = false) +{ + return detail::partial_sort_impl(temporary_storage, + storage_size, + keys, + middle, + size, + compare_function, + stream, + debug_synchronous); +} + /// @} // end of group devicemodule diff --git a/test/rocprim/test_device_partial_sort.cpp b/test/rocprim/test_device_partial_sort.cpp index f643e4f96..50fa45159 100644 --- a/test/rocprim/test_device_partial_sort.cpp +++ b/test/rocprim/test_device_partial_sort.cpp @@ -165,15 +165,30 @@ TYPED_TEST(RocprimDevicePartialSortTests, PartialSort) // Allocate temporary storage size_t temp_storage_size_bytes{}; - HIP_CHECK(rocprim::partial_sort(nullptr, - temp_storage_size_bytes, - d_input, - d_output, - middle, - size, - compare_op, - stream, - debug_synchronous)); + if(in_place) + { + HIP_CHECK(rocprim::partial_sort(nullptr, + temp_storage_size_bytes, + d_input, + middle, + size, + compare_op, + stream, + debug_synchronous)); + } + else + { + HIP_CHECK(rocprim::partial_sort_copy(nullptr, + temp_storage_size_bytes, + d_input, + d_output, + middle, + size, + compare_op, + stream, + debug_synchronous)); + } + ASSERT_GT(temp_storage_size_bytes, 0); void* d_temp_storage{}; HIP_CHECK( @@ -184,16 +199,30 @@ TYPED_TEST(RocprimDevicePartialSortTests, PartialSort) { graph = test_utils::createGraphHelper(stream); } + if(in_place) + { + HIP_CHECK(rocprim::partial_sort(d_temp_storage, + temp_storage_size_bytes, + d_input, + middle, + size, + compare_op, + stream, + debug_synchronous)); + } + else + { + HIP_CHECK(rocprim::partial_sort_copy(d_temp_storage, + temp_storage_size_bytes, + d_input, + d_output, + middle, + size, + compare_op, + stream, + debug_synchronous)); + } - HIP_CHECK(rocprim::partial_sort(d_temp_storage, - temp_storage_size_bytes, - d_input, - d_output, - middle, - size, - compare_op, - stream, - debug_synchronous)); HIP_CHECK(hipGetLastError()); hipGraphExec_t graph_instance; From 4f7986ffcd4044e8cc25d1f007c9b7a8947d298d Mon Sep 17 00:00:00 2001 From: Nick Breed Date: Fri, 17 May 2024 09:15:53 +0000 Subject: [PATCH 04/14] Add and fix documentation partial_sort --- docs/device_ops/index.rst | 1 + docs/device_ops/partial_sort.rst | 20 +++++ docs/reference/ops_summary.rst | 1 + docs/sphinx/_toc.yml.in | 1 + .../rocprim/device/device_nth_element.hpp | 85 ++++++++----------- 5 files changed, 58 insertions(+), 50 deletions(-) create mode 100644 docs/device_ops/partial_sort.rst diff --git a/docs/device_ops/index.rst b/docs/device_ops/index.rst index e80ed2d31..3c27a1c15 100644 --- a/docs/device_ops/index.rst +++ b/docs/device_ops/index.rst @@ -24,3 +24,4 @@ * :ref:`dev-device_copy` * :ref:`dev-memcpy` * :ref:`dev-nth_element` + * :ref:`dev-partial_sort` diff --git a/docs/device_ops/partial_sort.rst b/docs/device_ops/partial_sort.rst new file mode 100644 index 000000000..6b5632a2e --- /dev/null +++ b/docs/device_ops/partial_sort.rst @@ -0,0 +1,20 @@ +.. meta:: + :description: rocPRIM documentation and API reference library + :keywords: rocPRIM, ROCm, API, documentation + +.. _dev-partial_sort: + + +Partial Sort +------------ + +Configuring the kernel +~~~~~~~~~~~~~~~~~~~~~~ + +.. doxygenstruct:: rocprim::nth_element_config + +partial_sort +~~~~~~~~~~~~ + +.. doxygenfunction:: rocprim::partial_sort(void* temporary_storage, size_t& storage_size, KeysIterator keys, size_t middle, size_t size, BinaryFunction compare_function = BinaryFunction(), hipStream_t stream = 0, bool debug_synchronous = false) +.. doxygenfunction:: rocprim::partial_sort_copy(void* temporary_storage, size_t& storage_size, KeysInputIterator keys_input, KeysOutputIterator keys_output, size_t middle, size_t size, BinaryFunction compare_function = BinaryFunction(), hipStream_t stream = 0, bool debug_synchronous = false) diff --git a/docs/reference/ops_summary.rst b/docs/reference/ops_summary.rst index 8dc8b678e..c308eb6bd 100644 --- a/docs/reference/ops_summary.rst +++ b/docs/reference/ops_summary.rst @@ -32,6 +32,7 @@ Rearrangement ================ * ``sort`` rearranges the sequence by sorting it. It could be according to a comparison operator or a value using a radix approach +* ``partial_sort`` rearranges the sequence by sorting it up to and including the middle index, according to a comparison operator. * ``nth_element`` places the nth element in its sorted position, with elements less-than before, and greater after, according to a comparison operator. * ``exchange`` rearranges the elements according to a different stride configuration which is equivalent to a tensor axis transposition * ``shuffle`` rotates the elements diff --git a/docs/sphinx/_toc.yml.in b/docs/sphinx/_toc.yml.in index d99e195ca..29e2bf154 100644 --- a/docs/sphinx/_toc.yml.in +++ b/docs/sphinx/_toc.yml.in @@ -23,6 +23,7 @@ subtrees: - file: device_ops/transform.rst - file: device_ops/unique.rst - file: device_ops/sort.rst + - file: device_ops/partial_sort.rst - file: device_ops/nth_element.rst - file: device_ops/merge.rst - file: device_ops/partition.rst diff --git a/rocprim/include/rocprim/device/device_nth_element.hpp b/rocprim/include/rocprim/device/device_nth_element.hpp index 927b2c885..85f93b98f 100644 --- a/rocprim/include/rocprim/device/device_nth_element.hpp +++ b/rocprim/include/rocprim/device/device_nth_element.hpp @@ -475,36 +475,31 @@ hipError_t partial_sort_impl(void* temporary_storage, } // namespace detail -/// \brief Parallel nth_element for device level. -/// -/// `nth_element` function performs a device-wide nth_element, -/// this function sets nth element as if the list was sorted. -/// Also for all values `i` in `[first, nth)` and all values `j` in `[nth, last)` -/// the condition `comp(*j, *i)` is `false` where `comp` is the compare function. +/// \brief Rearranges elements such that the range [0, middle) contains the sorted middle smallest elements in the range [0, size). /// /// \par Overview /// * The contents of the inputs are not altered by the function. /// * Returns the required size of `temporary_storage` in `storage_size` -/// if `temporary_storage` in a null pointer. +/// if `temporary_storage` is a null pointer. /// * Accepts custom compare_functions for nth_element across the device. -/// * Does not work with hipGraph +/// * Streams in graph capture mode are not supported /// -/// \tparam Config [optional] configuration of the primitive. It has to be `radix_sort_config` -/// or a class derived from it. -/// \tparam KeysIterator [inferred] random-access iterator type of the input range. Must meet the +/// \tparam Config [optional] configuration of the primitive. It has to be `nth_element_config`. +/// \tparam KeysInputIterator [inferred] random-access iterator type of the input range. Must meet the +/// requirements of a C++ InputIterator concept. It can be a simple pointer type. +/// \tparam KeysOutputIterator [inferred] random-access iterator type of the output range. Must meet the /// requirements of a C++ InputIterator concept. It can be a simple pointer type. /// \tparam CompareFunction [inferred] Type of binary function that accepts two arguments of the /// type `KeysIterator` and returns a value convertible to bool. Default type is `::rocprim::less<>.` /// /// \param [in] temporary_storage pointer to a device-accessible temporary storage. When /// a null pointer is passed, the required allocation size (in bytes) is written to -/// `storage_size` and function returns without performing the sort operation. +/// `storage_size` and function returns without performing the nth_element rearrangement. /// \param [in,out] storage_size reference to a size (in bytes) of `temporary_storage`. /// \param [in] keys_input iterator to the input range. -/// \param [out] keys_output iterator to the output range. Allowed to point to the same elements as `keys_input`. -/// Only complete overlap or no overlap at all is allowed between `keys_input` and `keys_output`. In other words -/// writing to `keys_output[i]` is only allowed to overwrite `keys_input[i]`, any other element must not be changed. -/// \param [in] nth The index of the nth_element in the input range. +/// \param [out] keys_output iterator to the output range. No overlap at all is allowed between `keys_input` and `keys_output`. +/// `keys_output` should be able to be written and read from for `size` elements. +/// \param [in] middle The index of the point till where it is sorted in the input range. /// \param [in] size number of element in the input range. /// \param [in] compare_function binary operation function object that will be used for comparison. /// The signature of the function should be equivalent to the following: @@ -516,7 +511,7 @@ hipError_t partial_sort_impl(void* temporary_storage, /// \param [in] debug_synchronous [optional] If true, synchronization after every kernel /// launch is forced in order to check for errors. Default value is `false`. /// -/// \returns `hipSuccess` (`0`) after successful sort; otherwise a HIP runtime error of +/// \returns `hipSuccess` (`0`) after successful rearrangement; otherwise a HIP runtime error of /// type `hipError_t`. /// /// \par Example @@ -529,27 +524,27 @@ hipError_t partial_sort_impl(void* temporary_storage, /// /// // Prepare input and output (declare pointers, allocate device memory etc.) /// size_t input_size; // e.g., 8 -/// size_t nth; // e.g., 4 +/// size_t middle; // e.g., 4 /// unsigned int * keys_input; // e.g., [ 6, 3, 5, 4, 1, 8, 2, 7 ] /// unsigned int * keys_output; // empty array of 8 elements /// /// size_t temporary_storage_size_bytes; /// void * temporary_storage_ptr = nullptr; /// // Get required size of the temporary storage -/// rocprim::nth_element( +/// rocprim::partial_sort_copy( /// temporary_storage_ptr, temporary_storage_size_bytes, -/// keys_input, keys_output, nth, input_size +/// keys_input, keys_output, middle, input_size /// ); /// /// // allocate temporary storage /// hipMalloc(&temporary_storage_ptr, temporary_storage_size_bytes); /// -/// // perform nth_element -/// rocprim::nth_element( +/// // perform partial_sort +/// rocprim::partial_sort_copy( /// temporary_storage_ptr, temporary_storage_size_bytes, -/// keys_input, keys_output, nth, input_size +/// keys_input, keys_output, middle, input_size /// ); -/// // possible keys_output: [ 1, 3, 4, 2, 5, 8, 7, 6 ] +/// // possible keys_output: [ 1, 2, 3, 4, 5, 8, 7, 6 ] /// \endcode /// \endparblock template Date: Fri, 17 May 2024 11:21:00 +0000 Subject: [PATCH 05/14] Test partial_sort with iterator --- test/rocprim/test_device_partial_sort.cpp | 81 +++++++++++++---------- 1 file changed, 45 insertions(+), 36 deletions(-) diff --git a/test/rocprim/test_device_partial_sort.cpp b/test/rocprim/test_device_partial_sort.cpp index 50fa45159..eff54743e 100644 --- a/test/rocprim/test_device_partial_sort.cpp +++ b/test/rocprim/test_device_partial_sort.cpp @@ -21,6 +21,7 @@ // SOFTWARE. // required test headers +#include "indirect_iterator.hpp" #include "test_utils_assertions.hpp" #include "test_utils_custom_float_type.hpp" #include "test_utils_custom_test_types.hpp" @@ -44,46 +45,50 @@ // Params for tests template, - class Config = ::rocprim::default_config, - bool UseGraphs = false> + class CompareFunction = ::rocprim::less, + class Config = ::rocprim::default_config, + bool UseGraphs = false, + bool UseIndirectIterator = false> struct DevicePartialSortParams { - using key_type = KeyType; - using compare_function = CompareFunction; - using config = Config; - static constexpr bool use_graphs = UseGraphs; + using key_type = KeyType; + using compare_function = CompareFunction; + using config = Config; + static constexpr bool use_graphs = UseGraphs; + static constexpr bool use_indirect_iterator = UseIndirectIterator; }; template class RocprimDevicePartialSortTests : public ::testing::Test { public: - using key_type = typename Params::key_type; - using compare_function = typename Params::compare_function; - using config = typename Params::config; - const bool debug_synchronous = false; - bool use_graphs = Params::use_graphs; + using key_type = typename Params::key_type; + using compare_function = typename Params::compare_function; + using config = typename Params::config; + const bool debug_synchronous = false; + static constexpr bool use_graphs = Params::use_graphs; + static constexpr bool use_indirect_iterator = Params::use_indirect_iterator; }; // TODO add custom config // TODO no graph support -using RocprimDevicePartialSortTestsParams - = ::testing::Types, - DevicePartialSortParams, - DevicePartialSortParams, - DevicePartialSortParams>, - DevicePartialSortParams, - DevicePartialSortParams>, - DevicePartialSortParams, - DevicePartialSortParams, - DevicePartialSortParams, - DevicePartialSortParams, - DevicePartialSortParams, - DevicePartialSortParams, - DevicePartialSortParams>, - DevicePartialSortParams, - DevicePartialSortParams>>; +using RocprimDevicePartialSortTestsParams = ::testing::Types< + DevicePartialSortParams, + DevicePartialSortParams, + DevicePartialSortParams, + DevicePartialSortParams>, + DevicePartialSortParams, + DevicePartialSortParams>, + DevicePartialSortParams, + DevicePartialSortParams, + DevicePartialSortParams, + DevicePartialSortParams, + DevicePartialSortParams, + DevicePartialSortParams, + DevicePartialSortParams>, + DevicePartialSortParams, + DevicePartialSortParams>, + DevicePartialSortParams, rocprim::default_config, false, true>>; TYPED_TEST_SUITE(RocprimDevicePartialSortTests, RocprimDevicePartialSortTestsParams); @@ -93,10 +98,11 @@ TYPED_TEST(RocprimDevicePartialSortTests, PartialSort) SCOPED_TRACE(testing::Message() << "with device_id = " << device_id); HIP_CHECK(hipSetDevice(device_id)); - using key_type = typename TestFixture::key_type; - using compare_function = typename TestFixture::compare_function; - using config = typename TestFixture::config; - const bool debug_synchronous = TestFixture::debug_synchronous; + using key_type = typename TestFixture::key_type; + using compare_function = typename TestFixture::compare_function; + using config = typename TestFixture::config; + const bool debug_synchronous = TestFixture::debug_synchronous; + static constexpr bool use_indirect_iterator = TestFixture::use_indirect_iterator; bool in_place = false; @@ -161,6 +167,9 @@ TYPED_TEST(RocprimDevicePartialSortTests, PartialSort) test_common_utils::hipMallocHelper(&d_output, size * sizeof(key_type))); } + const auto input_it + = test_utils::wrap_in_indirect_iterator(d_input); + compare_function compare_op; // Allocate temporary storage @@ -169,7 +178,7 @@ TYPED_TEST(RocprimDevicePartialSortTests, PartialSort) { HIP_CHECK(rocprim::partial_sort(nullptr, temp_storage_size_bytes, - d_input, + input_it, middle, size, compare_op, @@ -180,7 +189,7 @@ TYPED_TEST(RocprimDevicePartialSortTests, PartialSort) { HIP_CHECK(rocprim::partial_sort_copy(nullptr, temp_storage_size_bytes, - d_input, + input_it, d_output, middle, size, @@ -203,7 +212,7 @@ TYPED_TEST(RocprimDevicePartialSortTests, PartialSort) { HIP_CHECK(rocprim::partial_sort(d_temp_storage, temp_storage_size_bytes, - d_input, + input_it, middle, size, compare_op, @@ -214,7 +223,7 @@ TYPED_TEST(RocprimDevicePartialSortTests, PartialSort) { HIP_CHECK(rocprim::partial_sort_copy(d_temp_storage, temp_storage_size_bytes, - d_input, + input_it, d_output, middle, size, From 8086fcecf05ef1e191984184ba5776ac218b88d7 Mon Sep 17 00:00:00 2001 From: Nick Breed Date: Fri, 17 May 2024 11:31:04 +0000 Subject: [PATCH 06/14] Add partial_sort and partial_sort_copy to the changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index e8b3179bb..429da3b3b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ Documentation for rocPRIM is available at * Added large segment support for `rocprim:segmented_reduce`. * Added a parallel `nth_element` device function similar to `std::nth_element`, this function rearranges elements smaller than the n-th before and bigger than the n-th after the n-th element. * Added deterministic (bitwise reproducible) algorithm variants `rocprim::deterministic_inclusive_scan`, `rocprim::deterministic_exclusive_scan`, `rocprim::deterministic_inclusive_scan_by_key`, `rocprim::deterministic_exclusive_scan_by_key`, and `rocprim::deterministic_reduce_by_key`. These provide run-to-run stable results with non-associative operators such as float operations, at the cost of reduced performance. +* Added a parallel `partial_sort` and `partial_sort_copy` device function similar to `std::partial_sort` and `std::partial_sort_copy`, these functions rearranges elements such that the elements are the same as a sorted list up to and including the middle index. ### Changes From 3eaccf25efca3e0b31d0f34f4a6d0823f4a01bda Mon Sep 17 00:00:00 2001 From: Nick Breed Date: Fri, 17 May 2024 11:54:11 +0000 Subject: [PATCH 07/14] Moved partial sort to own file --- benchmark/benchmark_device_partial_sort.hpp | 2 +- .../rocprim/device/device_nth_element.hpp | 344 --------------- .../rocprim/device/device_partial_sort.hpp | 394 ++++++++++++++++++ rocprim/include/rocprim/rocprim.hpp | 1 + test/rocprim/test_device_partial_sort.cpp | 2 +- 5 files changed, 397 insertions(+), 346 deletions(-) create mode 100644 rocprim/include/rocprim/device/device_partial_sort.hpp diff --git a/benchmark/benchmark_device_partial_sort.hpp b/benchmark/benchmark_device_partial_sort.hpp index 560d2633d..d0dba07da 100644 --- a/benchmark/benchmark_device_partial_sort.hpp +++ b/benchmark/benchmark_device_partial_sort.hpp @@ -32,7 +32,7 @@ #include // rocPRIM -#include +#include #include #include diff --git a/rocprim/include/rocprim/device/device_nth_element.hpp b/rocprim/include/rocprim/device/device_nth_element.hpp index 85f93b98f..c1beaec65 100644 --- a/rocprim/include/rocprim/device/device_nth_element.hpp +++ b/rocprim/include/rocprim/device/device_nth_element.hpp @@ -28,7 +28,6 @@ #include "../config.hpp" #include "config_types.hpp" -#include "device_merge_sort.hpp" #include "device_nth_element_config.hpp" #include "device_transform.hpp" @@ -336,349 +335,6 @@ ROCPRIM_INLINE hipError_t nth_element(void* temporary_storage, debug_synchronous); } -namespace detail -{ - -template -hipError_t partial_sort_impl(void* temporary_storage, - size_t& storage_size, - KeysIterator keys, - size_t middle, - size_t size, - BinaryFunction compare_function, - hipStream_t stream, - bool debug_synchronous) -{ - using key_type = typename std::iterator_traits::value_type; - using config = wrapped_nth_element_config; - - target_arch target_arch; - hipError_t result = host_target_arch(stream, target_arch); - if(result != hipSuccess) - { - return result; - } - const nth_element_config_params params = dispatch_target_arch(target_arch); - - constexpr unsigned int num_partitions = 3; - const unsigned int num_buckets = params.number_of_buckets; - const unsigned int num_splitters = num_buckets - 1; - const unsigned int stop_recursion_size = num_buckets; - const unsigned int num_items_per_threads = params.kernel_config.items_per_thread; - const unsigned int num_threads_per_block = params.kernel_config.block_size; - const unsigned int num_items_per_block = num_threads_per_block * num_items_per_threads; - const unsigned int num_blocks = ceiling_div(size, num_items_per_block); - - size_t storage_size_merge_sort{}; - // non-null placeholder so that no buffer is allocated for keys - key_type* keys_buffer_placeholder = reinterpret_cast(1); - - result = merge_sort_impl(nullptr, - storage_size_merge_sort, - keys, - keys, - static_cast(nullptr), // values_input - static_cast(nullptr), // values_output - middle, - compare_function, - stream, - debug_synchronous, - keys_buffer_placeholder, // keys_buffer - static_cast(nullptr)); // values_buffer - if(result != hipSuccess) - { - return result; - } - - key_type* tree = nullptr; - size_t* buckets = nullptr; - n_th_element_iteration_data* nth_element_data = nullptr; - uint8_t* oracles = nullptr; - bool* equality_buckets = nullptr; - nth_element_onesweep_lookback_state* lookback_states = nullptr; - key_type* keys_buffer = nullptr; - void* temporary_storage_merge_sort = nullptr; - - const hipError_t partition_result = temp_storage::partition( - temporary_storage, - storage_size, - temp_storage::make_linear_partition( - temp_storage::ptr_aligned_array(&tree, num_splitters), - temp_storage::ptr_aligned_array(&equality_buckets, num_buckets), - temp_storage::ptr_aligned_array(&buckets, num_buckets), - temp_storage::ptr_aligned_array(&oracles, size), - temp_storage::ptr_aligned_array(&keys_buffer, size), - temp_storage::ptr_aligned_array(&nth_element_data, 1), - temp_storage::ptr_aligned_array(&lookback_states, num_partitions * num_blocks), - temp_storage::make_partition(&temporary_storage_merge_sort, storage_size_merge_sort))); - - if(partition_result != hipSuccess || temporary_storage == nullptr) - { - return partition_result; - } - - if(size == 0) - { - return hipSuccess; - } - - if(middle > size) - { - return hipErrorInvalidValue; - } - - if(debug_synchronous) - { - std::cout << "-----" << '\n'; - std::cout << "size: " << size << '\n'; - std::cout << "num_buckets: " << num_buckets << '\n'; - std::cout << "num_threads_per_block: " << num_threads_per_block << '\n'; - std::cout << "num_blocks: " << num_blocks << '\n'; - std::cout << "storage_size: " << storage_size << '\n'; - } - - result = nth_element_keys_impl(keys, - keys_buffer, - tree, - middle, - size, - buckets, - equality_buckets, - oracles, - lookback_states, - num_buckets, - stop_recursion_size, - num_threads_per_block, - num_items_per_threads, - nth_element_data, - compare_function, - stream, - debug_synchronous); - if(result != hipSuccess) - { - return result; - } - - return merge_sort_impl(temporary_storage_merge_sort, - storage_size_merge_sort, - keys, - keys, - static_cast(nullptr), // values_input - static_cast(nullptr), // values_output - middle, - compare_function, - stream, - debug_synchronous, - keys_buffer, // keys_buffer - static_cast(nullptr)); // values_buffer -} - -} // namespace detail - -/// \brief Rearranges elements such that the range [0, middle) contains the sorted middle smallest elements in the range [0, size). -/// -/// \par Overview -/// * The contents of the inputs are not altered by the function. -/// * Returns the required size of `temporary_storage` in `storage_size` -/// if `temporary_storage` is a null pointer. -/// * Accepts custom compare_functions for nth_element across the device. -/// * Streams in graph capture mode are not supported -/// -/// \tparam Config [optional] configuration of the primitive. It has to be `nth_element_config`. -/// \tparam KeysInputIterator [inferred] random-access iterator type of the input range. Must meet the -/// requirements of a C++ InputIterator concept. It can be a simple pointer type. -/// \tparam KeysOutputIterator [inferred] random-access iterator type of the output range. Must meet the -/// requirements of a C++ InputIterator concept. It can be a simple pointer type. -/// \tparam CompareFunction [inferred] Type of binary function that accepts two arguments of the -/// type `KeysIterator` and returns a value convertible to bool. Default type is `::rocprim::less<>.` -/// -/// \param [in] temporary_storage pointer to a device-accessible temporary storage. When -/// a null pointer is passed, the required allocation size (in bytes) is written to -/// `storage_size` and function returns without performing the nth_element rearrangement. -/// \param [in,out] storage_size reference to a size (in bytes) of `temporary_storage`. -/// \param [in] keys_input iterator to the input range. -/// \param [out] keys_output iterator to the output range. No overlap at all is allowed between `keys_input` and `keys_output`. -/// `keys_output` should be able to be written and read from for `size` elements. -/// \param [in] middle The index of the point till where it is sorted in the input range. -/// \param [in] size number of element in the input range. -/// \param [in] compare_function binary operation function object that will be used for comparison. -/// The signature of the function should be equivalent to the following: -/// bool f(const T &a, const T &b);. The signature does not need to have -/// const &, but function object must not modify the objects passed to it. -/// The comperator must meet the C++ named requirement Compare. -/// The default value is `BinaryFunction()`. -/// \param [in] stream [optional] HIP stream object. Default is `0` (default stream). -/// \param [in] debug_synchronous [optional] If true, synchronization after every kernel -/// launch is forced in order to check for errors. Default value is `false`. -/// -/// \returns `hipSuccess` (`0`) after successful rearrangement; otherwise a HIP runtime error of -/// type `hipError_t`. -/// -/// \par Example -/// \parblock -/// In this example a device-level nth_element is performed where input keys are -/// represented by an array of unsigned integers. -/// -/// \code{.cpp} -/// #include -/// -/// // Prepare input and output (declare pointers, allocate device memory etc.) -/// size_t input_size; // e.g., 8 -/// size_t middle; // e.g., 4 -/// unsigned int * keys_input; // e.g., [ 6, 3, 5, 4, 1, 8, 2, 7 ] -/// unsigned int * keys_output; // empty array of 8 elements -/// -/// size_t temporary_storage_size_bytes; -/// void * temporary_storage_ptr = nullptr; -/// // Get required size of the temporary storage -/// rocprim::partial_sort_copy( -/// temporary_storage_ptr, temporary_storage_size_bytes, -/// keys_input, keys_output, middle, input_size -/// ); -/// -/// // allocate temporary storage -/// hipMalloc(&temporary_storage_ptr, temporary_storage_size_bytes); -/// -/// // perform partial_sort -/// rocprim::partial_sort_copy( -/// temporary_storage_ptr, temporary_storage_size_bytes, -/// keys_input, keys_output, middle, input_size -/// ); -/// // possible keys_output: [ 1, 2, 3, 4, 5, 8, 7, 6 ] -/// \endcode -/// \endparblock -template::value_type>> -hipError_t partial_sort_copy(void* temporary_storage, - size_t& storage_size, - KeysInputIterator keys_input, - KeysOutputIterator keys_output, - size_t middle, - size_t size, - BinaryFunction compare_function = BinaryFunction(), - hipStream_t stream = 0, - bool debug_synchronous = false) -{ - using key_type = typename std::iterator_traits::value_type; - static_assert( - std::is_same::value_type>::value, - "KeysInputIterator and KeysOutputIterator must have the same value_type"); - - hipError_t error = transform(keys_input, - keys_output, - size, - ::rocprim::identity(), - stream, - debug_synchronous); - if(error != hipSuccess) - { - return error; - } - - return detail::partial_sort_impl(temporary_storage, - storage_size, - keys_output, - middle, - size, - compare_function, - stream, - debug_synchronous); -} - -/// \brief Rearranges elements such that the range [0, middle) contains the sorted middle smallest elements in the range [0, size). -/// -/// \par Overview -/// * The contents of the inputs are not altered by the function. -/// * Returns the required size of `temporary_storage` in `storage_size` -/// if `temporary_storage` is a null pointer. -/// * Accepts custom compare_functions for nth_element across the device. -/// * Streams in graph capture mode are not supported -/// -/// \tparam Config [optional] configuration of the primitive. It has to be `nth_element_config`. -/// \tparam KeysIterator [inferred] random-access iterator type of the input range. Must meet the -/// requirements of a C++ InputIterator concept. It can be a simple pointer type. -/// \tparam CompareFunction [inferred] Type of binary function that accepts two arguments of the -/// type `KeysIterator` and returns a value convertible to bool. Default type is `::rocprim::less<>.` -/// -/// \param [in] temporary_storage pointer to a device-accessible temporary storage. When -/// a null pointer is passed, the required allocation size (in bytes) is written to -/// `storage_size` and function returns without performing the nth_element rearrangement. -/// \param [in,out] storage_size reference to a size (in bytes) of `temporary_storage`. -/// \param [in,out] keys iterator to the input range. -/// \param [in] middle The index of the point till where it is sorted in the input range. -/// \param [in] size number of element in the input range. -/// \param [in] compare_function binary operation function object that will be used for comparison. -/// The signature of the function should be equivalent to the following: -/// bool f(const T &a, const T &b);. The signature does not need to have -/// const &, but function object must not modify the objects passed to it. -/// The comperator must meet the C++ named requirement Compare. -/// The default value is `BinaryFunction()`. -/// \param [in] stream [optional] HIP stream object. Default is `0` (default stream). -/// \param [in] debug_synchronous [optional] If true, synchronization after every kernel -/// launch is forced in order to check for errors. Default value is `false`. -/// -/// \returns `hipSuccess` (`0`) after successful rearrangement; otherwise a HIP runtime error of -/// type `hipError_t`. -/// -/// \par Example -/// \parblock -/// In this example a device-level nth_element is performed where input keys are -/// represented by an array of unsigned integers. -/// -/// \code{.cpp} -/// #include -/// -/// // Prepare input and output (declare pointers, allocate device memory etc.) -/// size_t input_size; // e.g., 8 -/// size_t middle; // e.g., 4 -/// unsigned int * keys; // e.g., [ 6, 3, 5, 4, 1, 8, 2, 7 ] -/// -/// size_t temporary_storage_size_bytes; -/// void * temporary_storage_ptr = nullptr; -/// // Get required size of the temporary storage -/// rocprim::partial_sort( -/// temporary_storage_ptr, temporary_storage_size_bytes, -/// keys, nth, input_size -/// ); -/// -/// // allocate temporary storage -/// hipMalloc(&temporary_storage_ptr, temporary_storage_size_bytes); -/// -/// // perform partial_sort -/// rocprim::partial_sort( -/// temporary_storage_ptr, temporary_storage_size_bytes, -/// keys, nth, input_size -/// ); -/// // possible keys: [ 1, 2, 3, 4, 5, 8, 7, 6 ] -/// \endcode -/// \endparblock -template::value_type>> -hipError_t partial_sort(void* temporary_storage, - size_t& storage_size, - KeysIterator keys, - size_t middle, - size_t size, - BinaryFunction compare_function = BinaryFunction(), - hipStream_t stream = 0, - bool debug_synchronous = false) -{ - return detail::partial_sort_impl(temporary_storage, - storage_size, - keys, - middle, - size, - compare_function, - stream, - debug_synchronous); -} - /// @} // end of group devicemodule diff --git a/rocprim/include/rocprim/device/device_partial_sort.hpp b/rocprim/include/rocprim/device/device_partial_sort.hpp new file mode 100644 index 000000000..1ecf2ac1e --- /dev/null +++ b/rocprim/include/rocprim/device/device_partial_sort.hpp @@ -0,0 +1,394 @@ +// Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_DEVICE_DEVICE_PARTIAL_SORT_HPP_ +#define ROCPRIM_DEVICE_DEVICE_PARTIAL_SORT_HPP_ + +#include "detail/device_nth_element.hpp" + +#include "../detail/temp_storage.hpp" + +#include "../config.hpp" + +#include "config_types.hpp" +#include "device_merge_sort.hpp" +#include "device_nth_element_config.hpp" +#include "device_transform.hpp" + +#include +#include + +#include +#include + +BEGIN_ROCPRIM_NAMESPACE + +/// \addtogroup devicemodule +/// @{ + +namespace detail +{ + +template +hipError_t partial_sort_impl(void* temporary_storage, + size_t& storage_size, + KeysIterator keys, + size_t middle, + size_t size, + BinaryFunction compare_function, + hipStream_t stream, + bool debug_synchronous) +{ + using key_type = typename std::iterator_traits::value_type; + using config = wrapped_nth_element_config; + + target_arch target_arch; + hipError_t result = host_target_arch(stream, target_arch); + if(result != hipSuccess) + { + return result; + } + const nth_element_config_params params = dispatch_target_arch(target_arch); + + constexpr unsigned int num_partitions = 3; + const unsigned int num_buckets = params.number_of_buckets; + const unsigned int num_splitters = num_buckets - 1; + const unsigned int stop_recursion_size = params.stop_recursion_size; + const unsigned int num_items_per_threads = params.kernel_config.items_per_thread; + const unsigned int num_threads_per_block = params.kernel_config.block_size; + const unsigned int num_items_per_block = num_threads_per_block * num_items_per_threads; + const unsigned int num_blocks = ceiling_div(size, num_items_per_block); + + size_t storage_size_merge_sort{}; + // non-null placeholder so that no buffer is allocated for keys + key_type* keys_buffer_placeholder = reinterpret_cast(1); + + result = merge_sort_impl(nullptr, + storage_size_merge_sort, + keys, + keys, + static_cast(nullptr), // values_input + static_cast(nullptr), // values_output + middle, + compare_function, + stream, + debug_synchronous, + keys_buffer_placeholder, // keys_buffer + static_cast(nullptr)); // values_buffer + if(result != hipSuccess) + { + return result; + } + + key_type* tree = nullptr; + size_t* buckets = nullptr; + n_th_element_iteration_data* nth_element_data = nullptr; + uint8_t* oracles = nullptr; + bool* equality_buckets = nullptr; + nth_element_onesweep_lookback_state* lookback_states = nullptr; + key_type* keys_buffer = nullptr; + void* temporary_storage_merge_sort = nullptr; + + const hipError_t partition_result = temp_storage::partition( + temporary_storage, + storage_size, + temp_storage::make_linear_partition( + temp_storage::ptr_aligned_array(&tree, num_splitters), + temp_storage::ptr_aligned_array(&equality_buckets, num_buckets), + temp_storage::ptr_aligned_array(&buckets, num_buckets), + temp_storage::ptr_aligned_array(&oracles, size), + temp_storage::ptr_aligned_array(&keys_buffer, size), + temp_storage::ptr_aligned_array(&nth_element_data, 1), + temp_storage::ptr_aligned_array(&lookback_states, num_partitions * num_blocks), + temp_storage::make_partition(&temporary_storage_merge_sort, storage_size_merge_sort))); + + if(partition_result != hipSuccess || temporary_storage == nullptr) + { + return partition_result; + } + + if(size == 0) + { + return hipSuccess; + } + + if(middle > size) + { + return hipErrorInvalidValue; + } + + if(debug_synchronous) + { + std::cout << "-----" << '\n'; + std::cout << "size: " << size << '\n'; + std::cout << "num_buckets: " << num_buckets << '\n'; + std::cout << "num_threads_per_block: " << num_threads_per_block << '\n'; + std::cout << "num_blocks: " << num_blocks << '\n'; + std::cout << "storage_size: " << storage_size << '\n'; + } + + result = nth_element_keys_impl(keys, + keys_buffer, + tree, + middle, + size, + buckets, + equality_buckets, + oracles, + lookback_states, + num_buckets, + stop_recursion_size, + num_threads_per_block, + num_items_per_threads, + nth_element_data, + compare_function, + stream, + debug_synchronous); + if(result != hipSuccess) + { + return result; + } + + return merge_sort_impl(temporary_storage_merge_sort, + storage_size_merge_sort, + keys, + keys, + static_cast(nullptr), // values_input + static_cast(nullptr), // values_output + middle, + compare_function, + stream, + debug_synchronous, + keys_buffer, // keys_buffer + static_cast(nullptr)); // values_buffer +} + +} // namespace detail + +/// \brief Rearranges elements such that the range [0, middle) contains the sorted middle smallest elements in the range [0, size). +/// +/// \par Overview +/// * The contents of the inputs are not altered by the function. +/// * Returns the required size of `temporary_storage` in `storage_size` +/// if `temporary_storage` is a null pointer. +/// * Accepts custom compare_functions for nth_element across the device. +/// * Streams in graph capture mode are not supported +/// +/// \tparam Config [optional] configuration of the primitive. It has to be `nth_element_config`. +/// \tparam KeysInputIterator [inferred] random-access iterator type of the input range. Must meet the +/// requirements of a C++ InputIterator concept. It can be a simple pointer type. +/// \tparam KeysOutputIterator [inferred] random-access iterator type of the output range. Must meet the +/// requirements of a C++ InputIterator concept. It can be a simple pointer type. +/// \tparam CompareFunction [inferred] Type of binary function that accepts two arguments of the +/// type `KeysIterator` and returns a value convertible to bool. Default type is `::rocprim::less<>.` +/// +/// \param [in] temporary_storage pointer to a device-accessible temporary storage. When +/// a null pointer is passed, the required allocation size (in bytes) is written to +/// `storage_size` and function returns without performing the nth_element rearrangement. +/// \param [in,out] storage_size reference to a size (in bytes) of `temporary_storage`. +/// \param [in] keys_input iterator to the input range. +/// \param [out] keys_output iterator to the output range. No overlap at all is allowed between `keys_input` and `keys_output`. +/// `keys_output` should be able to be written and read from for `size` elements. +/// \param [in] middle The index of the point till where it is sorted in the input range. +/// \param [in] size number of element in the input range. +/// \param [in] compare_function binary operation function object that will be used for comparison. +/// The signature of the function should be equivalent to the following: +/// bool f(const T &a, const T &b);. The signature does not need to have +/// const &, but function object must not modify the objects passed to it. +/// The comperator must meet the C++ named requirement Compare. +/// The default value is `BinaryFunction()`. +/// \param [in] stream [optional] HIP stream object. Default is `0` (default stream). +/// \param [in] debug_synchronous [optional] If true, synchronization after every kernel +/// launch is forced in order to check for errors. Default value is `false`. +/// +/// \returns `hipSuccess` (`0`) after successful rearrangement; otherwise a HIP runtime error of +/// type `hipError_t`. +/// +/// \par Example +/// \parblock +/// In this example a device-level nth_element is performed where input keys are +/// represented by an array of unsigned integers. +/// +/// \code{.cpp} +/// #include +/// +/// // Prepare input and output (declare pointers, allocate device memory etc.) +/// size_t input_size; // e.g., 8 +/// size_t middle; // e.g., 4 +/// unsigned int * keys_input; // e.g., [ 6, 3, 5, 4, 1, 8, 2, 7 ] +/// unsigned int * keys_output; // empty array of 8 elements +/// +/// size_t temporary_storage_size_bytes; +/// void * temporary_storage_ptr = nullptr; +/// // Get required size of the temporary storage +/// rocprim::partial_sort_copy( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// keys_input, keys_output, middle, input_size +/// ); +/// +/// // allocate temporary storage +/// hipMalloc(&temporary_storage_ptr, temporary_storage_size_bytes); +/// +/// // perform partial_sort +/// rocprim::partial_sort_copy( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// keys_input, keys_output, middle, input_size +/// ); +/// // possible keys_output: [ 1, 2, 3, 4, 5, 8, 7, 6 ] +/// \endcode +/// \endparblock +template::value_type>> +hipError_t partial_sort_copy(void* temporary_storage, + size_t& storage_size, + KeysInputIterator keys_input, + KeysOutputIterator keys_output, + size_t middle, + size_t size, + BinaryFunction compare_function = BinaryFunction(), + hipStream_t stream = 0, + bool debug_synchronous = false) +{ + using key_type = typename std::iterator_traits::value_type; + static_assert( + std::is_same::value_type>::value, + "KeysInputIterator and KeysOutputIterator must have the same value_type"); + + hipError_t error = transform(keys_input, + keys_output, + size, + ::rocprim::identity(), + stream, + debug_synchronous); + if(error != hipSuccess) + { + return error; + } + + return detail::partial_sort_impl(temporary_storage, + storage_size, + keys_output, + middle, + size, + compare_function, + stream, + debug_synchronous); +} + +/// \brief Rearranges elements such that the range [0, middle) contains the sorted middle smallest elements in the range [0, size). +/// +/// \par Overview +/// * The contents of the inputs are not altered by the function. +/// * Returns the required size of `temporary_storage` in `storage_size` +/// if `temporary_storage` is a null pointer. +/// * Accepts custom compare_functions for nth_element across the device. +/// * Streams in graph capture mode are not supported +/// +/// \tparam Config [optional] configuration of the primitive. It has to be `nth_element_config`. +/// \tparam KeysIterator [inferred] random-access iterator type of the input range. Must meet the +/// requirements of a C++ InputIterator concept. It can be a simple pointer type. +/// \tparam CompareFunction [inferred] Type of binary function that accepts two arguments of the +/// type `KeysIterator` and returns a value convertible to bool. Default type is `::rocprim::less<>.` +/// +/// \param [in] temporary_storage pointer to a device-accessible temporary storage. When +/// a null pointer is passed, the required allocation size (in bytes) is written to +/// `storage_size` and function returns without performing the nth_element rearrangement. +/// \param [in,out] storage_size reference to a size (in bytes) of `temporary_storage`. +/// \param [in,out] keys iterator to the input range. +/// \param [in] middle The index of the point till where it is sorted in the input range. +/// \param [in] size number of element in the input range. +/// \param [in] compare_function binary operation function object that will be used for comparison. +/// The signature of the function should be equivalent to the following: +/// bool f(const T &a, const T &b);. The signature does not need to have +/// const &, but function object must not modify the objects passed to it. +/// The comperator must meet the C++ named requirement Compare. +/// The default value is `BinaryFunction()`. +/// \param [in] stream [optional] HIP stream object. Default is `0` (default stream). +/// \param [in] debug_synchronous [optional] If true, synchronization after every kernel +/// launch is forced in order to check for errors. Default value is `false`. +/// +/// \returns `hipSuccess` (`0`) after successful rearrangement; otherwise a HIP runtime error of +/// type `hipError_t`. +/// +/// \par Example +/// \parblock +/// In this example a device-level nth_element is performed where input keys are +/// represented by an array of unsigned integers. +/// +/// \code{.cpp} +/// #include +/// +/// // Prepare input and output (declare pointers, allocate device memory etc.) +/// size_t input_size; // e.g., 8 +/// size_t middle; // e.g., 4 +/// unsigned int * keys; // e.g., [ 6, 3, 5, 4, 1, 8, 2, 7 ] +/// +/// size_t temporary_storage_size_bytes; +/// void * temporary_storage_ptr = nullptr; +/// // Get required size of the temporary storage +/// rocprim::partial_sort( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// keys, nth, input_size +/// ); +/// +/// // allocate temporary storage +/// hipMalloc(&temporary_storage_ptr, temporary_storage_size_bytes); +/// +/// // perform partial_sort +/// rocprim::partial_sort( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// keys, nth, input_size +/// ); +/// // possible keys: [ 1, 2, 3, 4, 5, 8, 7, 6 ] +/// \endcode +/// \endparblock +template::value_type>> +hipError_t partial_sort(void* temporary_storage, + size_t& storage_size, + KeysIterator keys, + size_t middle, + size_t size, + BinaryFunction compare_function = BinaryFunction(), + hipStream_t stream = 0, + bool debug_synchronous = false) +{ + return detail::partial_sort_impl(temporary_storage, + storage_size, + keys, + middle, + size, + compare_function, + stream, + debug_synchronous); +} + +/// @} +// end of group devicemodule + +END_ROCPRIM_NAMESPACE + +#endif // ROCPRIM_DEVICE_DEVICE_PARTIAL_SORT_HPP_ diff --git a/rocprim/include/rocprim/rocprim.hpp b/rocprim/include/rocprim/rocprim.hpp index 1bfb027b0..8ba3f4714 100644 --- a/rocprim/include/rocprim/rocprim.hpp +++ b/rocprim/include/rocprim/rocprim.hpp @@ -66,6 +66,7 @@ #include "device/device_merge.hpp" #include "device/device_merge_sort.hpp" #include "device/device_nth_element.hpp" +#include "device/device_partial_sort.hpp" #include "device/device_partition.hpp" #include "device/device_radix_sort.hpp" #include "device/device_reduce.hpp" diff --git a/test/rocprim/test_device_partial_sort.cpp b/test/rocprim/test_device_partial_sort.cpp index eff54743e..b1107ea47 100644 --- a/test/rocprim/test_device_partial_sort.cpp +++ b/test/rocprim/test_device_partial_sort.cpp @@ -33,7 +33,7 @@ // required rocprim headers #include #include -#include +#include #include #include From 88574c5cab35ba30bc41e1c301a15d20c9efcad0 Mon Sep 17 00:00:00 2001 From: Nick Breed Date: Tue, 21 May 2024 13:08:05 +0000 Subject: [PATCH 08/14] Added partial_sort_config --- docs/device_ops/partial_sort.rst | 2 +- .../rocprim/device/device_nth_element.hpp | 220 +++++++++++------- .../rocprim/device/device_partial_sort.hpp | 148 +++++------- .../device/device_partial_sort_config.hpp | 61 +++++ test/rocprim/test_device_partial_sort.cpp | 8 +- 5 files changed, 267 insertions(+), 172 deletions(-) create mode 100644 rocprim/include/rocprim/device/device_partial_sort_config.hpp diff --git a/docs/device_ops/partial_sort.rst b/docs/device_ops/partial_sort.rst index 6b5632a2e..222ed25f1 100644 --- a/docs/device_ops/partial_sort.rst +++ b/docs/device_ops/partial_sort.rst @@ -11,7 +11,7 @@ Partial Sort Configuring the kernel ~~~~~~~~~~~~~~~~~~~~~~ -.. doxygenstruct:: rocprim::nth_element_config +.. doxygenstruct:: rocprim::partial_sort_config partial_sort ~~~~~~~~~~~~ diff --git a/rocprim/include/rocprim/device/device_nth_element.hpp b/rocprim/include/rocprim/device/device_nth_element.hpp index c1beaec65..550ab941a 100644 --- a/rocprim/include/rocprim/device/device_nth_element.hpp +++ b/rocprim/include/rocprim/device/device_nth_element.hpp @@ -39,6 +39,132 @@ BEGIN_ROCPRIM_NAMESPACE +namespace detail +{ + +template +ROCPRIM_INLINE hipError_t nth_element_impl( + void* temporary_storage, + size_t& storage_size, + KeysIterator keys, + size_t nth, + size_t size, + BinaryFunction compare_function, + hipStream_t stream, + bool debug_synchronous, + typename std::iterator_traits::value_type* keys_double_buffer = nullptr) +{ + using key_type = typename std::iterator_traits::value_type; + using config = wrapped_nth_element_config; + + detail::target_arch target_arch; + hipError_t result = host_target_arch(stream, target_arch); + if(result != hipSuccess) + { + return result; + } + const detail::nth_element_config_params params + = detail::dispatch_target_arch(target_arch); + + constexpr unsigned int num_partitions = 3; + const unsigned int num_buckets = params.number_of_buckets; + const unsigned int num_splitters = num_buckets - 1; + const unsigned int stop_recursion_size = params.stop_recursion_size; + const unsigned int num_items_per_threads = params.kernel_config.items_per_thread; + const unsigned int num_threads_per_block = params.kernel_config.block_size; + const unsigned int num_items_per_block = num_threads_per_block * num_items_per_threads; + const unsigned int num_blocks = detail::ceiling_div(size, num_items_per_block); + + // oracles stores the bucket that correlates with the index + uint8_t* oracles = nullptr; + key_type* tree = nullptr; + size_t* buckets = nullptr; + detail::n_th_element_iteration_data* nth_element_data = nullptr; + bool* equality_buckets = nullptr; + detail::nth_element_onesweep_lookback_state* lookback_states = nullptr; + + key_type* keys_buffer = nullptr; + + { + using namespace detail::temp_storage; + + hipError_t partition_result; + if(keys_double_buffer == nullptr) + { + partition_result + = partition(temporary_storage, + storage_size, + make_linear_partition( + ptr_aligned_array(&tree, num_splitters), + ptr_aligned_array(&equality_buckets, num_buckets), + ptr_aligned_array(&buckets, num_buckets), + ptr_aligned_array(&oracles, size), + ptr_aligned_array(&keys_buffer, size), + ptr_aligned_array(&nth_element_data, 1), + ptr_aligned_array(&lookback_states, num_partitions * num_blocks))); + } + else + { + partition_result + = partition(temporary_storage, + storage_size, + make_linear_partition( + ptr_aligned_array(&tree, num_splitters), + ptr_aligned_array(&equality_buckets, num_buckets), + ptr_aligned_array(&buckets, num_buckets), + ptr_aligned_array(&oracles, size), + ptr_aligned_array(&nth_element_data, 1), + ptr_aligned_array(&lookback_states, num_partitions * num_blocks))); + keys_buffer = keys_double_buffer; + } + + if(partition_result != hipSuccess || temporary_storage == nullptr) + { + return partition_result; + } + } + + if((size == 0) || (size == 1 && nth == 0)) + { + return hipSuccess; + } + + if(nth >= size) + { + return hipErrorInvalidValue; + } + + if(debug_synchronous) + { + std::cout << "-----" << '\n'; + std::cout << "size: " << size << '\n'; + std::cout << "num_buckets: " << num_buckets << '\n'; + std::cout << "num_threads_per_block: " << num_threads_per_block << '\n'; + std::cout << "num_blocks: " << num_blocks << '\n'; + std::cout << "storage_size: " << storage_size << '\n'; + } + + return detail::nth_element_keys_impl(keys, + keys_buffer, + tree, + nth, + size, + buckets, + equality_buckets, + oracles, + lookback_states, + num_buckets, + stop_recursion_size, + num_threads_per_block, + num_items_per_threads, + nth_element_data, + compare_function, + stream, + debug_synchronous); +} + +} // namespace detail + /// \addtogroup devicemodule /// @{ @@ -127,91 +253,15 @@ ROCPRIM_INLINE hipError_t nth_element(void* temporary_storage, hipStream_t stream = 0, bool debug_synchronous = false) { - using key_type = typename std::iterator_traits::value_type; - using config = detail::wrapped_nth_element_config; - - detail::target_arch target_arch; - hipError_t result = host_target_arch(stream, target_arch); - if(result != hipSuccess) - { - return result; - } - const detail::nth_element_config_params params - = detail::dispatch_target_arch(target_arch); - - constexpr unsigned int num_partitions = 3; - const unsigned int num_buckets = params.number_of_buckets; - const unsigned int num_splitters = num_buckets - 1; - const unsigned int stop_recursion_size = params.stop_recursion_size; - const unsigned int num_items_per_threads = params.kernel_config.items_per_thread; - const unsigned int num_threads_per_block = params.kernel_config.block_size; - const unsigned int num_items_per_block = num_threads_per_block * num_items_per_threads; - const unsigned int num_blocks = detail::ceiling_div(size, num_items_per_block); - - key_type* tree = nullptr; - size_t* buckets = nullptr; - detail::n_th_element_iteration_data* nth_element_data = nullptr; - bool* equality_buckets = nullptr; - detail::nth_element_onesweep_lookback_state* lookback_states = nullptr; - - key_type* keys_buffer = nullptr; - - { - using namespace detail::temp_storage; - - const hipError_t partition_result - = partition(temporary_storage, - storage_size, - make_linear_partition( - ptr_aligned_array(&tree, num_splitters), - ptr_aligned_array(&equality_buckets, num_buckets), - ptr_aligned_array(&buckets, num_buckets), - ptr_aligned_array(&keys_buffer, size), - ptr_aligned_array(&nth_element_data, 1), - ptr_aligned_array(&lookback_states, num_partitions * num_blocks))); - - if(partition_result != hipSuccess || temporary_storage == nullptr) - { - return partition_result; - } - } - - if((size == 0) || (size == 1 && nth == 0)) - { - return hipSuccess; - } - - if(nth >= size) - { - return hipErrorInvalidValue; - } - - if(debug_synchronous) - { - std::cout << "-----" << '\n'; - std::cout << "size: " << size << '\n'; - std::cout << "num_buckets: " << num_buckets << '\n'; - std::cout << "num_threads_per_block: " << num_threads_per_block << '\n'; - std::cout << "num_blocks: " << num_blocks << '\n'; - std::cout << "storage_size: " << storage_size << '\n'; - } - - return detail::nth_element_keys_impl(keys, - keys_buffer, - tree, - nth, - size, - buckets, - equality_buckets, - lookback_states, - num_buckets, - stop_recursion_size, - num_threads_per_block, - num_items_per_threads, - nth_element_data, - compare_function, - stream, - debug_synchronous); + return detail::nth_element_impl(temporary_storage, + storage_size, + keys, + nth, + size, + compare_function, + stream, + debug_synchronous, + nullptr); } /// \brief Rearrange elements smaller than the n-th before and bigger than n-th after the n-th element. diff --git a/rocprim/include/rocprim/device/device_partial_sort.hpp b/rocprim/include/rocprim/device/device_partial_sort.hpp index 1ecf2ac1e..a2d025e61 100644 --- a/rocprim/include/rocprim/device/device_partial_sort.hpp +++ b/rocprim/include/rocprim/device/device_partial_sort.hpp @@ -29,7 +29,8 @@ #include "config_types.hpp" #include "device_merge_sort.hpp" -#include "device_nth_element_config.hpp" +#include "device_nth_element.hpp" +#include "device_partial_sort_config.hpp" #include "device_transform.hpp" #include @@ -57,66 +58,58 @@ hipError_t partial_sort_impl(void* temporary_storage, bool debug_synchronous) { using key_type = typename std::iterator_traits::value_type; - using config = wrapped_nth_element_config; + using config + = detail::default_or_custom_config>; + using config_merge_sort = typename config::merge_sort; + using config_nth_element = typename config::nth_element; - target_arch target_arch; - hipError_t result = host_target_arch(stream, target_arch); + size_t storage_size_nth_element{}; + // non-null placeholder so that no buffer is allocated for keys + key_type* keys_buffer_placeholder = reinterpret_cast(1); + + hipError_t result = nth_element_impl(nullptr, + storage_size_nth_element, + keys, + middle, + size, + compare_function, + stream, + debug_synchronous, + keys_buffer_placeholder); if(result != hipSuccess) { return result; } - const nth_element_config_params params = dispatch_target_arch(target_arch); - - constexpr unsigned int num_partitions = 3; - const unsigned int num_buckets = params.number_of_buckets; - const unsigned int num_splitters = num_buckets - 1; - const unsigned int stop_recursion_size = params.stop_recursion_size; - const unsigned int num_items_per_threads = params.kernel_config.items_per_thread; - const unsigned int num_threads_per_block = params.kernel_config.block_size; - const unsigned int num_items_per_block = num_threads_per_block * num_items_per_threads; - const unsigned int num_blocks = ceiling_div(size, num_items_per_block); size_t storage_size_merge_sort{}; - // non-null placeholder so that no buffer is allocated for keys - key_type* keys_buffer_placeholder = reinterpret_cast(1); - result = merge_sort_impl(nullptr, - storage_size_merge_sort, - keys, - keys, - static_cast(nullptr), // values_input - static_cast(nullptr), // values_output - middle, - compare_function, - stream, - debug_synchronous, - keys_buffer_placeholder, // keys_buffer - static_cast(nullptr)); // values_buffer + result = merge_sort_impl(nullptr, + storage_size_merge_sort, + keys, + keys, + static_cast(nullptr), // values_input + static_cast(nullptr), // values_output + middle, + compare_function, + stream, + debug_synchronous, + keys_buffer_placeholder, // keys_buffer + static_cast(nullptr)); // values_buffer if(result != hipSuccess) { return result; } - key_type* tree = nullptr; - size_t* buckets = nullptr; - n_th_element_iteration_data* nth_element_data = nullptr; - uint8_t* oracles = nullptr; - bool* equality_buckets = nullptr; - nth_element_onesweep_lookback_state* lookback_states = nullptr; - key_type* keys_buffer = nullptr; - void* temporary_storage_merge_sort = nullptr; + void* temporary_storage_nth_element = nullptr; + void* temporary_storage_merge_sort = nullptr; + key_type* keys_buffer = nullptr; const hipError_t partition_result = temp_storage::partition( temporary_storage, storage_size, temp_storage::make_linear_partition( - temp_storage::ptr_aligned_array(&tree, num_splitters), - temp_storage::ptr_aligned_array(&equality_buckets, num_buckets), - temp_storage::ptr_aligned_array(&buckets, num_buckets), - temp_storage::ptr_aligned_array(&oracles, size), temp_storage::ptr_aligned_array(&keys_buffer, size), - temp_storage::ptr_aligned_array(&nth_element_data, 1), - temp_storage::ptr_aligned_array(&lookback_states, num_partitions * num_blocks), + temp_storage::make_partition(&temporary_storage_nth_element, storage_size_nth_element), temp_storage::make_partition(&temporary_storage_merge_sort, storage_size_merge_sort))); if(partition_result != hipSuccess || temporary_storage == nullptr) @@ -134,50 +127,35 @@ hipError_t partial_sort_impl(void* temporary_storage, return hipErrorInvalidValue; } - if(debug_synchronous) - { - std::cout << "-----" << '\n'; - std::cout << "size: " << size << '\n'; - std::cout << "num_buckets: " << num_buckets << '\n'; - std::cout << "num_threads_per_block: " << num_threads_per_block << '\n'; - std::cout << "num_blocks: " << num_blocks << '\n'; - std::cout << "storage_size: " << storage_size << '\n'; - } - - result = nth_element_keys_impl(keys, - keys_buffer, - tree, - middle, - size, - buckets, - equality_buckets, - oracles, - lookback_states, - num_buckets, - stop_recursion_size, - num_threads_per_block, - num_items_per_threads, - nth_element_data, - compare_function, - stream, - debug_synchronous); - if(result != hipSuccess) + if(middle < size) { - return result; + result = nth_element_impl(temporary_storage_nth_element, + storage_size_nth_element, + keys, + middle, + size, + compare_function, + stream, + debug_synchronous, + keys_buffer); + if(result != hipSuccess) + { + return result; + } } - return merge_sort_impl(temporary_storage_merge_sort, - storage_size_merge_sort, - keys, - keys, - static_cast(nullptr), // values_input - static_cast(nullptr), // values_output - middle, - compare_function, - stream, - debug_synchronous, - keys_buffer, // keys_buffer - static_cast(nullptr)); // values_buffer + return merge_sort_impl(temporary_storage_merge_sort, + storage_size_merge_sort, + keys, + keys, + static_cast(nullptr), // values_input + static_cast(nullptr), // values_output + middle, + compare_function, + stream, + debug_synchronous, + keys_buffer, // keys_buffer + static_cast(nullptr)); // values_buffer } } // namespace detail @@ -191,7 +169,7 @@ hipError_t partial_sort_impl(void* temporary_storage, /// * Accepts custom compare_functions for nth_element across the device. /// * Streams in graph capture mode are not supported /// -/// \tparam Config [optional] configuration of the primitive. It has to be `nth_element_config`. +/// \tparam Config [optional] configuration of the primitive. It has to be `partial_sort_config`. /// \tparam KeysInputIterator [inferred] random-access iterator type of the input range. Must meet the /// requirements of a C++ InputIterator concept. It can be a simple pointer type. /// \tparam KeysOutputIterator [inferred] random-access iterator type of the output range. Must meet the @@ -305,7 +283,7 @@ hipError_t partial_sort_copy(void* temporary_storage, /// * Accepts custom compare_functions for nth_element across the device. /// * Streams in graph capture mode are not supported /// -/// \tparam Config [optional] configuration of the primitive. It has to be `nth_element_config`. +/// \tparam Config [optional] configuration of the primitive. It has to be `partial_sort_config`. /// \tparam KeysIterator [inferred] random-access iterator type of the input range. Must meet the /// requirements of a C++ InputIterator concept. It can be a simple pointer type. /// \tparam CompareFunction [inferred] Type of binary function that accepts two arguments of the diff --git a/rocprim/include/rocprim/device/device_partial_sort_config.hpp b/rocprim/include/rocprim/device/device_partial_sort_config.hpp new file mode 100644 index 000000000..29aa4eb56 --- /dev/null +++ b/rocprim/include/rocprim/device/device_partial_sort_config.hpp @@ -0,0 +1,61 @@ +// Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_DEVICE_DEVICE_PARTIAL_SORT_CONFIG_HPP_ +#define ROCPRIM_DEVICE_DEVICE_PARTIAL_SORT_CONFIG_HPP_ + +#include "config_types.hpp" + +#include "device_nth_element_config.hpp" + +/// \addtogroup primitivesmodule_deviceconfigs +/// @{ + +BEGIN_ROCPRIM_NAMESPACE + +/// \brief Configuration of device-level partial sort. +/// +/// \tparam NthElementConfig - configuration of device-level nth element operation. +/// Must be \p nth_element_config or \p default_config. +/// \tparam MergeSortConfig - configuration of device-level merge sort operation. +/// Must be \p merge_sort_config or \p default_config. +template +struct partial_sort_config +{ + /// \brief Configuration of device-level nth element operation. + using nth_element = NthElementConfig; + /// \brief Configuration of device-level merge sort operation. + using merge_sort = MergeSortConfig; +}; + +namespace detail +{ + +template +using default_partial_sort_config = partial_sort_config; + +} // end namespace detail + +END_ROCPRIM_NAMESPACE + +/// @} +// end of group primitivesmodule_deviceconfigs + +#endif // ROCPRIM_DEVICE_DEVICE_PARTIAL_SORT_CONFIG_HPP_ diff --git a/test/rocprim/test_device_partial_sort.cpp b/test/rocprim/test_device_partial_sort.cpp index b1107ea47..6eef7542c 100644 --- a/test/rocprim/test_device_partial_sort.cpp +++ b/test/rocprim/test_device_partial_sort.cpp @@ -88,7 +88,13 @@ using RocprimDevicePartialSortTestsParams = ::testing::Types< DevicePartialSortParams>, DevicePartialSortParams, DevicePartialSortParams>, - DevicePartialSortParams, rocprim::default_config, false, true>>; + DevicePartialSortParams, rocprim::default_config, false, true>, + DevicePartialSortParams< + int, + ::rocprim::less, + rocprim::partial_sort_config< + rocprim:: + nth_element_config<128, 4, 32, 16, rocprim::block_radix_rank_algorithm::basic>>>>; TYPED_TEST_SUITE(RocprimDevicePartialSortTests, RocprimDevicePartialSortTestsParams); From 14c48a83db1db56697f826a2613f0bc43cc25b25 Mon Sep 17 00:00:00 2001 From: Nick Breed Date: Thu, 30 May 2024 12:48:14 +0000 Subject: [PATCH 09/14] Merge with nth_element_remove_oracle branch --- rocprim/include/rocprim/device/device_nth_element.hpp | 4 ---- 1 file changed, 4 deletions(-) diff --git a/rocprim/include/rocprim/device/device_nth_element.hpp b/rocprim/include/rocprim/device/device_nth_element.hpp index 550ab941a..f3090699c 100644 --- a/rocprim/include/rocprim/device/device_nth_element.hpp +++ b/rocprim/include/rocprim/device/device_nth_element.hpp @@ -76,7 +76,6 @@ ROCPRIM_INLINE hipError_t nth_element_impl( const unsigned int num_blocks = detail::ceiling_div(size, num_items_per_block); // oracles stores the bucket that correlates with the index - uint8_t* oracles = nullptr; key_type* tree = nullptr; size_t* buckets = nullptr; detail::n_th_element_iteration_data* nth_element_data = nullptr; @@ -98,7 +97,6 @@ ROCPRIM_INLINE hipError_t nth_element_impl( ptr_aligned_array(&tree, num_splitters), ptr_aligned_array(&equality_buckets, num_buckets), ptr_aligned_array(&buckets, num_buckets), - ptr_aligned_array(&oracles, size), ptr_aligned_array(&keys_buffer, size), ptr_aligned_array(&nth_element_data, 1), ptr_aligned_array(&lookback_states, num_partitions * num_blocks))); @@ -112,7 +110,6 @@ ROCPRIM_INLINE hipError_t nth_element_impl( ptr_aligned_array(&tree, num_splitters), ptr_aligned_array(&equality_buckets, num_buckets), ptr_aligned_array(&buckets, num_buckets), - ptr_aligned_array(&oracles, size), ptr_aligned_array(&nth_element_data, 1), ptr_aligned_array(&lookback_states, num_partitions * num_blocks))); keys_buffer = keys_double_buffer; @@ -151,7 +148,6 @@ ROCPRIM_INLINE hipError_t nth_element_impl( size, buckets, equality_buckets, - oracles, lookback_states, num_buckets, stop_recursion_size, From 74cdef02f6a788107c37e1c32af96a59a299d071 Mon Sep 17 00:00:00 2001 From: Nick Breed Date: Thu, 13 Jun 2024 14:33:36 +0000 Subject: [PATCH 10/14] Created c++17 test for partial_sort --- test/rocprim/CMakeLists.txt | 2 +- test/rocprim/test_device_partial_sort.cpp | 67 ++++++++++++++++++++--- 2 files changed, 61 insertions(+), 8 deletions(-) diff --git a/test/rocprim/CMakeLists.txt b/test/rocprim/CMakeLists.txt index 3f6e9f02f..433bf442e 100644 --- a/test/rocprim/CMakeLists.txt +++ b/test/rocprim/CMakeLists.txt @@ -260,7 +260,7 @@ add_rocprim_test("rocprim.device_histogram" test_device_histogram.cpp) add_rocprim_test("rocprim.device_merge" test_device_merge.cpp) add_rocprim_test("rocprim.device_merge_sort" test_device_merge_sort.cpp) add_rocprim_cpp17_test("rocprim.nth_element" test_device_nth_element.cpp) -add_rocprim_test("rocprim.device_partial_sort" test_device_partial_sort.cpp) +add_rocprim_cpp17_test("rocprim.device_partial_sort" test_device_partial_sort.cpp) add_rocprim_test("rocprim.device_partition" test_device_partition.cpp) add_rocprim_test_parallel("rocprim.device_radix_sort" test_device_radix_sort.cpp.in) add_rocprim_test("rocprim.device_reduce_by_key" test_device_reduce_by_key.cpp) diff --git a/test/rocprim/test_device_partial_sort.cpp b/test/rocprim/test_device_partial_sort.cpp index 6eef7542c..148c392b5 100644 --- a/test/rocprim/test_device_partial_sort.cpp +++ b/test/rocprim/test_device_partial_sort.cpp @@ -70,7 +70,6 @@ class RocprimDevicePartialSortTests : public ::testing::Test static constexpr bool use_indirect_iterator = Params::use_indirect_iterator; }; -// TODO add custom config // TODO no graph support using RocprimDevicePartialSortTestsParams = ::testing::Types< DevicePartialSortParams, @@ -96,6 +95,64 @@ using RocprimDevicePartialSortTestsParams = ::testing::Types< rocprim:: nth_element_config<128, 4, 32, 16, rocprim::block_radix_rank_algorithm::basic>>>>; +template +void inline compare_cpp_14(InputVector input, + OutputVector output, + size_t middle, + CompareFunction compare_op) +{ + using key_type = typename InputVector::value_type; + + // Calculate sorted input results on host + std::vector sorted_input(input); + std::sort(sorted_input.begin(), sorted_input.end(), compare_op); + + // Calculate sorted output results on host + std::vector sorted_output(output); + std::sort(sorted_output.begin() + middle, sorted_output.end(), compare_op); + + // Check if the values are the same + ASSERT_NO_FATAL_FAILURE(test_utils::assert_eq(sorted_output, sorted_input)); +} + +#if CPP17 +template +void inline compare_cpp_17(InputVector input, + OutputVector output, + size_t middle, + CompareFunction compare_op) +{ + using key_type = typename InputVector::value_type; + + // Calculate sorted input results on host + std::vector sorted_input(input); + std::partial_sort(sorted_input.begin(), sorted_input.begin() + middle, sorted_input.end(), compare_op); + std::sort(sorted_input.begin() + middle, sorted_input.end(), compare_op); + + // Calculate sorted output results on host + std::vector sorted_output(output); + std::sort(sorted_output.begin() + middle, sorted_output.end(), compare_op); + + // Check if the values are the same + ASSERT_NO_FATAL_FAILURE(test_utils::assert_eq(sorted_output, sorted_input)); +} +#endif + +template +void inline compare(InputVector input, + OutputVector output, + size_t middle, + CompareFunction compare_op) +{ + compare_cpp_14(input, output, middle, compare_op); +#if CPP17 + // this comparison is only compiled and executed if c++17 is available + compare_cpp_17(input, output, middle, compare_op); +#else + ROCPRIM_PRAGMA_MESSAGE("c++17 not available skips direct comparison with std::partial_sort"); +#endif +} + TYPED_TEST_SUITE(RocprimDevicePartialSortTests, RocprimDevicePartialSortTestsParams); TYPED_TEST(RocprimDevicePartialSortTests, PartialSort) @@ -253,17 +310,13 @@ TYPED_TEST(RocprimDevicePartialSortTests, PartialSort) d_output, size * sizeof(key_type), hipMemcpyDeviceToHost)); - std::sort(output.begin() + middle, output.begin() + size, compare_op); - - // Sort input fully to compare - std::sort(input.begin(), input.end(), compare_op); - ASSERT_NO_FATAL_FAILURE(test_utils::assert_eq(output, input)); + compare(input, output, middle, compare_op); HIP_CHECK(hipFree(d_input)); if(!in_place) { - hipFree(d_output); + HIP_CHECK(hipFree(d_output)); } HIP_CHECK(hipFree(d_temp_storage)); From 711624a0cbc8b6140459e040ad7494f1f04889f9 Mon Sep 17 00:00:00 2001 From: Nick Breed Date: Fri, 14 Jun 2024 06:14:16 +0000 Subject: [PATCH 11/14] Cleanup code based on nth_element review --- benchmark/benchmark_device_partial_sort.cpp | 2 +- .../rocprim/device/device_partial_sort.hpp | 107 +++++++++--------- test/rocprim/test_device_partial_sort.cpp | 5 +- 3 files changed, 60 insertions(+), 54 deletions(-) diff --git a/benchmark/benchmark_device_partial_sort.cpp b/benchmark/benchmark_device_partial_sort.cpp index a921e537f..49db25d25 100644 --- a/benchmark/benchmark_device_partial_sort.cpp +++ b/benchmark/benchmark_device_partial_sort.cpp @@ -80,7 +80,7 @@ int main(int argc, char* argv[]) benchmark::AddCustomContext("seed", seed_type); // Add benchmarks - std::vector benchmarks = {}; + std::vector benchmarks{}; CREATE_BENCHMARK(int) CREATE_BENCHMARK(long long) CREATE_BENCHMARK(int8_t) diff --git a/rocprim/include/rocprim/device/device_partial_sort.hpp b/rocprim/include/rocprim/device/device_partial_sort.hpp index a2d025e61..cddc23aa1 100644 --- a/rocprim/include/rocprim/device/device_partial_sort.hpp +++ b/rocprim/include/rocprim/device/device_partial_sort.hpp @@ -47,6 +47,17 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { +#define RETURN_ON_ERROR(...) \ + do \ + { \ + hipError_t error = (__VA_ARGS__); \ + if(error != hipSuccess) \ + { \ + return error; \ + } \ + } \ + while(0) + template hipError_t partial_sort_impl(void* temporary_storage, size_t& storage_size, @@ -67,38 +78,31 @@ hipError_t partial_sort_impl(void* temporary_storage, // non-null placeholder so that no buffer is allocated for keys key_type* keys_buffer_placeholder = reinterpret_cast(1); - hipError_t result = nth_element_impl(nullptr, - storage_size_nth_element, - keys, - middle, - size, - compare_function, - stream, - debug_synchronous, - keys_buffer_placeholder); - if(result != hipSuccess) - { - return result; - } + RETURN_ON_ERROR(nth_element_impl(nullptr, + storage_size_nth_element, + keys, + middle, + size, + compare_function, + stream, + debug_synchronous, + keys_buffer_placeholder)); size_t storage_size_merge_sort{}; - result = merge_sort_impl(nullptr, - storage_size_merge_sort, - keys, - keys, - static_cast(nullptr), // values_input - static_cast(nullptr), // values_output - middle, - compare_function, - stream, - debug_synchronous, - keys_buffer_placeholder, // keys_buffer - static_cast(nullptr)); // values_buffer - if(result != hipSuccess) - { - return result; - } + RETURN_ON_ERROR( + merge_sort_impl(nullptr, + storage_size_merge_sort, + keys, + keys, + static_cast(nullptr), // values_input + static_cast(nullptr), // values_output + middle, + compare_function, + stream, + debug_synchronous, + keys_buffer_placeholder, // keys_buffer + static_cast(nullptr))); // values_buffer void* temporary_storage_nth_element = nullptr; void* temporary_storage_merge_sort = nullptr; @@ -129,19 +133,20 @@ hipError_t partial_sort_impl(void* temporary_storage, if(middle < size) { - result = nth_element_impl(temporary_storage_nth_element, - storage_size_nth_element, - keys, - middle, - size, - compare_function, - stream, - debug_synchronous, - keys_buffer); - if(result != hipSuccess) - { - return result; - } + RETURN_ON_ERROR(nth_element_impl(temporary_storage_nth_element, + storage_size_nth_element, + keys, + middle, + size, + compare_function, + stream, + debug_synchronous, + keys_buffer)); + } + + if(middle == 0) + { + return hipSuccess; } return merge_sort_impl(temporary_storage_merge_sort, @@ -253,16 +258,12 @@ hipError_t partial_sort_copy(void* temporary_storage, typename std::iterator_traits::value_type>::value, "KeysInputIterator and KeysOutputIterator must have the same value_type"); - hipError_t error = transform(keys_input, - keys_output, - size, - ::rocprim::identity(), - stream, - debug_synchronous); - if(error != hipSuccess) - { - return error; - } + RETURN_ON_ERROR(transform(keys_input, + keys_output, + size, + ::rocprim::identity(), + stream, + debug_synchronous)); return detail::partial_sort_impl(temporary_storage, storage_size, @@ -367,6 +368,8 @@ hipError_t partial_sort(void* temporary_storage, /// @} // end of group devicemodule +#undef RETURN_ON_ERROR + END_ROCPRIM_NAMESPACE #endif // ROCPRIM_DEVICE_DEVICE_PARTIAL_SORT_HPP_ diff --git a/test/rocprim/test_device_partial_sort.cpp b/test/rocprim/test_device_partial_sort.cpp index 148c392b5..3cf701108 100644 --- a/test/rocprim/test_device_partial_sort.cpp +++ b/test/rocprim/test_device_partial_sort.cpp @@ -126,7 +126,10 @@ void inline compare_cpp_17(InputVector input, // Calculate sorted input results on host std::vector sorted_input(input); - std::partial_sort(sorted_input.begin(), sorted_input.begin() + middle, sorted_input.end(), compare_op); + std::partial_sort(sorted_input.begin(), + sorted_input.begin() + middle, + sorted_input.end(), + compare_op); std::sort(sorted_input.begin() + middle, sorted_input.end(), compare_op); // Calculate sorted output results on host From 608a043a44f1bed78210163499402ae4b9419050 Mon Sep 17 00:00:00 2001 From: Nick Breed Date: Wed, 19 Jun 2024 13:59:13 +0000 Subject: [PATCH 12/14] Review adaptations --- benchmark/benchmark_device_partial_sort.hpp | 7 +- docs/reference/ops_summary.rst | 2 +- .../rocprim/device/device_nth_element.hpp | 80 ++-- .../rocprim/device/device_partial_sort.hpp | 141 ++++--- test/rocprim/test_device_partial_sort.cpp | 351 +++++++++++++----- 5 files changed, 393 insertions(+), 188 deletions(-) diff --git a/benchmark/benchmark_device_partial_sort.hpp b/benchmark/benchmark_device_partial_sort.hpp index d0dba07da..d16c9fb1c 100644 --- a/benchmark/benchmark_device_partial_sort.hpp +++ b/benchmark/benchmark_device_partial_sort.hpp @@ -101,10 +101,9 @@ struct device_partial_sort_benchmark : public config_autotune_interface size * sizeof(*d_keys_input), hipMemcpyHostToDevice)); - ::rocprim::less lesser_op; - - void* d_temporary_storage = nullptr; - size_t temporary_storage_bytes = 0; + rocprim::less lesser_op; + void* d_temporary_storage = nullptr; + size_t temporary_storage_bytes = 0; HIP_CHECK(rocprim::partial_sort_copy(d_temporary_storage, temporary_storage_bytes, d_keys_input, diff --git a/docs/reference/ops_summary.rst b/docs/reference/ops_summary.rst index c308eb6bd..9dbf13d68 100644 --- a/docs/reference/ops_summary.rst +++ b/docs/reference/ops_summary.rst @@ -32,7 +32,7 @@ Rearrangement ================ * ``sort`` rearranges the sequence by sorting it. It could be according to a comparison operator or a value using a radix approach -* ``partial_sort`` rearranges the sequence by sorting it up to and including the middle index, according to a comparison operator. +* ``partial_sort`` rearranges the sequence by sorting it up to and including a given index, according to a comparison operator. * ``nth_element`` places the nth element in its sorted position, with elements less-than before, and greater after, according to a comparison operator. * ``exchange`` rearranges the elements according to a different stride configuration which is equivalent to a tensor axis transposition * ``shuffle`` rotates the elements diff --git a/rocprim/include/rocprim/device/device_nth_element.hpp b/rocprim/include/rocprim/device/device_nth_element.hpp index f3090699c..6b387b594 100644 --- a/rocprim/include/rocprim/device/device_nth_element.hpp +++ b/rocprim/include/rocprim/device/device_nth_element.hpp @@ -43,28 +43,29 @@ namespace detail { template -ROCPRIM_INLINE hipError_t nth_element_impl( - void* temporary_storage, - size_t& storage_size, - KeysIterator keys, - size_t nth, - size_t size, - BinaryFunction compare_function, - hipStream_t stream, - bool debug_synchronous, - typename std::iterator_traits::value_type* keys_double_buffer = nullptr) +ROCPRIM_INLINE +hipError_t + nth_element_impl(void* temporary_storage, + size_t& storage_size, + KeysIterator keys, + size_t nth, + size_t size, + BinaryFunction compare_function, + hipStream_t stream, + bool debug_synchronous, + typename std::iterator_traits::value_type* keys_double_buffer + = nullptr) { using key_type = typename std::iterator_traits::value_type; using config = wrapped_nth_element_config; - detail::target_arch target_arch; - hipError_t result = host_target_arch(stream, target_arch); + target_arch target_arch; + hipError_t result = host_target_arch(stream, target_arch); if(result != hipSuccess) { return result; } - const detail::nth_element_config_params params - = detail::dispatch_target_arch(target_arch); + const nth_element_config_params params = dispatch_target_arch(target_arch); constexpr unsigned int num_partitions = 3; const unsigned int num_buckets = params.number_of_buckets; @@ -73,19 +74,18 @@ ROCPRIM_INLINE hipError_t nth_element_impl( const unsigned int num_items_per_threads = params.kernel_config.items_per_thread; const unsigned int num_threads_per_block = params.kernel_config.block_size; const unsigned int num_items_per_block = num_threads_per_block * num_items_per_threads; - const unsigned int num_blocks = detail::ceiling_div(size, num_items_per_block); + const unsigned int num_blocks = ceiling_div(size, num_items_per_block); - // oracles stores the bucket that correlates with the index - key_type* tree = nullptr; - size_t* buckets = nullptr; - detail::n_th_element_iteration_data* nth_element_data = nullptr; - bool* equality_buckets = nullptr; - detail::nth_element_onesweep_lookback_state* lookback_states = nullptr; + key_type* tree = nullptr; + size_t* buckets = nullptr; + n_th_element_iteration_data* nth_element_data = nullptr; + bool* equality_buckets = nullptr; + nth_element_onesweep_lookback_state* lookback_states = nullptr; key_type* keys_buffer = nullptr; { - using namespace detail::temp_storage; + using namespace temp_storage; hipError_t partition_result; if(keys_double_buffer == nullptr) @@ -141,22 +141,22 @@ ROCPRIM_INLINE hipError_t nth_element_impl( std::cout << "storage_size: " << storage_size << '\n'; } - return detail::nth_element_keys_impl(keys, - keys_buffer, - tree, - nth, - size, - buckets, - equality_buckets, - lookback_states, - num_buckets, - stop_recursion_size, - num_threads_per_block, - num_items_per_threads, - nth_element_data, - compare_function, - stream, - debug_synchronous); + return nth_element_keys_impl(keys, + keys_buffer, + tree, + nth, + size, + buckets, + equality_buckets, + lookback_states, + num_buckets, + stop_recursion_size, + num_threads_per_block, + num_items_per_threads, + nth_element_data, + compare_function, + stream, + debug_synchronous); } } // namespace detail @@ -195,7 +195,7 @@ ROCPRIM_INLINE hipError_t nth_element_impl( /// The signature of the function should be equivalent to the following: /// bool f(const T &a, const T &b);. The signature does not need to have /// const &, but function object must not modify the objects passed to it. -/// The comperator must meet the C++ named requirement Compare. +/// The comparator must meet the C++ named requirement Compare. /// The default value is `BinaryFunction()`. /// \param [in] stream [optional] HIP stream object. Default is `0` (default stream). /// \param [in] debug_synchronous [optional] If true, synchronization after every kernel @@ -295,7 +295,7 @@ ROCPRIM_INLINE hipError_t nth_element(void* temporary_storage, /// The signature of the function should be equivalent to the following: /// bool f(const T &a, const T &b);. The signature does not need to have /// const &, but function object must not modify the objects passed to it. -/// The comperator must meet the C++ named requirement Compare. +/// The comparator must meet the C++ named requirement Compare. /// The default value is `BinaryFunction()`. /// \param [in] stream [optional] HIP stream object. Default is `0` (default stream). /// \param [in] debug_synchronous [optional] If true, synchronization after every kernel diff --git a/rocprim/include/rocprim/device/device_partial_sort.hpp b/rocprim/include/rocprim/device/device_partial_sort.hpp index cddc23aa1..e9bacae0d 100644 --- a/rocprim/include/rocprim/device/device_partial_sort.hpp +++ b/rocprim/include/rocprim/device/device_partial_sort.hpp @@ -58,10 +58,11 @@ namespace detail } \ while(0) -template +template hipError_t partial_sort_impl(void* temporary_storage, size_t& storage_size, - KeysIterator keys, + KeysIterator keys_in, + KeysIterator keys_out, size_t middle, size_t size, BinaryFunction compare_function, @@ -69,35 +70,42 @@ hipError_t partial_sort_impl(void* temporary_storage, bool debug_synchronous) { using key_type = typename std::iterator_traits::value_type; - using config - = detail::default_or_custom_config>; + using config = default_or_custom_config>; using config_merge_sort = typename config::merge_sort; using config_nth_element = typename config::nth_element; + if(size != 0 && middle >= size) + { + return hipErrorInvalidValue; + } + size_t storage_size_nth_element{}; // non-null placeholder so that no buffer is allocated for keys key_type* keys_buffer_placeholder = reinterpret_cast(1); - RETURN_ON_ERROR(nth_element_impl(nullptr, - storage_size_nth_element, - keys, - middle, - size, - compare_function, - stream, - debug_synchronous, - keys_buffer_placeholder)); - + const bool full_sort = middle + 1 == size; + if(!full_sort) + { + RETURN_ON_ERROR(nth_element_impl(nullptr, + storage_size_nth_element, + keys_in, + middle, + size, + compare_function, + stream, + debug_synchronous, + keys_buffer_placeholder)); + } size_t storage_size_merge_sort{}; RETURN_ON_ERROR( merge_sort_impl(nullptr, storage_size_merge_sort, - keys, - keys, + keys_in, + keys_out, static_cast(nullptr), // values_input static_cast(nullptr), // values_output - middle, + (!inplace || full_sort) ? middle + 1 : middle, compare_function, stream, debug_synchronous, @@ -107,12 +115,14 @@ hipError_t partial_sort_impl(void* temporary_storage, void* temporary_storage_nth_element = nullptr; void* temporary_storage_merge_sort = nullptr; key_type* keys_buffer = nullptr; + key_type* keys_output_nth_element = nullptr; const hipError_t partition_result = temp_storage::partition( temporary_storage, storage_size, temp_storage::make_linear_partition( temp_storage::ptr_aligned_array(&keys_buffer, size), + temp_storage::ptr_aligned_array(&keys_output_nth_element, inplace ? 0 : size), temp_storage::make_partition(&temporary_storage_nth_element, storage_size_nth_element), temp_storage::make_partition(&temporary_storage_merge_sort, storage_size_merge_sort))); @@ -126,36 +136,51 @@ hipError_t partial_sort_impl(void* temporary_storage, return hipSuccess; } - if(middle > size) + if(!inplace) { - return hipErrorInvalidValue; + RETURN_ON_ERROR(transform(keys_in, + keys_output_nth_element, + size, + rocprim::identity(), + stream, + debug_synchronous)); } - if(middle < size) + if(!full_sort) { - RETURN_ON_ERROR(nth_element_impl(temporary_storage_nth_element, - storage_size_nth_element, - keys, - middle, - size, - compare_function, - stream, - debug_synchronous, - keys_buffer)); + RETURN_ON_ERROR( + nth_element_impl(temporary_storage_nth_element, + storage_size_nth_element, + inplace ? keys_in : keys_output_nth_element, + middle, + size, + compare_function, + stream, + debug_synchronous, + keys_buffer)); } if(middle == 0) { + if(!inplace) + { + RETURN_ON_ERROR(transform(keys_output_nth_element, + keys_out, + 1, + rocprim::identity(), + stream, + debug_synchronous)); + } return hipSuccess; } return merge_sort_impl(temporary_storage_merge_sort, storage_size_merge_sort, - keys, - keys, + inplace ? keys_in : keys_output_nth_element, + keys_out, static_cast(nullptr), // values_input static_cast(nullptr), // values_output - middle, + (!inplace || full_sort) ? middle + 1 : middle, compare_function, stream, debug_synchronous, @@ -171,7 +196,7 @@ hipError_t partial_sort_impl(void* temporary_storage, /// * The contents of the inputs are not altered by the function. /// * Returns the required size of `temporary_storage` in `storage_size` /// if `temporary_storage` is a null pointer. -/// * Accepts custom compare_functions for nth_element across the device. +/// * Accepts custom compare_functions for partial_sort_copy across the device. /// * Streams in graph capture mode are not supported /// /// \tparam Config [optional] configuration of the primitive. It has to be `partial_sort_config`. @@ -184,18 +209,18 @@ hipError_t partial_sort_impl(void* temporary_storage, /// /// \param [in] temporary_storage pointer to a device-accessible temporary storage. When /// a null pointer is passed, the required allocation size (in bytes) is written to -/// `storage_size` and function returns without performing the nth_element rearrangement. +/// `storage_size` and function returns without performing the partial_sort_copy rearrangement. /// \param [in,out] storage_size reference to a size (in bytes) of `temporary_storage`. /// \param [in] keys_input iterator to the input range. /// \param [out] keys_output iterator to the output range. No overlap at all is allowed between `keys_input` and `keys_output`. -/// `keys_output` should be able to be written and read from for `size` elements. +/// `keys_output` should be able to be write to at least `middle` + 1 elements. /// \param [in] middle The index of the point till where it is sorted in the input range. /// \param [in] size number of element in the input range. /// \param [in] compare_function binary operation function object that will be used for comparison. /// The signature of the function should be equivalent to the following: /// bool f(const T &a, const T &b);. The signature does not need to have /// const &, but function object must not modify the objects passed to it. -/// The comperator must meet the C++ named requirement Compare. +/// The comparator must meet the C++ named requirement Compare. /// The default value is `BinaryFunction()`. /// \param [in] stream [optional] HIP stream object. Default is `0` (default stream). /// \param [in] debug_synchronous [optional] If true, synchronization after every kernel @@ -206,7 +231,7 @@ hipError_t partial_sort_impl(void* temporary_storage, /// /// \par Example /// \parblock -/// In this example a device-level nth_element is performed where input keys are +/// In this example a device-level partial_sort_copy is performed where input keys are /// represented by an array of unsigned integers. /// /// \code{.cpp} @@ -216,7 +241,7 @@ hipError_t partial_sort_impl(void* temporary_storage, /// size_t input_size; // e.g., 8 /// size_t middle; // e.g., 4 /// unsigned int * keys_input; // e.g., [ 6, 3, 5, 4, 1, 8, 2, 7 ] -/// unsigned int * keys_output; // empty array of 8 elements +/// unsigned int * keys_output; // e.g., [ 9, 9, 9, 9, 9, 9, 9, 9 ] /// /// size_t temporary_storage_size_bytes; /// void * temporary_storage_ptr = nullptr; @@ -234,7 +259,7 @@ hipError_t partial_sort_impl(void* temporary_storage, /// temporary_storage_ptr, temporary_storage_size_bytes, /// keys_input, keys_output, middle, input_size /// ); -/// // possible keys_output: [ 1, 2, 3, 4, 5, 8, 7, 6 ] +/// // possible keys_output: [ 1, 2, 3, 4, 5, 9, 9, 9 ] /// \endcode /// \endparblock template::value_type>::value, "KeysInputIterator and KeysOutputIterator must have the same value_type"); - RETURN_ON_ERROR(transform(keys_input, - keys_output, - size, - ::rocprim::identity(), - stream, - debug_synchronous)); - - return detail::partial_sort_impl(temporary_storage, - storage_size, - keys_output, - middle, - size, - compare_function, - stream, - debug_synchronous); + return detail::partial_sort_impl( + temporary_storage, + storage_size, + keys_input, + keys_output, + middle, + size, + compare_function, + stream, + debug_synchronous); } /// \brief Rearranges elements such that the range [0, middle) contains the sorted middle smallest elements in the range [0, size). @@ -281,7 +301,7 @@ hipError_t partial_sort_copy(void* temporary_storage, /// * The contents of the inputs are not altered by the function. /// * Returns the required size of `temporary_storage` in `storage_size` /// if `temporary_storage` is a null pointer. -/// * Accepts custom compare_functions for nth_element across the device. +/// * Accepts custom compare_functions for partial_sort across the device. /// * Streams in graph capture mode are not supported /// /// \tparam Config [optional] configuration of the primitive. It has to be `partial_sort_config`. @@ -292,7 +312,7 @@ hipError_t partial_sort_copy(void* temporary_storage, /// /// \param [in] temporary_storage pointer to a device-accessible temporary storage. When /// a null pointer is passed, the required allocation size (in bytes) is written to -/// `storage_size` and function returns without performing the nth_element rearrangement. +/// `storage_size` and function returns without performing the partial_sort rearrangement. /// \param [in,out] storage_size reference to a size (in bytes) of `temporary_storage`. /// \param [in,out] keys iterator to the input range. /// \param [in] middle The index of the point till where it is sorted in the input range. @@ -301,7 +321,7 @@ hipError_t partial_sort_copy(void* temporary_storage, /// The signature of the function should be equivalent to the following: /// bool f(const T &a, const T &b);. The signature does not need to have /// const &, but function object must not modify the objects passed to it. -/// The comperator must meet the C++ named requirement Compare. +/// The comparator must meet the C++ named requirement Compare. /// The default value is `BinaryFunction()`. /// \param [in] stream [optional] HIP stream object. Default is `0` (default stream). /// \param [in] debug_synchronous [optional] If true, synchronization after every kernel @@ -312,7 +332,7 @@ hipError_t partial_sort_copy(void* temporary_storage, /// /// \par Example /// \parblock -/// In this example a device-level nth_element is performed where input keys are +/// In this example a device-level partial_sort is performed where input keys are /// represented by an array of unsigned integers. /// /// \code{.cpp} @@ -328,7 +348,7 @@ hipError_t partial_sort_copy(void* temporary_storage, /// // Get required size of the temporary storage /// rocprim::partial_sort( /// temporary_storage_ptr, temporary_storage_size_bytes, -/// keys, nth, input_size +/// keys, middle, input_size /// ); /// /// // allocate temporary storage @@ -337,7 +357,7 @@ hipError_t partial_sort_copy(void* temporary_storage, /// // perform partial_sort /// rocprim::partial_sort( /// temporary_storage_ptr, temporary_storage_size_bytes, -/// keys, nth, input_size +/// keys, middle, input_size /// ); /// // possible keys: [ 1, 2, 3, 4, 5, 8, 7, 6 ] /// \endcode @@ -358,6 +378,7 @@ hipError_t partial_sort(void* temporary_storage, return detail::partial_sort_impl(temporary_storage, storage_size, keys, + keys, middle, size, compare_function, diff --git a/test/rocprim/test_device_partial_sort.cpp b/test/rocprim/test_device_partial_sort.cpp index 3cf701108..077d1fc59 100644 --- a/test/rocprim/test_device_partial_sort.cpp +++ b/test/rocprim/test_device_partial_sort.cpp @@ -31,6 +31,7 @@ #include "../common_test_header.hpp" // required rocprim headers +#include #include #include #include @@ -96,20 +97,25 @@ using RocprimDevicePartialSortTestsParams = ::testing::Types< nth_element_config<128, 4, 32, 16, rocprim::block_radix_rank_algorithm::basic>>>>; template -void inline compare_cpp_14(InputVector input, - OutputVector output, - size_t middle, - CompareFunction compare_op) +void inline compare_partial_sort_cpp_14(InputVector input, + OutputVector output, + size_t middle, + CompareFunction compare_op) { using key_type = typename InputVector::value_type; + if(input.size() == 0) + { + return; + } + // Calculate sorted input results on host std::vector sorted_input(input); std::sort(sorted_input.begin(), sorted_input.end(), compare_op); // Calculate sorted output results on host std::vector sorted_output(output); - std::sort(sorted_output.begin() + middle, sorted_output.end(), compare_op); + std::sort(sorted_output.begin() + middle + 1, sorted_output.end(), compare_op); // Check if the values are the same ASSERT_NO_FATAL_FAILURE(test_utils::assert_eq(sorted_output, sorted_input)); @@ -117,24 +123,29 @@ void inline compare_cpp_14(InputVector input, #if CPP17 template -void inline compare_cpp_17(InputVector input, - OutputVector output, - size_t middle, - CompareFunction compare_op) +void inline compare_partial_sort_cpp_17(InputVector input, + OutputVector output, + size_t middle, + CompareFunction compare_op) { using key_type = typename InputVector::value_type; + if(input.size() == 0) + { + return; + } + // Calculate sorted input results on host std::vector sorted_input(input); std::partial_sort(sorted_input.begin(), - sorted_input.begin() + middle, + sorted_input.begin() + middle + 1, sorted_input.end(), compare_op); - std::sort(sorted_input.begin() + middle, sorted_input.end(), compare_op); + std::sort(sorted_input.begin() + middle + 1, sorted_input.end(), compare_op); // Calculate sorted output results on host std::vector sorted_output(output); - std::sort(sorted_output.begin() + middle, sorted_output.end(), compare_op); + std::sort(sorted_output.begin() + middle + 1, sorted_output.end(), compare_op); // Check if the values are the same ASSERT_NO_FATAL_FAILURE(test_utils::assert_eq(sorted_output, sorted_input)); @@ -142,15 +153,15 @@ void inline compare_cpp_17(InputVector input, #endif template -void inline compare(InputVector input, - OutputVector output, - size_t middle, - CompareFunction compare_op) +void inline compare_partial_sort(InputVector input, + OutputVector output, + size_t middle, + CompareFunction compare_op) { - compare_cpp_14(input, output, middle, compare_op); + compare_partial_sort_cpp_14(input, output, middle, compare_op); #if CPP17 // this comparison is only compiled and executed if c++17 is available - compare_cpp_17(input, output, middle, compare_op); + compare_partial_sort_cpp_17(input, output, middle, compare_op); #else ROCPRIM_PRAGMA_MESSAGE("c++17 not available skips direct comparison with std::partial_sort"); #endif @@ -170,8 +181,6 @@ TYPED_TEST(RocprimDevicePartialSortTests, PartialSort) const bool debug_synchronous = TestFixture::debug_synchronous; static constexpr bool use_indirect_iterator = TestFixture::use_indirect_iterator; - bool in_place = false; - for(size_t seed_index = 0; seed_index < random_seeds_count + seed_size; ++seed_index) { unsigned int seed_value @@ -183,12 +192,10 @@ TYPED_TEST(RocprimDevicePartialSortTests, PartialSort) SCOPED_TRACE(testing::Message() << "with size = " << size); std::vector middles = {0}; - if(size > 0) - { - middles.push_back(size); - } + if(size > 1) { + middles.push_back(size - 1); middles.push_back(test_utils::get_random_value(1, size - 1, seed_value)); } @@ -223,15 +230,7 @@ TYPED_TEST(RocprimDevicePartialSortTests, PartialSort) hipMemcpyHostToDevice)); key_type* d_output; - if(in_place) - { - d_output = d_input; - } - else - { - HIP_CHECK( - test_common_utils::hipMallocHelper(&d_output, size * sizeof(key_type))); - } + d_output = d_input; const auto input_it = test_utils::wrap_in_indirect_iterator(d_input); @@ -240,29 +239,14 @@ TYPED_TEST(RocprimDevicePartialSortTests, PartialSort) // Allocate temporary storage size_t temp_storage_size_bytes{}; - if(in_place) - { - HIP_CHECK(rocprim::partial_sort(nullptr, - temp_storage_size_bytes, - input_it, - middle, - size, - compare_op, - stream, - debug_synchronous)); - } - else - { - HIP_CHECK(rocprim::partial_sort_copy(nullptr, - temp_storage_size_bytes, - input_it, - d_output, - middle, - size, - compare_op, - stream, - debug_synchronous)); - } + HIP_CHECK(rocprim::partial_sort(nullptr, + temp_storage_size_bytes, + input_it, + middle, + size, + compare_op, + stream, + debug_synchronous)); ASSERT_GT(temp_storage_size_bytes, 0); void* d_temp_storage{}; @@ -274,30 +258,238 @@ TYPED_TEST(RocprimDevicePartialSortTests, PartialSort) { graph = test_utils::createGraphHelper(stream); } - if(in_place) + HIP_CHECK(rocprim::partial_sort(d_temp_storage, + temp_storage_size_bytes, + input_it, + middle, + size, + compare_op, + stream, + debug_synchronous)); + + HIP_CHECK(hipGetLastError()); + + hipGraphExec_t graph_instance; + if(TestFixture::use_graphs) + { + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); + } + + std::vector output(size); + HIP_CHECK(hipMemcpy(output.data(), + d_output, + size * sizeof(key_type), + hipMemcpyDeviceToHost)); + + compare_partial_sort(input, output, middle, compare_op); + + HIP_CHECK(hipFree(d_input)); + HIP_CHECK(hipFree(d_temp_storage)); + + if(TestFixture::use_graphs) { - HIP_CHECK(rocprim::partial_sort(d_temp_storage, - temp_storage_size_bytes, - input_it, - middle, - size, - compare_op, - stream, - debug_synchronous)); + test_utils::cleanupGraphHelper(graph, graph_instance); + HIP_CHECK(hipStreamDestroy(stream)); + } + } + } + } +} + +template +void inline compare_partial_sort_copy_cpp_14(InputVector input, + OutputVector output, + OutputVector orignal_output, + size_t middle, + CompareFunction compare_op) +{ + using key_type = typename InputVector::value_type; + + if(input.size() == 0) + { + return; + } + std::vector expected_output; + // Calculate sorted input results on host + std::vector sorted_input(input); + std::sort(sorted_input.begin(), sorted_input.end(), compare_op); + + expected_output.insert(expected_output.end(), + sorted_input.begin(), + sorted_input.begin() + std::min(middle + 1, sorted_input.size())); + + if(middle + 1 < orignal_output.size()) + { + expected_output.insert(expected_output.end(), + orignal_output.begin() + middle + 1, + orignal_output.end()); + } + + // Check if the values are the same + ASSERT_NO_FATAL_FAILURE(test_utils::assert_eq(output, expected_output)); +} + +#if CPP17 +template +void inline compare_partial_sort_copy_cpp_17(InputVector input, + OutputVector output, + OutputVector orignal_output, + size_t middle, + CompareFunction compare_op) +{ + using key_type = typename InputVector::value_type; + + if(input.size() == 0) + { + return; + } + + // Calculate sorted input results on host + std::vector sorted_output(orignal_output); + std::partial_sort_copy(input.begin(), + input.end(), + sorted_output.begin(), + sorted_output.begin() + middle + 1, + compare_op); + + // Check if the values are the same + ASSERT_NO_FATAL_FAILURE(test_utils::assert_eq(sorted_output, output)); +} +#endif + +template +void inline compare_partial_sort_copy(InputVector input, + OutputVector output, + OutputVector orignal_output, + size_t middle, + CompareFunction compare_op) +{ + compare_partial_sort_copy_cpp_14(input, output, orignal_output, middle, compare_op); +#if CPP17 + // this comparison is only compiled and executed if c++17 is available + compare_partial_sort_copy_cpp_17(input, output, orignal_output, middle, compare_op); +#else + ROCPRIM_PRAGMA_MESSAGE( + "c++17 not available skips direct comparison with std::partial_sort_copy"); +#endif +} + +TYPED_TEST(RocprimDevicePartialSortTests, PartialSortCopy) +{ + int device_id = test_common_utils::obtain_device_from_ctest(); + SCOPED_TRACE(testing::Message() << "with device_id = " << device_id); + HIP_CHECK(hipSetDevice(device_id)); + + using key_type = typename TestFixture::key_type; + using compare_function = typename TestFixture::compare_function; + using config = typename TestFixture::config; + const bool debug_synchronous = TestFixture::debug_synchronous; + static constexpr bool use_indirect_iterator = TestFixture::use_indirect_iterator; + + for(size_t seed_index = 0; seed_index < random_seeds_count + seed_size; ++seed_index) + { + unsigned int seed_value + = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count]; + SCOPED_TRACE(testing::Message() << "with seed = " << seed_value); + + for(size_t size : test_utils::get_sizes(seed_value)) + { + SCOPED_TRACE(testing::Message() << "with size = " << size); + + std::vector middles = {0}; + + if(size > 1) + { + middles.push_back(size - 1); + middles.push_back(test_utils::get_random_value(1, size - 1, seed_value)); + } + + for(size_t middle : middles) + { + SCOPED_TRACE(testing::Message() << "with middle = " << middle); + + hipStream_t stream = 0; // default + if(TestFixture::use_graphs) + { + HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); + } + + std::vector input; + std::vector output_original; + if(rocprim::is_floating_point::value) + { + input = test_utils::get_random_data(size, -1000, 1000, seed_value); + output_original + = test_utils::get_random_data(size, -1000, 1000, seed_value + 1); } else { - HIP_CHECK(rocprim::partial_sort_copy(d_temp_storage, - temp_storage_size_bytes, - input_it, - d_output, - middle, - size, - compare_op, - stream, - debug_synchronous)); + input = test_utils::get_random_data( + size, + test_utils::numeric_limits::min(), + test_utils::numeric_limits::max(), + seed_value); + output_original = test_utils::get_random_data( + size, + test_utils::numeric_limits::min(), + test_utils::numeric_limits::max(), + seed_value + 1); + } + + key_type* d_input; + HIP_CHECK(test_common_utils::hipMallocHelper(&d_input, size * sizeof(key_type))); + HIP_CHECK(hipMemcpy(d_input, + input.data(), + size * sizeof(key_type), + hipMemcpyHostToDevice)); + + key_type* d_output; + + HIP_CHECK(test_common_utils::hipMallocHelper(&d_output, size * sizeof(key_type))); + HIP_CHECK(hipMemcpy(d_output, + output_original.data(), + size * sizeof(key_type), + hipMemcpyHostToDevice)); + + const auto input_it + = test_utils::wrap_in_indirect_iterator(d_input); + + compare_function compare_op; + + // Allocate temporary storage + size_t temp_storage_size_bytes{}; + + HIP_CHECK(rocprim::partial_sort_copy(nullptr, + temp_storage_size_bytes, + input_it, + d_output, + middle, + size, + compare_op, + stream, + debug_synchronous)); + + ASSERT_GT(temp_storage_size_bytes, 0); + void* d_temp_storage{}; + HIP_CHECK( + test_common_utils::hipMallocHelper(&d_temp_storage, temp_storage_size_bytes)); + + hipGraph_t graph; + if(TestFixture::use_graphs) + { + graph = test_utils::createGraphHelper(stream); } + HIP_CHECK(rocprim::partial_sort_copy(d_temp_storage, + temp_storage_size_bytes, + input_it, + d_output, + middle, + size, + compare_op, + stream, + debug_synchronous)); + HIP_CHECK(hipGetLastError()); hipGraphExec_t graph_instance; @@ -306,21 +498,16 @@ TYPED_TEST(RocprimDevicePartialSortTests, PartialSort) graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); } - // The algorithm sorted [first, middle). Since the order of [middle, last) is not specified, - // sort [middle, last) to compare with expected values. std::vector output(size); HIP_CHECK(hipMemcpy(output.data(), d_output, size * sizeof(key_type), hipMemcpyDeviceToHost)); - compare(input, output, middle, compare_op); + compare_partial_sort_copy(input, output, output_original, middle, compare_op); HIP_CHECK(hipFree(d_input)); - if(!in_place) - { - HIP_CHECK(hipFree(d_output)); - } + HIP_CHECK(hipFree(d_output)); HIP_CHECK(hipFree(d_temp_storage)); if(TestFixture::use_graphs) @@ -328,8 +515,6 @@ TYPED_TEST(RocprimDevicePartialSortTests, PartialSort) test_utils::cleanupGraphHelper(graph, graph_instance); HIP_CHECK(hipStreamDestroy(stream)); } - - in_place = !in_place; } } } From b2c60bb4d232433df3d43215365a330d4fcddd8f Mon Sep 17 00:00:00 2001 From: Nick Breed Date: Thu, 20 Jun 2024 09:19:17 +0000 Subject: [PATCH 13/14] Added benchmark for partial_sort --- benchmark/CMakeLists.txt | 1 + benchmark/benchmark_device_partial_sort.hpp | 87 +++++---- .../benchmark_device_partial_sort_copy.cpp | 123 ++++++++++++ .../benchmark_device_partial_sort_copy.hpp | 179 ++++++++++++++++++ 4 files changed, 349 insertions(+), 41 deletions(-) create mode 100644 benchmark/benchmark_device_partial_sort_copy.cpp create mode 100644 benchmark/benchmark_device_partial_sort_copy.hpp diff --git a/benchmark/CMakeLists.txt b/benchmark/CMakeLists.txt index 638b7a16b..03cada66c 100644 --- a/benchmark/CMakeLists.txt +++ b/benchmark/CMakeLists.txt @@ -143,6 +143,7 @@ add_rocprim_benchmark(benchmark_device_merge_sort_block_sort.cpp) add_rocprim_benchmark(benchmark_device_merge_sort_block_merge.cpp) add_rocprim_benchmark(benchmark_device_nth_element.cpp) add_rocprim_benchmark(benchmark_device_partial_sort.cpp) +add_rocprim_benchmark(benchmark_device_partial_sort_copy.cpp) add_rocprim_benchmark(benchmark_device_partition.cpp) add_rocprim_benchmark(benchmark_device_radix_sort.cpp) add_rocprim_benchmark(benchmark_device_radix_sort_block_sort.cpp) diff --git a/benchmark/benchmark_device_partial_sort.hpp b/benchmark/benchmark_device_partial_sort.hpp index d16c9fb1c..525c68c9f 100644 --- a/benchmark/benchmark_device_partial_sort.hpp +++ b/benchmark/benchmark_device_partial_sort.hpp @@ -53,7 +53,7 @@ struct device_partial_sort_benchmark : public config_autotune_interface { using namespace std::string_literals; return bench_naming::format_name( - "{lvl:device,algo:partial_sort,nth:" + (small_n ? "small"s : "large"s) + "{lvl:device,algo:partial_sort,nth:" + (small_n ? "small"s : "half"s) + ",key_type:" + std::string(Traits::name()) + ",cfg:default_config}"); } @@ -92,11 +92,11 @@ struct device_partial_sort_benchmark : public config_autotune_interface } key_type* d_keys_input; - key_type* d_keys_output; + key_type* d_keys_new_data; HIP_CHECK(hipMalloc(&d_keys_input, size * sizeof(*d_keys_input))); - HIP_CHECK(hipMalloc(&d_keys_output, size * sizeof(*d_keys_output))); + HIP_CHECK(hipMalloc(&d_keys_new_data, size * sizeof(*d_keys_new_data))); - HIP_CHECK(hipMemcpy(d_keys_input, + HIP_CHECK(hipMemcpy(d_keys_new_data, keys_input.data(), size * sizeof(*d_keys_input), hipMemcpyHostToDevice)); @@ -104,30 +104,32 @@ struct device_partial_sort_benchmark : public config_autotune_interface rocprim::less lesser_op; void* d_temporary_storage = nullptr; size_t temporary_storage_bytes = 0; - HIP_CHECK(rocprim::partial_sort_copy(d_temporary_storage, - temporary_storage_bytes, - d_keys_input, - d_keys_output, - middle, - size, - lesser_op, - stream, - false)); + HIP_CHECK(rocprim::partial_sort(d_temporary_storage, + temporary_storage_bytes, + d_keys_input, + middle, + size, + lesser_op, + stream, + false)); HIP_CHECK(hipMalloc(&d_temporary_storage, temporary_storage_bytes)); // Warm-up for(size_t i = 0; i < warmup_size; i++) { - HIP_CHECK(rocprim::partial_sort_copy(d_temporary_storage, - temporary_storage_bytes, - d_keys_input, - d_keys_output, - middle, - size, - lesser_op, - stream, - false)); + HIP_CHECK(hipMemcpy(d_keys_input, + d_keys_new_data, + size * sizeof(*d_keys_input), + hipMemcpyDeviceToDevice)); + HIP_CHECK(rocprim::partial_sort(d_temporary_storage, + temporary_storage_bytes, + d_keys_input, + middle, + size, + lesser_op, + stream, + false)); } HIP_CHECK(hipDeviceSynchronize()); @@ -138,28 +140,31 @@ struct device_partial_sort_benchmark : public config_autotune_interface for(auto _ : state) { - // Record start event - HIP_CHECK(hipEventRecord(start, stream)); - + float elapsed_mseconds = 0; for(size_t i = 0; i < batch_size; i++) { - HIP_CHECK(rocprim::partial_sort_copy(d_temporary_storage, - temporary_storage_bytes, - d_keys_input, - d_keys_output, - middle, - size, - lesser_op, - stream, - false)); + HIP_CHECK(hipMemcpy(d_keys_input, + d_keys_new_data, + size * sizeof(*d_keys_input), + hipMemcpyDeviceToDevice)); + // Record start event + HIP_CHECK(hipEventRecord(start, stream)); + HIP_CHECK(rocprim::partial_sort(d_temporary_storage, + temporary_storage_bytes, + d_keys_input, + middle, + size, + lesser_op, + stream, + false)); + // Record stop event and wait until it completes + HIP_CHECK(hipEventRecord(stop, stream)); + HIP_CHECK(hipEventSynchronize(stop)); + float elapsed_mseconds_current; + HIP_CHECK(hipEventElapsedTime(&elapsed_mseconds_current, start, stop)); + elapsed_mseconds += elapsed_mseconds_current; } - // Record stop event and wait until it completes - HIP_CHECK(hipEventRecord(stop, stream)); - HIP_CHECK(hipEventSynchronize(stop)); - - float elapsed_mseconds; - HIP_CHECK(hipEventElapsedTime(&elapsed_mseconds, start, stop)); state.SetIterationTime(elapsed_mseconds / 1000); } @@ -172,7 +177,7 @@ struct device_partial_sort_benchmark : public config_autotune_interface HIP_CHECK(hipFree(d_temporary_storage)); HIP_CHECK(hipFree(d_keys_input)); - HIP_CHECK(hipFree(d_keys_output)); + HIP_CHECK(hipFree(d_keys_new_data)); } }; diff --git a/benchmark/benchmark_device_partial_sort_copy.cpp b/benchmark/benchmark_device_partial_sort_copy.cpp new file mode 100644 index 000000000..e8097b635 --- /dev/null +++ b/benchmark/benchmark_device_partial_sort_copy.cpp @@ -0,0 +1,123 @@ +// MIT License +// +// Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +#include "benchmark_device_partial_sort_copy.hpp" +#include "benchmark_utils.hpp" + +// CmdParser +#include "cmdparser.hpp" + +// Google Benchmark +#include + +// HIP API +#include + +#include +#include + +#ifndef DEFAULT_N +const size_t DEFAULT_N = 1024 * 1024 * 32; +#endif + +#define CREATE_BENCHMARK_PARTIAL_SORT_COPY(TYPE, SMALL_N) \ + { \ + const device_partial_sort_copy_benchmark instance(SMALL_N); \ + REGISTER_BENCHMARK(benchmarks, size, seed, stream, instance); \ + } + +#define CREATE_BENCHMARK(TYPE) \ + { \ + CREATE_BENCHMARK_PARTIAL_SORT_COPY(TYPE, true) \ + CREATE_BENCHMARK_PARTIAL_SORT_COPY(TYPE, false) \ + } + +int main(int argc, char* argv[]) +{ + cli::Parser parser(argc, argv); + parser.set_optional("size", "size", DEFAULT_N, "number of values"); + parser.set_optional("trials", "trials", -1, "number of iterations"); + parser.set_optional("name_format", + "name_format", + "human", + "either: json,human,txt"); + parser.set_optional("seed", "seed", "random", get_seed_message()); + parser.run_and_exit_if_error(); + + // Parse argv + benchmark::Initialize(&argc, argv); + const size_t size = parser.get("size"); + const int trials = parser.get("trials"); + bench_naming::set_format(parser.get("name_format")); + const std::string seed_type = parser.get("seed"); + const managed_seed seed(seed_type); + + // HIP + hipStream_t stream = 0; // default + + // Benchmark info + add_common_benchmark_info(); + benchmark::AddCustomContext("size", std::to_string(size)); + benchmark::AddCustomContext("seed", seed_type); + + // Add benchmarks + std::vector benchmarks{}; + CREATE_BENCHMARK(int) + CREATE_BENCHMARK(long long) + CREATE_BENCHMARK(int8_t) + CREATE_BENCHMARK(uint8_t) + CREATE_BENCHMARK(rocprim::half) + CREATE_BENCHMARK(short) + CREATE_BENCHMARK(float) + + using custom_float2 = custom_type; + using custom_double2 = custom_type; + using custom_int2 = custom_type; + using custom_char_double = custom_type; + using custom_longlong_double = custom_type; + + CREATE_BENCHMARK(custom_float2) + CREATE_BENCHMARK(custom_double2) + CREATE_BENCHMARK(custom_int2) + CREATE_BENCHMARK(custom_char_double) + CREATE_BENCHMARK(custom_longlong_double) + + // Use manual timing + for(auto& b : benchmarks) + { + b->UseManualTime(); + b->Unit(benchmark::kMillisecond); + } + + // Force number of iterations + if(trials > 0) + { + for(auto& b : benchmarks) + { + b->Iterations(trials); + } + } + + // Run benchmarks + benchmark::RunSpecifiedBenchmarks(); + return 0; +} diff --git a/benchmark/benchmark_device_partial_sort_copy.hpp b/benchmark/benchmark_device_partial_sort_copy.hpp new file mode 100644 index 000000000..7510c5089 --- /dev/null +++ b/benchmark/benchmark_device_partial_sort_copy.hpp @@ -0,0 +1,179 @@ +// MIT License +// +// Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +#ifndef ROCPRIM_BENCHMARK_DEVICE_PARTIAL_SORT_COPY_PARALLEL_HPP_ +#define ROCPRIM_BENCHMARK_DEVICE_PARTIAL_SORT_COPY_PARALLEL_HPP_ + +#include "benchmark_utils.hpp" + +// Google Benchmark +#include + +// HIP API +#include + +// rocPRIM +#include + +#include +#include + +#include + +template +struct device_partial_sort_copy_benchmark : public config_autotune_interface +{ + bool small_n = false; + + device_partial_sort_copy_benchmark(bool SmallN) + { + small_n = SmallN; + } + + std::string name() const override + { + using namespace std::string_literals; + return bench_naming::format_name( + "{lvl:device,algo:partial_sort_copy,nth:" + (small_n ? "small"s : "half"s) + + ",key_type:" + std::string(Traits::name()) + ",cfg:default_config}"); + } + + static constexpr unsigned int batch_size = 10; + static constexpr unsigned int warmup_size = 5; + + void run(benchmark::State& state, + size_t size, + const managed_seed& seed, + hipStream_t stream) const override + { + using key_type = Key; + + size_t middle = 10; + + if(!small_n) + { + middle = size / 2; + } + + // Generate data + std::vector keys_input; + if(std::is_floating_point::value) + { + keys_input = get_random_data(size, + static_cast(-1000), + static_cast(1000), + seed.get_0()); + } + else + { + keys_input = get_random_data(size, + std::numeric_limits::min(), + std::numeric_limits::max(), + seed.get_0()); + } + + key_type* d_keys_input; + key_type* d_keys_output; + HIP_CHECK(hipMalloc(&d_keys_input, size * sizeof(*d_keys_input))); + HIP_CHECK(hipMalloc(&d_keys_output, size * sizeof(*d_keys_output))); + + HIP_CHECK(hipMemcpy(d_keys_input, + keys_input.data(), + size * sizeof(*d_keys_input), + hipMemcpyHostToDevice)); + + rocprim::less lesser_op; + void* d_temporary_storage = nullptr; + size_t temporary_storage_bytes = 0; + HIP_CHECK(rocprim::partial_sort_copy(d_temporary_storage, + temporary_storage_bytes, + d_keys_input, + d_keys_output, + middle, + size, + lesser_op, + stream, + false)); + + HIP_CHECK(hipMalloc(&d_temporary_storage, temporary_storage_bytes)); + + // Warm-up + for(size_t i = 0; i < warmup_size; i++) + { + HIP_CHECK(rocprim::partial_sort_copy(d_temporary_storage, + temporary_storage_bytes, + d_keys_input, + d_keys_output, + middle, + size, + lesser_op, + stream, + false)); + } + HIP_CHECK(hipDeviceSynchronize()); + + // HIP events creation + hipEvent_t start, stop; + HIP_CHECK(hipEventCreate(&start)); + HIP_CHECK(hipEventCreate(&stop)); + + for(auto _ : state) + { + // Record start event + HIP_CHECK(hipEventRecord(start, stream)); + + for(size_t i = 0; i < batch_size; i++) + { + HIP_CHECK(rocprim::partial_sort_copy(d_temporary_storage, + temporary_storage_bytes, + d_keys_input, + d_keys_output, + middle, + size, + lesser_op, + stream, + false)); + } + + // Record stop event and wait until it completes + HIP_CHECK(hipEventRecord(stop, stream)); + HIP_CHECK(hipEventSynchronize(stop)); + + float elapsed_mseconds; + HIP_CHECK(hipEventElapsedTime(&elapsed_mseconds, start, stop)); + state.SetIterationTime(elapsed_mseconds / 1000); + } + + // Destroy HIP events + HIP_CHECK(hipEventDestroy(start)); + HIP_CHECK(hipEventDestroy(stop)); + + state.SetBytesProcessed(state.iterations() * batch_size * size * sizeof(*d_keys_input)); + state.SetItemsProcessed(state.iterations() * batch_size * size); + + HIP_CHECK(hipFree(d_temporary_storage)); + HIP_CHECK(hipFree(d_keys_input)); + HIP_CHECK(hipFree(d_keys_output)); + } +}; + +#endif // ROCPRIM_BENCHMARK_DEVICE_PARTIAL_SORT_COPY_PARALLEL_HPP_ From 88cb742d42254fe3b720bbb56975deabfa0710d1 Mon Sep 17 00:00:00 2001 From: Nick Breed Date: Mon, 24 Jun 2024 06:31:11 +0000 Subject: [PATCH 14/14] Fixed bug with inplicit casting in partial sort --- .../rocprim/device/device_partial_sort.hpp | 125 ++++++++++++------ test/rocprim/indirect_iterator.hpp | 6 +- 2 files changed, 86 insertions(+), 45 deletions(-) diff --git a/rocprim/include/rocprim/device/device_partial_sort.hpp b/rocprim/include/rocprim/device/device_partial_sort.hpp index e9bacae0d..ca8e86d44 100644 --- a/rocprim/include/rocprim/device/device_partial_sort.hpp +++ b/rocprim/include/rocprim/device/device_partial_sort.hpp @@ -58,18 +58,22 @@ namespace detail } \ while(0) -template -hipError_t partial_sort_impl(void* temporary_storage, - size_t& storage_size, - KeysIterator keys_in, - KeysIterator keys_out, - size_t middle, - size_t size, - BinaryFunction compare_function, - hipStream_t stream, - bool debug_synchronous) +template +hipError_t partial_sort_impl(void* temporary_storage, + size_t& storage_size, + KeysInputIterator keys_in, + KeysOutputIterator keys_out, + size_t middle, + size_t size, + BinaryFunction compare_function, + hipStream_t stream, + bool debug_synchronous) { - using key_type = typename std::iterator_traits::value_type; + using key_type = typename std::iterator_traits::value_type; using config = default_or_custom_config>; using config_merge_sort = typename config::merge_sort; using config_nth_element = typename config::nth_element; @@ -148,16 +152,30 @@ hipError_t partial_sort_impl(void* temporary_storage, if(!full_sort) { - RETURN_ON_ERROR( - nth_element_impl(temporary_storage_nth_element, - storage_size_nth_element, - inplace ? keys_in : keys_output_nth_element, - middle, - size, - compare_function, - stream, - debug_synchronous, - keys_buffer)); + if(inplace) + { + RETURN_ON_ERROR(nth_element_impl(temporary_storage_nth_element, + storage_size_nth_element, + keys_in, + middle, + size, + compare_function, + stream, + debug_synchronous, + keys_buffer)); + } + else + { + RETURN_ON_ERROR(nth_element_impl(temporary_storage_nth_element, + storage_size_nth_element, + keys_output_nth_element, + middle, + size, + compare_function, + stream, + debug_synchronous, + keys_buffer)); + } } if(middle == 0) @@ -174,18 +192,38 @@ hipError_t partial_sort_impl(void* temporary_storage, return hipSuccess; } - return merge_sort_impl(temporary_storage_merge_sort, - storage_size_merge_sort, - inplace ? keys_in : keys_output_nth_element, - keys_out, - static_cast(nullptr), // values_input - static_cast(nullptr), // values_output - (!inplace || full_sort) ? middle + 1 : middle, - compare_function, - stream, - debug_synchronous, - keys_buffer, // keys_buffer - static_cast(nullptr)); // values_buffer + if(inplace) + { + return merge_sort_impl( + temporary_storage_merge_sort, + storage_size_merge_sort, + keys_in, + keys_out, + static_cast(nullptr), // values_input + static_cast(nullptr), // values_output + full_sort ? middle + 1 : middle, + compare_function, + stream, + debug_synchronous, + keys_buffer, // keys_buffer + static_cast(nullptr)); // values_buffer + } + else + { + return merge_sort_impl( + temporary_storage_merge_sort, + storage_size_merge_sort, + keys_output_nth_element, + keys_out, + static_cast(nullptr), // values_input + static_cast(nullptr), // values_output + middle + 1, + compare_function, + stream, + debug_synchronous, + keys_buffer, // keys_buffer + static_cast(nullptr)); // values_buffer + } } } // namespace detail @@ -283,16 +321,17 @@ hipError_t partial_sort_copy(void* temporary_storage, typename std::iterator_traits::value_type>::value, "KeysInputIterator and KeysOutputIterator must have the same value_type"); - return detail::partial_sort_impl( - temporary_storage, - storage_size, - keys_input, - keys_output, - middle, - size, - compare_function, - stream, - debug_synchronous); + return detail:: + partial_sort_impl( + temporary_storage, + storage_size, + keys_input, + keys_output, + middle, + size, + compare_function, + stream, + debug_synchronous); } /// \brief Rearranges elements such that the range [0, middle) contains the sorted middle smallest elements in the range [0, size). diff --git a/test/rocprim/indirect_iterator.hpp b/test/rocprim/indirect_iterator.hpp index 68328f7bb..5a98d4022 100644 --- a/test/rocprim/indirect_iterator.hpp +++ b/test/rocprim/indirect_iterator.hpp @@ -78,11 +78,13 @@ class indirect_iterator using iterator_category = std::random_access_iterator_tag; - ROCPRIM_HOST_DEVICE inline indirect_iterator(T* ptr) : ptr_(ptr) {} + ROCPRIM_HOST_DEVICE inline explicit indirect_iterator(T* ptr) : ptr_(ptr) {} ROCPRIM_HOST_DEVICE inline ~indirect_iterator() = default; - ROCPRIM_HOST_DEVICE inline indirect_iterator& operator++() + ROCPRIM_HOST_DEVICE + inline indirect_iterator& + operator++() { ++ptr_; return *this;