From 14779f76ec95ac1a925730f11d6bada622586c91 Mon Sep 17 00:00:00 2001 From: Bence Parajdi <148080361+parbenc@users.noreply.github.com> Date: Wed, 3 Apr 2024 21:55:20 +0200 Subject: [PATCH] Upstream staging 2024 03 22 (#541) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Modify adjacent_difference_kernel_impl so if ouput_type is void the input_type is used instead * fix(match_result_type.hpp): deprecate 'match_result_type.hpp' instead of removing it * change device_batch_memcpy to have an IsMemCpy template param started implementing test for device_copy * separate batch_memcpy and batch_cpy into different API calls and header files create test for batch_copy * add benchmark for batch_copy add separate config for batch_copy * update docs * removed unused variable * fix review comments * use warp_size properly instead of hardcoded number * merge batch_copy and batch_memcpy tests * merge batch_copy and batch_memcpy benchmarks * fix unused parameters * add event destroys to the benchmark * fix(device_scan.hpp): throw compiler warning on gfx 11 (navi 3x) for incorrect results due to compiler bug This change should be reverted once the compiler bug is fixed. * Resolve "tests for `block_adjacent_difference` and `block_discontinuity` don't compile with `rocprim::half`" * Update device_adjacent_difference fix for void output type to match CUB * Update CHANGELOG * revert: fix(device_scan.hpp): throw compiler warning on gfx 11 (navi 3x) for incorrect results due to compiler bug This reverts commit 3f375e458e741d10e46c61119ea93ea1012b7946. * fix(lookback_scan_state.hpp): fix sporadic failure in device scan algorithm on navi 3 gpus The compiler does not emit a 'buffer_gl0_inv'-instruction that is required on gfx11. We emit it manually only on this architecture to ensure correct results. * add tests for supported data types * clarified block_histogram documentation fixed block_histogram test, to work with int8_t data * fix bfloat16 tests * update changelog * add comment explainers for commented test cases * Abstracted bit_cast away * Implemented in-place radix sort * Added decomposer arguments and decomposer checking * Sliced compilation of test_block_radix_sort * Implemented sort_comparator for custom test type * Testing custom test types * Moved out identity_decomposer from detail namespace * Added custom type benchmarks * Updated changelog * Fixed typo in changelog * Fixed formatting * Separately testing radix key codec * Fixed documentation * Fixed test utils comparator for custom_float_type * Instantiate test block_radix_sort custom_test_tye only when int128 is supported * Fixed compilation on Windows * Fixed tests for signed integral keys * Added extra radix_key_codec test cases * Fixed device_adjacent_difference kernel config selection and shared memory usage * Documentation fix after rebase * Added the memset to the graph for DeviceAdjacentDifferenceLargeTests * feat(predicate_iterator.hpp): added rocprim::predicate_iterator * fix(predicate_iterator.hpp): fix msvc build error due to implicit deletion of copy assignment * docs(predicate_iterator.hpp): fix doxygen build and improve consistency * refactor(predicate_iterator.hpp): improve naming consistency with other rocprim iterators and algorithms * fix(predicate_iterator.hpp): add missing out-of-class definition for operator+ * fix(predicate_iterator.hpp): fix predicate_iterator not reading value when underlying iterator dereferences to a non-reference type * test(benchmark_predicate_iterator.cpp): add benchmark for predicate iterator * refactor(predicate_iterator.hpp): remove unneeded << operator for std::ostream * docs(predicate_iterator.hpp): fix spelling * fix(predicate_iterator.hpp): drop 'const' from 'predicate_' as a class with const members cannot have an implicit copy assign * refactor(predicate_iterator.hpp): clean up * Replaced in `warp_sort_shuffle` where possible the `warp_shuffle_xor` for `warp_swizzle` * Replaced in `block_sort_bitonic` where possible the `warp_shuffle_xor` for `warp_swizzle` * Updated changelog * Generalize warp_swizzle_shuffle function for both block_sort_bitonic and warp_sort_shuffle * Fix docs warp_swizzle_shuffle * fix(predicate_iterator.hpp): drop constness of derference operator and deriviates This is required to relax the requirement of functions passed as predicate. * fix(predicate_iterator.hpp): derive proxy capture type from dereference operator instead of relying on iterator trait * test(test_predicate_iterator.cpp): extend predicate iterator type tests * remove old workaround comment * Fixed descending device_radix_sort for bool keys * Bool tests in test_device_radix_sort * Bool tests in test_block_radix_sort * Replaced std::getenv for Windows with _dupenv_s to prevent MSVC deprication warning * Removed malloc from linux version of the __get_env * Mark << operator as deprecated for iterators * device_radix_sort uses identity decomposer * Testing device_radix_sort with custom type [WIP] * Sorted overloads and updated docs for device_radix_sort * device_radix_sort public decomposer APIs * Fixed device_radix_sort with custom decomposer and added tests * Updated docs and changelog * Fixed building device_segmented_radix_sort * Added and tested additional device_radix_sort decomposers overloads * Enforce begin/end bit being default for floating point radix sort * Compile time dispatch and cleanup for radix_merge_compare * Added custom_key benchmarks for device_radix_sort * Fixed comparator dispatch in device_segmented_radix_sort * Removed duplicate test_device_radix_sort case * Iterating from MSB to LSB in decomposed radix_merge_compare * add decomposer argument * Added warp_exchange optimization with template recursion * Added tests to warp_exchange that can use the optimization * Warp exchange optimization using integer_sequence * Changed the shuffle warp_exchange to make use of swizzle and use a temp array. * Added extra benchmarks for warp_exchange where warpsize equals items per thread * Added non in place tests for warp exchange * Added documentation for optimized blocked_to_striped_shuffle and striped_to_blocked_shuffle functions of warp_exchange * Added some comments to warp_exchange * Use primes for tuning device adjacent difference * Removed accidental restriction of build targets in build:benchmark job * remove workaround for old compiler bug * Removed hotfix for double to double to __half conversion bug * style(test_utils_hipgraphs.hpp): add used includes * test: remove superfluous graph calls from tests All graph capturing and launching of host only rocprim calls are unneeded as they don't invoke device code. * fix(test_device_adjacent_difference.cpp): re-enable use of identity_operator on device_adjacent_difference tests * feat(test_device_adjacent_difference.cpp): add tests for void value_type in device_adjacent_difference * Update algorithms descriptions with non-bit-wise reproducibility * Added RadixBitsPerPass as parameter for the block_radix_sort * Add static_assert for RadixBitsPerPass from block_radix_sort * Added tests for block_radix_sort with different number of radix_bits_per_pass * Update CHANGELOG * Added some benchmarks with different radix bits per pass * Remove benchmarks with RadixBitsPerPass equal to 1 * Added check to partition kernel if size is smaller than items_per_block * Fixed benchmark_device_adjacent_difference formatting * Remove unnessary rocprim headers for more specific includes * Ordering includes of includes of benchmark files * Added detail includes only for direct usage of detail functions in benchmarks * declare shared memory at kernel level as workaround for non-optimized builds taking too long * increase build parallelism * add debug build and run to ci * fix leftover instance * fix copyright dates * debug benchmark builds * Deprecate TwiddleIn/TwiddleOut * Match radix_key_codec with radix_key_codec_inplace * Remove radix_key_codec_inplace * Make radix_key_codec part of the public API * Add radix_key_codec to sphinx docs * Add radix_key_codec tests for encode/decode/extract_digit consistency * Add static assert to ensure non-fundamental typed keys do not get an identity_decomposer * Add ROCPIM_PRAGMA_MESSAGE to warn about radix_sort.hpp functionality migration * Add test cases for block_histogram * Add test cases for block_exchange * Fix test_block_exchange to avoid UB * fix(thread_load.hpp): combine asm statements to fix broken behavior in debug builds 'rocprim::thread_load' used two consecutive asm declarations (a load and a wait) which allowed the compiler to insert code between the two instructions. This bug was only observed when compiling with '-O0'. By joining the two asm declarations, the compiler can no longer insert instructions between the load and wait, which would cause incomplete data to be used when it was dependent on the data being loaded. --------- Co-authored-by: Beatriz Navidad Vilches Co-authored-by: Nara Prasetya Co-authored-by: Lőrinc Serfőző Co-authored-by: Nick Breed Co-authored-by: Nol Moonen Co-authored-by: Balint Soproni --- .gitlab-ci.yml | 85 +- CHANGELOG.md | 33 + benchmark/CMakeLists.txt | 1 + .../benchmark_block_adjacent_difference.cpp | 19 +- benchmark/benchmark_block_discontinuity.cpp | 28 +- benchmark/benchmark_block_exchange.cpp | 26 +- benchmark/benchmark_block_histogram.cpp | 28 +- benchmark/benchmark_block_radix_rank.cpp | 21 +- benchmark/benchmark_block_radix_sort.cpp | 232 +- benchmark/benchmark_block_reduce.cpp | 26 +- .../benchmark_block_run_length_decode.cpp | 9 +- benchmark/benchmark_block_scan.cpp | 26 +- benchmark/benchmark_block_sort.cpp | 20 +- benchmark/benchmark_block_sort.parallel.hpp | 16 +- benchmark/benchmark_config_dispatch.cpp | 5 +- .../benchmark_device_adjacent_difference.cpp | 14 +- ...rk_device_adjacent_difference.parallel.hpp | 43 +- benchmark/benchmark_device_batch_memcpy.cpp | 267 +- benchmark/benchmark_device_binary_search.cpp | 27 +- ...enchmark_device_binary_search.parallel.hpp | 19 +- benchmark/benchmark_device_histogram.cpp | 23 +- .../benchmark_device_histogram.parallel.hpp | 15 +- benchmark/benchmark_device_memory.cpp | 22 +- benchmark/benchmark_device_merge.cpp | 26 +- benchmark/benchmark_device_merge_sort.cpp | 19 +- benchmark/benchmark_device_merge_sort.hpp | 13 +- ...enchmark_device_merge_sort_block_merge.cpp | 18 +- ...device_merge_sort_block_merge.parallel.hpp | 13 +- ...benchmark_device_merge_sort_block_sort.cpp | 18 +- ..._device_merge_sort_block_sort.parallel.hpp | 13 +- benchmark/benchmark_device_partition.cpp | 31 +- benchmark/benchmark_device_radix_sort.cpp | 14 +- benchmark/benchmark_device_radix_sort.hpp | 229 +- ...benchmark_device_radix_sort_block_sort.cpp | 16 +- ..._device_radix_sort_block_sort.parallel.hpp | 59 +- .../benchmark_device_radix_sort_onesweep.cpp | 13 +- ...rk_device_radix_sort_onesweep.parallel.hpp | 80 +- benchmark/benchmark_device_reduce.cpp | 17 +- .../benchmark_device_reduce.parallel.hpp | 14 +- benchmark/benchmark_device_reduce_by_key.cpp | 23 +- .../benchmark_device_run_length_encode.cpp | 23 +- benchmark/benchmark_device_scan.cpp | 17 +- benchmark/benchmark_device_scan.parallel.hpp | 14 +- benchmark/benchmark_device_scan_by_key.cpp | 17 +- .../benchmark_device_scan_by_key.parallel.hpp | 14 +- ...hmark_device_segmented_radix_sort_keys.cpp | 23 +- ...ice_segmented_radix_sort_keys.parallel.hpp | 14 +- ...mark_device_segmented_radix_sort_pairs.cpp | 21 +- ...ce_segmented_radix_sort_pairs.parallel.hpp | 14 +- .../benchmark_device_segmented_reduce.cpp | 23 +- benchmark/benchmark_device_select.cpp | 27 +- benchmark/benchmark_device_transform.cpp | 26 +- benchmark/benchmark_predicate_iterator.cpp | 243 ++ benchmark/benchmark_utils.hpp | 37 +- benchmark/benchmark_warp_exchange.cpp | 98 +- benchmark/benchmark_warp_reduce.cpp | 27 +- benchmark/benchmark_warp_scan.cpp | 24 +- benchmark/benchmark_warp_sort.cpp | 28 +- benchmark/cmdparser.hpp | 6 +- docs/block_ops/ops_classes/sort.rst | 7 +- docs/device_ops/device_copy.rst | 18 + docs/device_ops/index.rst | 1 + docs/device_ops/sort.rst | 50 +- docs/doxygen/Doxyfile | 1 + docs/doxygen/threadmodule.dox | 11 + docs/index.rst | 1 + docs/reference/iterators.rst | 14 + docs/reference/reference.rst | 1 + docs/sphinx/_toc.yml.in | 5 + docs/thread_ops/index.rst | 11 + docs/thread_ops/radix_key_codec.rst | 12 + .../include/rocprim/block/block_histogram.hpp | 10 +- .../rocprim/block/block_radix_rank.hpp | 4 +- .../rocprim/block/block_radix_sort.hpp | 404 +- .../block/detail/block_radix_rank_basic.hpp | 7 +- .../block/detail/block_radix_rank_match.hpp | 6 +- .../block/detail/block_sort_bitonic.hpp | 349 +- .../rocprim/detail/match_result_type.hpp | 50 + rocprim/include/rocprim/detail/radix_sort.hpp | 321 +- rocprim/include/rocprim/detail/various.hpp | 19 + .../include/rocprim/device/config_types.hpp | 2 +- .../config/device_adjacent_difference.hpp | 170 +- .../device_adjacent_difference_inplace.hpp | 172 +- .../detail/device_adjacent_difference.hpp | 9 +- .../device/detail/device_batch_memcpy.hpp | 354 +- .../device/detail/device_radix_sort.hpp | 346 +- .../detail/device_segmented_radix_sort.hpp | 70 +- .../device/device_adjacent_difference.hpp | 8 +- .../include/rocprim/device/device_copy.hpp | 150 + .../rocprim/device/device_copy_config.hpp | 67 + .../include/rocprim/device/device_memcpy.hpp | 219 +- .../rocprim/device/device_radix_sort.hpp | 3339 ++++++++++++++--- .../rocprim/device/device_reduce_by_key.hpp | 7 +- .../include/rocprim/device/device_scan.hpp | 14 +- .../rocprim/device/device_scan_by_key.hpp | 14 +- .../device/device_segmented_radix_sort.hpp | 4 +- .../device_radix_block_sort.hpp | 34 +- .../device_radix_merge_sort.hpp | 187 +- rocprim/include/rocprim/intrinsics/thread.hpp | 9 +- .../rocprim/intrinsics/warp_shuffle.hpp | 69 +- rocprim/include/rocprim/iterator.hpp | 3 +- .../rocprim/iterator/arg_index_iterator.hpp | 5 +- .../rocprim/iterator/constant_iterator.hpp | 4 +- .../rocprim/iterator/counting_iterator.hpp | 4 +- .../rocprim/iterator/discard_iterator.hpp | 5 +- .../rocprim/iterator/predicate_iterator.hpp | 300 ++ .../iterator/texture_cache_iterator.hpp | 5 +- .../rocprim/iterator/transform_iterator.hpp | 3 +- .../include/rocprim/iterator/zip_iterator.hpp | 4 +- rocprim/include/rocprim/rocprim.hpp | 5 +- .../rocprim/thread/radix_key_codec.hpp | 726 ++++ .../include/rocprim/thread/thread_load.hpp | 40 +- rocprim/include/rocprim/type_traits.hpp | 117 +- rocprim/include/rocprim/types.hpp | 8 +- .../rocprim/warp/detail/warp_sort_shuffle.hpp | 473 +-- .../include/rocprim/warp/warp_exchange.hpp | 219 +- test/common_test_header.hpp | 79 +- test/hipgraph/test_hipgraph_algs.cpp | 19 +- test/hipgraph/test_hipgraph_basic.cpp | 3 +- test/rocprim/CMakeLists.txt | 6 +- .../test_block_adjacent_difference.cpp.in | 9 +- ...test_block_adjacent_difference.kernels.hpp | 121 +- ...ty.cpp => test_block_discontinuity.cpp.in} | 21 +- .../test_block_discontinuity.kernels.hpp | 99 +- test/rocprim/test_block_exchange.kernels.hpp | 40 +- test/rocprim/test_block_histogram.kernels.hpp | 11 +- .../rocprim/test_block_load_store.kernels.hpp | 180 +- ..._sort.cpp => test_block_radix_sort.cpp.in} | 35 +- .../rocprim/test_block_radix_sort.kernels.hpp | 326 +- test/rocprim/test_block_reduce.hpp | 13 +- test/rocprim/test_block_run_length_decode.cpp | 4 +- test/rocprim/test_block_scan.hpp | 26 +- test/rocprim/test_block_sort.hpp | 6 +- .../test_device_adjacent_difference.cpp | 70 +- test/rocprim/test_device_batch_memcpy.cpp | 303 +- test/rocprim/test_device_binary_search.cpp | 95 +- test/rocprim/test_device_histogram.cpp | 216 +- test/rocprim/test_device_merge.cpp | 92 +- test/rocprim/test_device_merge_sort.cpp | 83 +- test/rocprim/test_device_partition.cpp | 278 +- test/rocprim/test_device_radix_sort.cpp.in | 15 +- test/rocprim/test_device_radix_sort.hpp | 1047 ++++-- test/rocprim/test_device_reduce.cpp | 183 +- test/rocprim/test_device_reduce_by_key.cpp | 125 +- .../rocprim/test_device_run_length_encode.cpp | 13 +- test/rocprim/test_device_scan.cpp | 236 +- test/rocprim/test_device_segmented_reduce.cpp | 24 +- test/rocprim/test_device_segmented_scan.cpp | 145 +- test/rocprim/test_device_select.cpp | 305 +- test/rocprim/test_device_transform.cpp | 79 +- test/rocprim/test_predicate_iterator.cpp | 288 ++ test/rocprim/test_radix_key_codec.cpp | 451 +++ .../test_temporary_storage_partitioning.cpp | 3 +- test/rocprim/test_utils.hpp | 57 +- test/rocprim/test_utils_assertions.hpp | 162 +- test/rocprim/test_utils_custom_float_type.hpp | 4 +- test/rocprim/test_utils_data_generation.hpp | 77 +- test/rocprim/test_utils_hipgraphs.hpp | 6 +- test/rocprim/test_utils_sort_comparator.hpp | 254 +- test/rocprim/test_utils_types.hpp | 126 +- test/rocprim/test_warp_exchange.cpp | 273 +- test/rocprim/test_warp_load.cpp | 41 +- test/rocprim/test_warp_reduce.hpp | 28 +- test/rocprim/test_warp_scan.hpp | 47 +- test/rocprim/test_warp_store.cpp | 41 +- 165 files changed, 11654 insertions(+), 5295 deletions(-) create mode 100644 benchmark/benchmark_predicate_iterator.cpp create mode 100644 docs/device_ops/device_copy.rst create mode 100644 docs/doxygen/threadmodule.dox create mode 100644 docs/thread_ops/index.rst create mode 100644 docs/thread_ops/radix_key_codec.rst create mode 100644 rocprim/include/rocprim/detail/match_result_type.hpp create mode 100644 rocprim/include/rocprim/device/device_copy.hpp create mode 100644 rocprim/include/rocprim/device/device_copy_config.hpp create mode 100644 rocprim/include/rocprim/iterator/predicate_iterator.hpp create mode 100644 rocprim/include/rocprim/thread/radix_key_codec.hpp rename test/rocprim/{test_block_discontinuity.cpp => test_block_discontinuity.cpp.in} (83%) rename test/rocprim/{test_block_radix_sort.cpp => test_block_radix_sort.cpp.in} (72%) create mode 100644 test/rocprim/test_predicate_iterator.cpp create mode 100644 test/rocprim/test_radix_key_codec.cpp diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 3bf7882b8..bbd9ceedc 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -169,10 +169,9 @@ build:cmake-minimum-apt: -G Ninja -D CMAKE_CXX_COMPILER="$AMDCLANG" -D CMAKE_CXX_FLAGS="-Wall -Wextra -Werror" - -D CMAKE_BUILD_TYPE=Release - -D BUILD_TEST=ON + -D CMAKE_BUILD_TYPE="$BUILD_TYPE" + -D BUILD_$BUILD_TARGET=ON -D BUILD_EXAMPLE=ON - -D BUILD_BENCHMARK=OFF -D GPU_TARGETS=$GPU_TARGETS -D AMDGPU_TEST_TARGETS=$GPU_TARGETS -S $CI_PROJECT_DIR @@ -180,14 +179,16 @@ build:cmake-minimum-apt: - cmake --build $BUILD_DIR artifacts: paths: - - $BUILD_DIR/test/test_* - - $BUILD_DIR/test/rocprim/test_* - - $BUILD_DIR/test/CTestTestfile.cmake - - $BUILD_DIR/test/rocprim/CTestTestfile.cmake - - $BUILD_DIR/gtest/ - - $BUILD_DIR/CMakeCache.txt - $BUILD_DIR/.ninja_log + - $BUILD_DIR/benchmark/* + - $BUILD_DIR/CMakeCache.txt - $BUILD_DIR/CTestTestfile.cmake + - $BUILD_DIR/deps/googlebenchmark/ + - $BUILD_DIR/gtest/ + - $BUILD_DIR/test/CTestTestfile.cmake + - $BUILD_DIR/test/rocprim/CTestTestfile.cmake + - $BUILD_DIR/test/rocprim/test_* + - $BUILD_DIR/test/test_* expire_in: 2 weeks build:cmake-latest: @@ -196,13 +197,19 @@ build:cmake-latest: extends: - .cmake-latest - .build:common + variables: + BUILD_TYPE: Release + BUILD_TARGET: TEST build:cmake-minimum: - stage: build needs: [] extends: - .cmake-minimum - .build:common + parallel: + matrix: + - BUILD_TYPE: [Debug, Release] + BUILD_TARGET: [BENCHMARK, TEST] build:package: stage: build @@ -229,39 +236,6 @@ build:package: - $PACKAGE_DIR/rocprim*.zip expire_in: 2 weeks -build:benchmark: - stage: build - needs: [] - tags: - - build - extends: - - .cmake-minimum - - .gpus:rocm-gpus - - .rules:build - script: -# If we have a custom config created by autotune:create-config - - "[ -d ${AUTOTUNE_CONFIG_DIR} ] && cp -r -f ${AUTOTUNE_CONFIG_DIR}/* ${CI_PROJECT_DIR}/" - - mkdir -p $BUILD_DIR - - cd $BUILD_DIR - - cmake - -B $BUILD_DIR - -S $CI_PROJECT_DIR - -G Ninja - -D CMAKE_CXX_COMPILER="$AMDCLANG" - -D CMAKE_CXX_FLAGS="-Wall -Wextra -Werror -Wno-#pragma-messages" - -D CMAKE_BUILD_TYPE=Release - -D BUILD_TEST=OFF - -D BUILD_EXAMPLE=OFF - -D BUILD_BENCHMARK=ON - -D GPU_TARGETS=$GPU_TARGETS - - cmake --build . - artifacts: - paths: - - $BUILD_DIR/benchmark/* - - $BUILD_DIR/.ninja_log - - $BUILD_DIR/deps/googlebenchmark/ - expire_in: 2 weeks - build:windows: stage: build needs: [] @@ -272,14 +246,9 @@ build:windows: - .deps:visual-studio-devshell parallel: matrix: - - BUILD_TYPE: - # Disabled due to extensive link times. - # This is tracked in issue 679 - #- Debug - - Release - BUILD_TARGET: - - BENCHMARK - - TEST + # Debug is disabled due to extensive link times, tracked in issue 679. + - BUILD_TYPE: [Release] + BUILD_TARGET: [BENCHMARK, TEST] script: - mkdir -p $CI_PROJECT_DIR/build - cmake -G Ninja @@ -360,7 +329,11 @@ test: - .rules:test - .gpus:rocm needs: - - build:cmake-minimum + - job: build:cmake-minimum + parallel: + matrix: + - BUILD_TYPE: Release + BUILD_TARGET: TEST script: - cd $BUILD_DIR - cmake @@ -475,7 +448,11 @@ test:docs: benchmark: needs: - - build:benchmark + - job: build:cmake-minimum + parallel: + matrix: + - BUILD_TYPE: Release + BUILD_TARGET: BENCHMARK extends: - .cmake-minimum - .gpus:rocm @@ -560,8 +537,6 @@ autotune:execute-tuning: paths: - ${AUTOTUNE_RESULT_DIR}/*.json script: - # Exclude benchmark that is known to fail on gfx906 - # On ROCm 5.7 or later, check if this can be removed - the presumption is that the failure is caused by a compiler issue. - > cd "${CI_PROJECT_DIR}" - | diff --git a/CHANGELOG.md b/CHANGELOG.md index 09a057021..2def66b97 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,12 +12,45 @@ Documentation for rocPRIM is available at * The default accumulator type is still the value type of the input iterator (inclusive scan) or the initial value's type (exclusive scan). This is the same behaviour as before this change. * New overload for `device_adjacent_difference_inplace` that allows separate input and output iterators, but allows them to point to the same element. +* New public API for deriving resulting type on device-only functions: + * `rocprim::invoke_result` + * `rocprim::invoke_result_t` + * `rocprim::invoke_result_binary_op` + * `rocprim::invoke_result_binary_op_t` +* New `rocprim::batch_copy` function added. Similar to `rocprim::batch_memcpy`, but copies by element, not with memcpy. +* Added more test cases, to better cover supported data types. +* Updated some tests to work with supported data types. +* An optional `decomposer` argument for all member functions of `rocprim::block_radix_sort` and all functions of `device_radix_sort`. + To sort keys of an user-defined type, a decomposer functor should be passed. The decomposer should produce a `rocprim::tuple` + of references to arithmetic types from the key. +* New `rocprim::predicate_iterator` which acts as a proxy for an underlying iterator based on a predicate. + It iterates over proxies that holds the references to the underlying values, but only allow reading and writing if the predicate is `true`. + It can be instantiated with: + * `rocprim::make_predicate_iterator` + * `rocprim::make_mask_iterator` +* Added custom radix sizes as the last parameter for `block_radix_sort`. The default value is 4, it can be a number between 0 and 32. +* New `rocprim::radix_key_codec`, which allows the encoding/decoding of keys for radix-based sorts. For user-defined key types, a decomposer functor should be passed. + +### Optimizations + +* Improved the performance of `warp_sort_shuffle` and `block_sort_bitonic`. +* Created an optimized version of the `warp_exchange` functions `blocked_to_striped_shuffle` and `striped_to_blocked_shuffle` when the warpsize is equal to the items per thread. ### Fixes * Fixed incorrect results of `warp_exchange::blocked_to_striped_shuffle` and `warp_exchange::striped_to_blocked_shuffle` when the block size is larger than the logical warp size. The test suite has been updated with such cases. * Fixed incorrect results returned when calling device `unique_by_key` with overlapping `values_input` and `values_output`. +* Fixed incorrect output type used in `device_adjacent_difference`. +* Hotfix for incorrect results on the GFX10 (Navi 10/RDNA1, Navi 20/RDNA2) ISA and GFX11 ISA (Navi 30 GPUs) on device scan algorithms `rocprim::inclusive_scan(_by_key)` and `rocprim::exclusive_scan(_by_key)` with large input types. +* `device_adjacent_difference` now considers both the input and the output type for selecting the appropriate kernel launch config. Previously only the input type was considered, which could result in compilation errors due to excessive shared memory usage. +* Fixed incorrect data being loaded with `rocprim::thread_load` when compiling with `-O0`. + +### Deprecations + +* The internal header `detail/match_result_type.hpp` has been deprecated. +* `TwiddleIn` and `TwiddleOut` have been deprecated in favor of `radix_key_codec`. +* The internal `::rocprim::detail::radix_key_codec` has been deprecated in favor of the new public utility with the same name. ## Unreleased rocPRIM-3.1.0 for ROCm 6.1.0 diff --git a/benchmark/CMakeLists.txt b/benchmark/CMakeLists.txt index 5bde961ed..d4e110c41 100644 --- a/benchmark/CMakeLists.txt +++ b/benchmark/CMakeLists.txt @@ -151,6 +151,7 @@ add_rocprim_benchmark(benchmark_device_segmented_radix_sort_keys.cpp) add_rocprim_benchmark(benchmark_device_segmented_radix_sort_pairs.cpp) add_rocprim_benchmark(benchmark_device_segmented_reduce.cpp) add_rocprim_benchmark(benchmark_device_transform.cpp) +add_rocprim_benchmark(benchmark_predicate_iterator.cpp) add_rocprim_benchmark(benchmark_warp_exchange.cpp) add_rocprim_benchmark(benchmark_warp_reduce.cpp) add_rocprim_benchmark(benchmark_warp_scan.cpp) diff --git a/benchmark/benchmark_block_adjacent_difference.cpp b/benchmark/benchmark_block_adjacent_difference.cpp index b2f74c6fa..f94e74091 100644 --- a/benchmark/benchmark_block_adjacent_difference.cpp +++ b/benchmark/benchmark_block_adjacent_difference.cpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2022 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2022-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -20,19 +20,22 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE // SOFTWARE. -#include -#include - -// Google Benchmark -#include "benchmark/benchmark.h" - +#include "benchmark_utils.hpp" // CmdParser #include "cmdparser.hpp" -#include "benchmark_utils.hpp" + +// Google Benchmark +#include // HIP API #include +// rocPRIM +#include +#include +#include + +#include #include #include #include diff --git a/benchmark/benchmark_block_discontinuity.cpp b/benchmark/benchmark_block_discontinuity.cpp index 5f572fc49..4483ef47b 100644 --- a/benchmark/benchmark_block_discontinuity.cpp +++ b/benchmark/benchmark_block_discontinuity.cpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2017-2022 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -20,24 +20,28 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE // SOFTWARE. -#include -#include -#include -#include -#include -#include - -// Google Benchmark -#include "benchmark/benchmark.h" // CmdParser -#include "cmdparser.hpp" #include "benchmark_utils.hpp" +#include "cmdparser.hpp" + +// Google Benchmark +#include // HIP API #include // rocPRIM -#include +#include +#include +#include + +#include +#include +#include +#include + +#include +#include #ifndef DEFAULT_N const size_t DEFAULT_N = 1024 * 1024 * 128; diff --git a/benchmark/benchmark_block_exchange.cpp b/benchmark/benchmark_block_exchange.cpp index 4225c4a3f..d82be888a 100644 --- a/benchmark/benchmark_block_exchange.cpp +++ b/benchmark/benchmark_block_exchange.cpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2017-2022 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -20,24 +20,28 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE // SOFTWARE. -#include -#include -#include -#include -#include -#include - -// Google Benchmark -#include "benchmark/benchmark.h" // CmdParser #include "cmdparser.hpp" #include "benchmark_utils.hpp" +// Google Benchmark +#include + // HIP API #include // rocPRIM -#include +#include +#include +#include + +#include +#include +#include +#include + +#include +#include #ifndef DEFAULT_N const size_t DEFAULT_N = 1024 * 1024 * 32; diff --git a/benchmark/benchmark_block_histogram.cpp b/benchmark/benchmark_block_histogram.cpp index f184590f4..676845f67 100644 --- a/benchmark/benchmark_block_histogram.cpp +++ b/benchmark/benchmark_block_histogram.cpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -20,24 +20,28 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE // SOFTWARE. -#include -#include -#include -#include -#include -#include - -// Google Benchmark -#include "benchmark/benchmark.h" +#include "benchmark_utils.hpp" // CmdParser #include "cmdparser.hpp" -#include "benchmark_utils.hpp" + +// Google Benchmark +#include // HIP API #include // rocPRIM -#include +#include +#include +#include + +#include +#include +#include +#include + +#include +#include #ifndef DEFAULT_N const size_t DEFAULT_N = 1024 * 1024 * 128; diff --git a/benchmark/benchmark_block_radix_rank.cpp b/benchmark/benchmark_block_radix_rank.cpp index a494d0d19..a2e28eb2e 100644 --- a/benchmark/benchmark_block_radix_rank.cpp +++ b/benchmark/benchmark_block_radix_rank.cpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2022 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2022-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -20,21 +20,24 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE // SOFTWARE. -#include -#include -#include - -// Google Benchmark -#include "benchmark/benchmark.h" -// CmdParser #include "benchmark_utils.hpp" +// CmdParser #include "cmdparser.hpp" +// Google Benchmark +#include + // HIP API #include // rocPRIM -#include +#include +#include +#include + +#include +#include +#include #ifndef DEFAULT_N const size_t DEFAULT_N = 1024 * 1024 * 128; diff --git a/benchmark/benchmark_block_radix_sort.cpp b/benchmark/benchmark_block_radix_sort.cpp index 7594cb39f..de7c3a541 100644 --- a/benchmark/benchmark_block_radix_sort.cpp +++ b/benchmark/benchmark_block_radix_sort.cpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2017-2022 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -20,24 +20,29 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE // SOFTWARE. -#include -#include -#include -#include -#include -#include - -// Google Benchmark -#include "benchmark/benchmark.h" +#include "benchmark_utils.hpp" // CmdParser #include "cmdparser.hpp" -#include "benchmark_utils.hpp" + +// Google Benchmark +#include // HIP API #include // rocPRIM -#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include +#include #ifndef DEFAULT_N const size_t DEFAULT_N = 1024 * 1024 * 128; @@ -51,15 +56,16 @@ enum class benchmark_kinds namespace rp = rocprim; -template< - class T, - unsigned int BlockSize, - unsigned int ItemsPerThread, - unsigned int Trials -> -__global__ -__launch_bounds__(BlockSize) -void sort_keys_kernel(const T * input, T * output) +template +using select_decomposer_t = std:: + conditional_t::value, custom_type_decomposer, rp::identity_decomposer>; + +template +__global__ __launch_bounds__(BlockSize) void sort_keys_kernel(const T* input, T* output) { const unsigned int lid = threadIdx.x; const unsigned int block_offset = blockIdx.x * ItemsPerThread * BlockSize; @@ -70,22 +76,26 @@ void sort_keys_kernel(const T * input, T * output) ROCPRIM_NO_UNROLL for(unsigned int trial = 0; trial < Trials; trial++) { - rp::block_radix_sort sort; - sort.sort(keys); + rp::block_radix_sort + sort; + sort.sort(keys, 0, sizeof(T) * 8, select_decomposer_t{}); } rp::block_store_direct_striped(lid, output + block_offset, keys); } -template< - class T, - unsigned int BlockSize, - unsigned int ItemsPerThread, - unsigned int Trials -> -__global__ -__launch_bounds__(BlockSize) -void sort_pairs_kernel(const T * input, T * output) +template +__global__ __launch_bounds__(BlockSize) void sort_pairs_kernel(const T* input, T* output) { const unsigned int lid = threadIdx.x; const unsigned int block_offset = blockIdx.x * ItemsPerThread * BlockSize; @@ -101,8 +111,8 @@ void sort_pairs_kernel(const T * input, T * output) ROCPRIM_NO_UNROLL for(unsigned int trial = 0; trial < Trials; trial++) { - rp::block_radix_sort sort; - sort.sort(keys, values); + rp::block_radix_sort sort; + sort.sort(keys, values, 0, sizeof(T) * 8, select_decomposer_t{}); } for(unsigned int i = 0; i < ItemsPerThread; i++) @@ -112,13 +122,15 @@ void sort_pairs_kernel(const T * input, T * output) rp::block_store_direct_striped(lid, output + block_offset, keys); } -template< - class T, - unsigned int BlockSize, - unsigned int ItemsPerThread, - unsigned int Trials = 10 -> -void run_benchmark(benchmark::State& state, benchmark_kinds benchmark_kind, hipStream_t stream, size_t N) +template +void run_benchmark(benchmark::State& state, + benchmark_kinds benchmark_kind, + hipStream_t stream, + size_t N) { constexpr auto items_per_block = BlockSize * ItemsPerThread; const auto size = items_per_block * ((N + items_per_block - 1)/items_per_block); @@ -162,18 +174,26 @@ void run_benchmark(benchmark::State& state, benchmark_kinds benchmark_kind, hipS if(benchmark_kind == benchmark_kinds::sort_keys) { hipLaunchKernelGGL( - HIP_KERNEL_NAME(sort_keys_kernel), - dim3(size/items_per_block), dim3(BlockSize), 0, stream, - d_input, d_output - ); + HIP_KERNEL_NAME( + sort_keys_kernel), + dim3(size / items_per_block), + dim3(BlockSize), + 0, + stream, + d_input, + d_output); } else if(benchmark_kind == benchmark_kinds::sort_pairs) { hipLaunchKernelGGL( - HIP_KERNEL_NAME(sort_pairs_kernel), - dim3(size/items_per_block), dim3(BlockSize), 0, stream, - d_input, d_output - ); + HIP_KERNEL_NAME( + sort_pairs_kernel), + dim3(size / items_per_block), + dim3(BlockSize), + 0, + stream, + d_input, + d_output); } HIP_CHECK(hipGetLastError()); @@ -197,22 +217,20 @@ void run_benchmark(benchmark::State& state, benchmark_kinds benchmark_kind, hipS HIP_CHECK(hipFree(d_output)); } -#define CREATE_BENCHMARK(T, BS, IPT) \ +#define CREATE_BENCHMARK(T, BS, RB, IPT) \ benchmark::RegisterBenchmark( \ bench_naming::format_name("{lvl:block,algo:radix_sort,key_type:" #T ",subalgo:" + name \ - + ",cfg:{bs:" #BS ",ipt:" #IPT "}}") \ + + ",cfg:{bs:" #BS ",rb:" #RB ",ipt:" #IPT "}}") \ .c_str(), \ - run_benchmark, \ + run_benchmark, \ benchmark_kind, \ stream, \ size) -#define BENCHMARK_TYPE(type, block) \ - CREATE_BENCHMARK(type, block, 1), \ - CREATE_BENCHMARK(type, block, 2), \ - CREATE_BENCHMARK(type, block, 3), \ - CREATE_BENCHMARK(type, block, 4), \ - CREATE_BENCHMARK(type, block, 8) +#define BENCHMARK_TYPE(type, block, radix_bits) \ + CREATE_BENCHMARK(type, block, radix_bits, 1), CREATE_BENCHMARK(type, block, radix_bits, 2), \ + CREATE_BENCHMARK(type, block, radix_bits, 3), \ + CREATE_BENCHMARK(type, block, radix_bits, 4), CREATE_BENCHMARK(type, block, radix_bits, 8) void add_benchmarks(benchmark_kinds benchmark_kind, const std::string& name, @@ -220,42 +238,68 @@ void add_benchmarks(benchmark_kinds benchmark_kind, hipStream_t stream, size_t size) { - std::vector bs = - { - BENCHMARK_TYPE(int, 64), - BENCHMARK_TYPE(int, 128), - BENCHMARK_TYPE(int, 192), - BENCHMARK_TYPE(int, 256), - BENCHMARK_TYPE(int, 320), - BENCHMARK_TYPE(int, 512), - - BENCHMARK_TYPE(int8_t, 64), - BENCHMARK_TYPE(int8_t, 128), - BENCHMARK_TYPE(int8_t, 192), - BENCHMARK_TYPE(int8_t, 256), - BENCHMARK_TYPE(int8_t, 320), - BENCHMARK_TYPE(int8_t, 512), - - BENCHMARK_TYPE(uint8_t, 64), - BENCHMARK_TYPE(uint8_t, 128), - BENCHMARK_TYPE(uint8_t, 192), - BENCHMARK_TYPE(uint8_t, 256), - BENCHMARK_TYPE(uint8_t, 320), - BENCHMARK_TYPE(uint8_t, 512), - - BENCHMARK_TYPE(rocprim::half, 64), - BENCHMARK_TYPE(rocprim::half, 128), - BENCHMARK_TYPE(rocprim::half, 192), - BENCHMARK_TYPE(rocprim::half, 256), - BENCHMARK_TYPE(rocprim::half, 320), - BENCHMARK_TYPE(rocprim::half, 512), - - BENCHMARK_TYPE(long long, 64), - BENCHMARK_TYPE(long long, 128), - BENCHMARK_TYPE(long long, 192), - BENCHMARK_TYPE(long long, 256), - BENCHMARK_TYPE(long long, 320), - BENCHMARK_TYPE(long long, 512), + using custom_int_type = custom_type; + + std::vector bs = { + BENCHMARK_TYPE(int, 64, 3), + BENCHMARK_TYPE(int, 512, 3), + + BENCHMARK_TYPE(int, 64, 4), + BENCHMARK_TYPE(int, 128, 4), + BENCHMARK_TYPE(int, 192, 4), + BENCHMARK_TYPE(int, 256, 4), + BENCHMARK_TYPE(int, 320, 4), + BENCHMARK_TYPE(int, 512, 4), + + BENCHMARK_TYPE(int8_t, 64, 3), + BENCHMARK_TYPE(int8_t, 512, 3), + + BENCHMARK_TYPE(int8_t, 64, 4), + BENCHMARK_TYPE(int8_t, 128, 4), + BENCHMARK_TYPE(int8_t, 192, 4), + BENCHMARK_TYPE(int8_t, 256, 4), + BENCHMARK_TYPE(int8_t, 320, 4), + BENCHMARK_TYPE(int8_t, 512, 4), + + BENCHMARK_TYPE(uint8_t, 64, 3), + BENCHMARK_TYPE(uint8_t, 512, 3), + + BENCHMARK_TYPE(uint8_t, 64, 4), + BENCHMARK_TYPE(uint8_t, 128, 4), + BENCHMARK_TYPE(uint8_t, 192, 4), + BENCHMARK_TYPE(uint8_t, 256, 4), + BENCHMARK_TYPE(uint8_t, 320, 4), + BENCHMARK_TYPE(uint8_t, 512, 4), + + BENCHMARK_TYPE(rocprim::half, 64, 3), + BENCHMARK_TYPE(rocprim::half, 512, 3), + + BENCHMARK_TYPE(rocprim::half, 64, 4), + BENCHMARK_TYPE(rocprim::half, 128, 4), + BENCHMARK_TYPE(rocprim::half, 192, 4), + BENCHMARK_TYPE(rocprim::half, 256, 4), + BENCHMARK_TYPE(rocprim::half, 320, 4), + BENCHMARK_TYPE(rocprim::half, 512, 4), + + BENCHMARK_TYPE(long long, 64, 3), + BENCHMARK_TYPE(long long, 512, 3), + + BENCHMARK_TYPE(long long, 64, 4), + BENCHMARK_TYPE(long long, 128, 4), + BENCHMARK_TYPE(long long, 192, 4), + BENCHMARK_TYPE(long long, 256, 4), + BENCHMARK_TYPE(long long, 320, 4), + BENCHMARK_TYPE(long long, 512, 4), + + BENCHMARK_TYPE(custom_int_type, 64, 3), + BENCHMARK_TYPE(custom_int_type, 512, 3), + + BENCHMARK_TYPE(custom_int_type, 64, 4), + BENCHMARK_TYPE(custom_int_type, 128, 4), + BENCHMARK_TYPE(custom_int_type, 192, 4), + BENCHMARK_TYPE(custom_int_type, 256, 4), + BENCHMARK_TYPE(custom_int_type, 320, 4), + BENCHMARK_TYPE(custom_int_type, 512, 4), }; benchmarks.insert(benchmarks.end(), bs.begin(), bs.end()); diff --git a/benchmark/benchmark_block_reduce.cpp b/benchmark/benchmark_block_reduce.cpp index d6fa3de1e..446649040 100644 --- a/benchmark/benchmark_block_reduce.cpp +++ b/benchmark/benchmark_block_reduce.cpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2017-2022 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -20,24 +20,26 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE // SOFTWARE. -#include -#include -#include -#include -#include -#include - -// Google Benchmark -#include "benchmark/benchmark.h" +#include "benchmark_utils.hpp" // CmdParser #include "cmdparser.hpp" -#include "benchmark_utils.hpp" + +// Google Benchmark +#include // HIP API #include // rocPRIM -#include +#include + +#include +#include +#include +#include + +#include +#include #ifndef DEFAULT_N const size_t DEFAULT_N = 1024 * 1024 * 32; diff --git a/benchmark/benchmark_block_run_length_decode.cpp b/benchmark/benchmark_block_run_length_decode.cpp index 34adfdc14..7d10630fd 100644 --- a/benchmark/benchmark_block_run_length_decode.cpp +++ b/benchmark/benchmark_block_run_length_decode.cpp @@ -20,13 +20,14 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE // SOFTWARE. -#include "benchmark/benchmark.h" #include "benchmark_utils.hpp" #include "cmdparser.hpp" -#include "rocprim/block/block_load.hpp" -#include "rocprim/block/block_run_length_decode.hpp" -#include "rocprim/block/block_store.hpp" +#include + +#include +#include +#include #include #include diff --git a/benchmark/benchmark_block_scan.cpp b/benchmark/benchmark_block_scan.cpp index 84f1a27ba..071312b62 100644 --- a/benchmark/benchmark_block_scan.cpp +++ b/benchmark/benchmark_block_scan.cpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2017-2022 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -20,24 +20,26 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE // SOFTWARE. -#include -#include -#include -#include -#include -#include - -// Google Benchmark -#include "benchmark/benchmark.h" +#include "benchmark_utils.hpp" // CmdParser #include "cmdparser.hpp" -#include "benchmark_utils.hpp" + +// Google Benchmark +#include // HIP API #include // rocPRIM -#include +#include + +#include +#include +#include +#include + +#include +#include #ifndef DEFAULT_N const size_t DEFAULT_N = 1024 * 1024 * 32; diff --git a/benchmark/benchmark_block_sort.cpp b/benchmark/benchmark_block_sort.cpp index 29948a364..14c734719 100644 --- a/benchmark/benchmark_block_sort.cpp +++ b/benchmark/benchmark_block_sort.cpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -20,9 +20,11 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE // SOFTWARE. -#include -#include -#include +#include "benchmark_block_sort.parallel.hpp" +#include "benchmark_utils.hpp" + +// CmdParser +#include "cmdparser.hpp" // Google Benchmark #include @@ -30,11 +32,13 @@ // HIP API #include -// CmdParser -#include "cmdparser.hpp" +// rocPRIM +#include -#include "benchmark_block_sort.parallel.hpp" -#include "benchmark_utils.hpp" +#include +#include + +#include #ifndef DEFAULT_N const size_t DEFAULT_N = 1024 * 1024 * 128; diff --git a/benchmark/benchmark_block_sort.parallel.hpp b/benchmark/benchmark_block_sort.parallel.hpp index dfdc6e5b9..e7138d197 100644 --- a/benchmark/benchmark_block_sort.parallel.hpp +++ b/benchmark/benchmark_block_sort.parallel.hpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2019-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2019-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -23,9 +23,7 @@ #ifndef ROCPRIM_BENCHMARK_BLOCK_SORT_PARALLEL_HPP_ #define ROCPRIM_BENCHMARK_BLOCK_SORT_PARALLEL_HPP_ -#include -#include -#include +#include "benchmark_utils.hpp" // Google Benchmark #include @@ -34,9 +32,15 @@ #include // rocPRIM -#include +#include +#include +#include +#include -#include "benchmark_utils.hpp" +#include +#include + +#include template #include @@ -6,8 +9,6 @@ #include -#include "benchmark_utils.hpp" - enum class stream_kind { default_stream, diff --git a/benchmark/benchmark_device_adjacent_difference.cpp b/benchmark/benchmark_device_adjacent_difference.cpp index c2fbe9ee2..5b64955f3 100644 --- a/benchmark/benchmark_device_adjacent_difference.cpp +++ b/benchmark/benchmark_device_adjacent_difference.cpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2022 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2022-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -20,9 +20,8 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE // SOFTWARE. -#include -#include -#include +#include "benchmark_device_adjacent_difference.parallel.hpp" +#include "benchmark_utils.hpp" // Google Benchmark #include @@ -31,14 +30,15 @@ #include // rocPRIM -#include #include // CmdParser #include "cmdparser.hpp" -#include "benchmark_device_adjacent_difference.parallel.hpp" -#include "benchmark_utils.hpp" +#include +#include + +#include #ifndef DEFAULT_N constexpr std::size_t DEFAULT_N = 1024 * 1024 * 128; diff --git a/benchmark/benchmark_device_adjacent_difference.parallel.hpp b/benchmark/benchmark_device_adjacent_difference.parallel.hpp index 51f8cdfda..5dfcc2499 100644 --- a/benchmark/benchmark_device_adjacent_difference.parallel.hpp +++ b/benchmark/benchmark_device_adjacent_difference.parallel.hpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2019-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2019-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -23,9 +23,7 @@ #ifndef ROCPRIM_BENCHMARK_DEVICE_ADJACENT_DIFFERENCE_PARALLEL_HPP_ #define ROCPRIM_BENCHMARK_DEVICE_ADJACENT_DIFFERENCE_PARALLEL_HPP_ -#include -#include -#include +#include "benchmark_utils.hpp" // Google Benchmark #include @@ -34,10 +32,13 @@ #include // rocPRIM -#include #include +#include -#include "benchmark_utils.hpp" +#include +#include + +#include template std::string config_name() @@ -234,28 +235,38 @@ struct device_adjacent_difference_benchmark : public config_autotune_interface template struct device_adjacent_difference_benchmark_generator { + static constexpr unsigned int min_items_per_thread = 0; + static constexpr unsigned int max_items_per_thread_arg + = TUNING_SHARED_MEMORY_MAX / (BlockSize * sizeof(T) * 2 + sizeof(T)); - template + template struct create_ipt { - using generated_config - = rocprim::adjacent_difference_config; + // Device Adjacent difference uses block_load/store_transpose to coalesc memory transaction to global memory + // However it accesses shared memory with a stride of items per thread, which leads to reduced performance if power + // of two is used for small types. Experiments shown that primes are the best choice for performance. + static constexpr int primes[] = {1, 2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, + 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97}; + static constexpr uint ipt_num = primes[IptValueIndex]; + using generated_config = rocprim::adjacent_difference_config; void operator()(std::vector>& storage) { - storage.emplace_back( - std::make_unique< - device_adjacent_difference_benchmark>()); + if(ipt_num < max_items_per_thread_arg) + { + storage.emplace_back( + std::make_unique>()); + } } }; static void create(std::vector>& storage) { - static constexpr unsigned int min_items_per_thread = 1; - static constexpr unsigned int max_items_per_thread_arg - = TUNING_SHARED_MEMORY_MAX / (BlockSize * sizeof(T) * 2 + sizeof(T)); static constexpr unsigned int max_items_per_thread - = rocprim::Log2::VALUE - 1; + = rocprim::Log2::VALUE; static_for_each, create_ipt>(storage); } diff --git a/benchmark/benchmark_device_batch_memcpy.cpp b/benchmark/benchmark_device_batch_memcpy.cpp index fde851cdd..b22b13a87 100644 --- a/benchmark/benchmark_device_batch_memcpy.cpp +++ b/benchmark/benchmark_device_batch_memcpy.cpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -20,14 +20,18 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE // SOFTWARE. -#include "benchmark/benchmark.h" #include "benchmark_utils.hpp" #include "cmdparser.hpp" -#include - +#include #include +// rocPRIM +#include +#include +#include +#include + #include #include #include @@ -84,6 +88,98 @@ std::vector shuffled_exclusive_scan(const std::vector& input, RandomGenera using offset_type = size_t; +template::type = 0> +void init_input(ContainerMemCpy& h_input_for_memcpy, + ContainerCopy& /*h_input_for_copy*/, + std::mt19937_64& rng, + offset_type total_num_bytes) +{ + std::independent_bits_engine bits_engine{rng}; + + const size_t num_ints = rocprim::detail::ceiling_div(total_num_bytes, sizeof(uint64_t)); + h_input_for_memcpy = std::vector(num_ints * sizeof(uint64_t)); + + // generate_n for uninitialized memory, pragmatically use placement-new, since there are no + // uint64_t objects alive yet in the storage. + std::for_each( + reinterpret_cast(h_input_for_memcpy.data()), + reinterpret_cast(h_input_for_memcpy.data() + num_ints * sizeof(uint64_t)), + [&bits_engine](uint64_t& elem) { ::new(&elem) uint64_t{bits_engine()}; }); +} + +template::type = 0> +void init_input(ContainerMemCpy& /*h_input_for_memcpy*/, + ContainerCopy& h_input_for_copy, + std::mt19937_64& rng, + byte_offset_type total_num_bytes) +{ + using value_type = typename ContainerCopy::value_type; + + std::independent_bits_engine bits_engine{rng}; + + const size_t num_ints = rocprim::detail::ceiling_div(total_num_bytes, sizeof(uint64_t)); + const size_t num_of_elements + = rocprim::detail::ceiling_div(num_ints * sizeof(uint64_t), sizeof(value_type)); + h_input_for_copy = std::vector(num_of_elements); + + // generate_n for uninitialized memory, pragmatically use placement-new, since there are no + // uint64_t objects alive yet in the storage. + std::for_each(reinterpret_cast(h_input_for_copy.data()), + reinterpret_cast(h_input_for_copy.data()) + num_ints, + [&bits_engine](uint64_t& elem) { ::new(&elem) uint64_t{bits_engine()}; }); +} + +template::type = 0> +void batch_copy(void* temporary_storage, + size_t& storage_size, + InputBufferItType sources, + OutputBufferItType destinations, + BufferSizeItType sizes, + uint32_t num_copies, + hipStream_t stream) +{ + HIP_CHECK(rocprim::batch_memcpy(temporary_storage, + storage_size, + sources, + destinations, + sizes, + num_copies, + stream)); +} + +template::type = 0> +void batch_copy(void* temporary_storage, + size_t& storage_size, + InputBufferItType sources, + OutputBufferItType destinations, + BufferSizeItType sizes, + uint32_t num_copies, + hipStream_t stream) +{ + HIP_CHECK(rocprim::batch_copy(temporary_storage, + storage_size, + sources, + destinations, + sizes, + num_copies, + stream)); +} + template struct BatchMemcpyData { @@ -134,7 +230,7 @@ struct BatchMemcpyData } }; -template +template BatchMemcpyData prepare_data(const int32_t num_tlev_buffers = 1024, const int32_t num_wlev_buffers = 1024, const int32_t num_blev_buffers = 1024) @@ -175,16 +271,12 @@ BatchMemcpyData prepare_data(const int32_t num_tlev_b result.total_num_elements = std::accumulate(h_buffer_num_elements.begin(), h_buffer_num_elements.end(), size_t{0}); - // Generate data. - std::independent_bits_engine bits_engine{rng}; - - const size_t num_ints - = rocprim::detail::ceiling_div(result.total_num_bytes(), sizeof(uint64_t)); - auto h_input = std::make_unique(num_ints * sizeof(uint64_t)); - - std::for_each(reinterpret_cast(h_input.get()), - reinterpret_cast(h_input.get() + num_ints * sizeof(uint64_t)), - [&bits_engine](uint64_t& elem) { ::new(&elem) uint64_t{bits_engine()}; }); + std::vector h_input_for_memcpy; + std::vector h_input_for_copy; + init_input(h_input_for_memcpy, + h_input_for_copy, + rng, + result.total_num_elements * sizeof(ValueType)); HIP_CHECK(hipMalloc(&result.d_input, result.total_num_bytes())); HIP_CHECK(hipMalloc(&result.d_output, result.total_num_bytes())); @@ -228,8 +320,28 @@ BatchMemcpyData prepare_data(const int32_t num_tlev_b } // Prepare the batch memcpy. - HIP_CHECK( - hipMemcpy(result.d_input, h_input.get(), result.total_num_bytes(), hipMemcpyHostToDevice)); + if(IsMemCpy) + { + HIP_CHECK(hipMemcpy(result.d_input, + h_input_for_memcpy.data(), + result.total_num_bytes(), + hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(result.d_buffer_sizes, + h_buffer_num_bytes.data(), + h_buffer_num_bytes.size() * sizeof(BufferSizeType), + hipMemcpyHostToDevice)); + } + else + { + HIP_CHECK(hipMemcpy(result.d_input, + h_input_for_copy.data(), + result.total_num_bytes(), + hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(result.d_buffer_sizes, + h_buffer_num_elements.data(), + h_buffer_num_elements.size() * sizeof(BufferSizeType), + hipMemcpyHostToDevice)); + } HIP_CHECK(hipMemcpy(result.d_buffer_srcs, h_buffer_srcs.data(), h_buffer_srcs.size() * sizeof(ValueType*), @@ -238,15 +350,11 @@ BatchMemcpyData prepare_data(const int32_t num_tlev_b h_buffer_dsts.data(), h_buffer_dsts.size() * sizeof(ValueType*), hipMemcpyHostToDevice)); - HIP_CHECK(hipMemcpy(result.d_buffer_sizes, - h_buffer_num_bytes.data(), - h_buffer_num_bytes.size() * sizeof(BufferSizeType), - hipMemcpyHostToDevice)); return result; } -template +template void run_benchmark(benchmark::State& state, hipStream_t stream, const int32_t num_tlev_buffers = 1024, @@ -257,30 +365,31 @@ void run_benchmark(benchmark::State& state, size_t temp_storage_bytes = 0; BatchMemcpyData data; - HIP_CHECK(rocprim::batch_memcpy(nullptr, - temp_storage_bytes, - data.d_buffer_srcs, - data.d_buffer_dsts, - data.d_buffer_sizes, - num_buffers)); + batch_copy(nullptr, + temp_storage_bytes, + data.d_buffer_srcs, + data.d_buffer_dsts, + data.d_buffer_sizes, + num_buffers, + stream); void* d_temp_storage = nullptr; HIP_CHECK(hipMalloc(&d_temp_storage, temp_storage_bytes)); - data = prepare_data(num_tlev_buffers, - num_wlev_buffers, - num_blev_buffers); + data = prepare_data(num_tlev_buffers, + num_wlev_buffers, + num_blev_buffers); // Warm-up for(size_t i = 0; i < warmup_size; i++) { - HIP_CHECK(rocprim::batch_memcpy(d_temp_storage, - temp_storage_bytes, - data.d_buffer_srcs, - data.d_buffer_dsts, - data.d_buffer_sizes, - num_buffers, - stream)); + batch_copy(d_temp_storage, + temp_storage_bytes, + data.d_buffer_srcs, + data.d_buffer_dsts, + data.d_buffer_sizes, + num_buffers, + stream); } HIP_CHECK(hipDeviceSynchronize()); @@ -294,13 +403,13 @@ void run_benchmark(benchmark::State& state, // Record start event HIP_CHECK(hipEventRecord(start, stream)); - HIP_CHECK(rocprim::batch_memcpy(d_temp_storage, - temp_storage_bytes, - data.d_buffer_srcs, - data.d_buffer_dsts, - data.d_buffer_sizes, - num_buffers, - stream)); + batch_copy(d_temp_storage, + temp_storage_bytes, + data.d_buffer_srcs, + data.d_buffer_dsts, + data.d_buffer_sizes, + num_buffers, + stream); // Record stop event and wait until it completes HIP_CHECK(hipEventRecord(stop, stream)); @@ -313,6 +422,9 @@ void run_benchmark(benchmark::State& state, state.SetBytesProcessed(state.iterations() * data.total_num_bytes()); state.SetItemsProcessed(state.iterations() * data.total_num_elements); + HIP_CHECK(hipEventDestroy(start)); + HIP_CHECK(hipEventDestroy(stop)); + HIP_CHECK(hipFree(d_temp_storage)); } @@ -356,7 +468,7 @@ __launch_bounds__(BlockSize) __global__ } } -template +template void run_naive_benchmark(benchmark::State& state, hipStream_t stream, const int32_t num_tlev_buffers = 1024, @@ -365,9 +477,9 @@ void run_naive_benchmark(benchmark::State& state, { const size_t num_buffers = num_tlev_buffers + num_wlev_buffers + num_blev_buffers; - const auto data = prepare_data(num_tlev_buffers, - num_wlev_buffers, - num_blev_buffers); + const auto data = prepare_data(num_tlev_buffers, + num_wlev_buffers, + num_blev_buffers); // Warm-up for(size_t i = 0; i < warmup_size; i++) @@ -404,28 +516,28 @@ void run_naive_benchmark(benchmark::State& state, } state.SetBytesProcessed(state.iterations() * data.total_num_bytes()); state.SetItemsProcessed(state.iterations() * data.total_num_elements); + + HIP_CHECK(hipEventDestroy(start)); + HIP_CHECK(hipEventDestroy(stop)); } - #define CREATE_NAIVE_BENCHMARK(item_size, \ - item_alignment, \ - size_type, \ - num_tlev, \ - num_wlev, \ - num_blev) \ - benchmark::RegisterBenchmark( \ - bench_naming::format_name( \ - "{lvl:device,item_size:" #item_size ",item_alignment:" #item_alignment \ - ",size_type:" #size_type ",algo:naive_memcpy,num_tlev:" #num_tlev \ - ",num_wlev:" #num_wlev ",num_blev:" #num_blev ",cfg:default_config}") \ - .c_str(), \ - [=](benchmark::State& state) \ - { \ - run_naive_benchmark, size_type>( \ - state, \ - stream, \ - num_tlev, \ - num_wlev, \ - num_blev); \ + #define CREATE_NAIVE_BENCHMARK(item_size, \ + item_alignment, \ + size_type, \ + num_tlev, \ + num_wlev, \ + num_blev) \ + benchmark::RegisterBenchmark( \ + bench_naming::format_name( \ + "{lvl:device,item_size:" #item_size ",item_alignment:" #item_alignment \ + ",size_type:" #size_type ",algo:naive_memcpy,num_tlev:" #num_tlev \ + ",num_wlev:" #num_wlev ",num_blev:" #num_blev ",cfg:default_config}") \ + .c_str(), \ + [=](benchmark::State& state) \ + { \ + run_naive_benchmark, \ + size_type, \ + true>(state, stream, num_tlev, num_wlev, num_blev); \ }) #endif @@ -439,11 +551,18 @@ void run_naive_benchmark(benchmark::State& state, .c_str(), \ [=](benchmark::State& state) \ { \ - run_benchmark, size_type>(state, \ - stream, \ - num_tlev, \ - num_wlev, \ - num_blev); \ + run_benchmark, size_type, true>( \ + state, \ + stream, \ + num_tlev, \ + num_wlev, \ + num_blev); \ + run_benchmark, size_type, false>( \ + state, \ + stream, \ + num_tlev, \ + num_wlev, \ + num_blev); \ }) #ifndef BENCHMARK_BATCH_MEMCPY_NAIVE diff --git a/benchmark/benchmark_device_binary_search.cpp b/benchmark/benchmark_device_binary_search.cpp index 77353b6a1..e9a859c27 100644 --- a/benchmark/benchmark_device_binary_search.cpp +++ b/benchmark/benchmark_device_binary_search.cpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2019-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2019-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -20,27 +20,28 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE // SOFTWARE. -#include -#include -#include -#include -#include -#include +#include "benchmark_device_binary_search.parallel.hpp" -// Google Benchmark -#include "benchmark/benchmark.h" +#include "benchmark_utils.hpp" // CmdParser #include "cmdparser.hpp" -#include "benchmark_utils.hpp" + +// Google Benchmark +#include // HIP API #include // rocPRIM -#include +#include -#include "benchmark_device_binary_search.parallel.hpp" -#include "rocprim/device/config_types.hpp" +#include +#include +#include +#include + +#include +#include #ifndef DEFAULT_N const size_t DEFAULT_N = 1024 * 1024 * 32; diff --git a/benchmark/benchmark_device_binary_search.parallel.hpp b/benchmark/benchmark_device_binary_search.parallel.hpp index cb9ff3c79..384349ddf 100644 --- a/benchmark/benchmark_device_binary_search.parallel.hpp +++ b/benchmark/benchmark_device_binary_search.parallel.hpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -23,16 +23,19 @@ #ifndef ROCPRIM_BENCHMARK_BINARY_SEARCH_PARALLEL_HPP_ #define ROCPRIM_BENCHMARK_BINARY_SEARCH_PARALLEL_HPP_ -#include -#include -#include - #include "benchmark_utils.hpp" -#include "rocprim/device/config_types.hpp" -#include "rocprim/device/detail/device_config_helper.hpp" + +#include +#include + #include + #include -#include + +#include +#include + +#include struct binary_search_subalgorithm { diff --git a/benchmark/benchmark_device_histogram.cpp b/benchmark/benchmark_device_histogram.cpp index 2218c4556..079471c46 100644 --- a/benchmark/benchmark_device_histogram.cpp +++ b/benchmark/benchmark_device_histogram.cpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -20,25 +20,26 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE // SOFTWARE. -#include -#include -#include -#include -#include +#include "benchmark_device_histogram.parallel.hpp" +#include "benchmark_utils.hpp" -// Google Benchmark -#include "benchmark/benchmark.h" // CmdParser #include "cmdparser.hpp" +// Google Benchmark +#include + // HIP API #include // rocPRIM -#include +#include -#include "benchmark_device_histogram.parallel.hpp" -#include "benchmark_utils.hpp" +#include +#include +#include +#include +#include #ifndef DEFAULT_N const size_t DEFAULT_N = 1024 * 1024 * 32; diff --git a/benchmark/benchmark_device_histogram.parallel.hpp b/benchmark/benchmark_device_histogram.parallel.hpp index 146c1cc40..1f18d4503 100644 --- a/benchmark/benchmark_device_histogram.parallel.hpp +++ b/benchmark/benchmark_device_histogram.parallel.hpp @@ -23,11 +23,7 @@ #ifndef ROCPRIM_BENCHMARK_DEVICE_HISTOGRAM_PARALLEL_HPP_ #define ROCPRIM_BENCHMARK_DEVICE_HISTOGRAM_PARALLEL_HPP_ -#include -#include -#include -#include -#include +#include "benchmark_utils.hpp" // Google Benchmark #include @@ -36,10 +32,15 @@ #include // rocPRIM -#include +#include #include -#include "benchmark_utils.hpp" +#include +#include +#include + +#include +#include template std::vector generate(size_t size, int entropy_reduction, int lower_level, int upper_level) diff --git a/benchmark/benchmark_device_memory.cpp b/benchmark/benchmark_device_memory.cpp index b4ed1bc1c..47e999e90 100644 --- a/benchmark/benchmark_device_memory.cpp +++ b/benchmark/benchmark_device_memory.cpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2018-2022 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -20,19 +20,21 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE // SOFTWARE. -#include -#include -#include -#include - -// Google Benchmark -#include "benchmark/benchmark.h" +#include "benchmark_utils.hpp" // CmdParser #include "cmdparser.hpp" + +// Google Benchmark +#include // rocPRIM -#include +#include +#include -#include "benchmark_utils.hpp" +#include +#include + +#include +#include enum memory_operation_method { diff --git a/benchmark/benchmark_device_merge.cpp b/benchmark/benchmark_device_merge.cpp index b218ac930..710ddf909 100644 --- a/benchmark/benchmark_device_merge.cpp +++ b/benchmark/benchmark_device_merge.cpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2017-2022 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -20,24 +20,26 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE // SOFTWARE. -#include -#include -#include -#include -#include -#include - -// Google Benchmark -#include "benchmark/benchmark.h" +#include "benchmark_utils.hpp" // CmdParser #include "cmdparser.hpp" -#include "benchmark_utils.hpp" + +// Google Benchmark +#include // HIP API #include // rocPRIM -#include +#include + +#include +#include +#include +#include + +#include +#include #ifndef DEFAULT_N const size_t DEFAULT_N = 1024 * 1024 * 32; diff --git a/benchmark/benchmark_device_merge_sort.cpp b/benchmark/benchmark_device_merge_sort.cpp index 6ea034ffc..9e0bc0ed0 100644 --- a/benchmark/benchmark_device_merge_sort.cpp +++ b/benchmark/benchmark_device_merge_sort.cpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -20,8 +20,11 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE // SOFTWARE. -#include -#include +#include "benchmark_device_merge_sort.hpp" +#include "benchmark_utils.hpp" + +// CmdParser +#include "cmdparser.hpp" // Google Benchmark #include @@ -29,14 +32,8 @@ // HIP API #include -// rocPRIM -#include - -// CmdParser -#include "cmdparser.hpp" - -#include "benchmark_device_merge_sort.hpp" -#include "benchmark_utils.hpp" +#include +#include #ifndef DEFAULT_N const size_t DEFAULT_N = 1024 * 1024 * 32; diff --git a/benchmark/benchmark_device_merge_sort.hpp b/benchmark/benchmark_device_merge_sort.hpp index 24d31b053..3d603898f 100644 --- a/benchmark/benchmark_device_merge_sort.hpp +++ b/benchmark/benchmark_device_merge_sort.hpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2017-2022 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -23,9 +23,7 @@ #ifndef ROCPRIM_BENCHMARK_DEVICE_MERGE_SORT_PARALLEL_HPP_ #define ROCPRIM_BENCHMARK_DEVICE_MERGE_SORT_PARALLEL_HPP_ -#include -#include -#include +#include "benchmark_utils.hpp" // Google Benchmark #include @@ -34,9 +32,12 @@ #include // rocPRIM -#include +#include -#include "benchmark_utils.hpp" +#include +#include + +#include namespace rp = rocprim; diff --git a/benchmark/benchmark_device_merge_sort_block_merge.cpp b/benchmark/benchmark_device_merge_sort_block_merge.cpp index 1ec2679e9..fadb43f50 100644 --- a/benchmark/benchmark_device_merge_sort_block_merge.cpp +++ b/benchmark/benchmark_device_merge_sort_block_merge.cpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2022 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2022-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -20,8 +20,11 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE // SOFTWARE. -#include -#include +#include "benchmark_device_merge_sort_block_merge.parallel.hpp" +#include "benchmark_utils.hpp" + +// CmdParser +#include "cmdparser.hpp" // Google Benchmark #include @@ -29,14 +32,9 @@ // HIP API #include -// rocPRIM -#include - -// CmdParser -#include "cmdparser.hpp" +#include -#include "benchmark_device_merge_sort_block_merge.parallel.hpp" -#include "benchmark_utils.hpp" +#include #ifndef DEFAULT_N const size_t DEFAULT_N = 1024 * 1024 * 32; diff --git a/benchmark/benchmark_device_merge_sort_block_merge.parallel.hpp b/benchmark/benchmark_device_merge_sort_block_merge.parallel.hpp index 98348dac8..5c51249a1 100644 --- a/benchmark/benchmark_device_merge_sort_block_merge.parallel.hpp +++ b/benchmark/benchmark_device_merge_sort_block_merge.parallel.hpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2022-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2022-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -23,9 +23,7 @@ #ifndef ROCPRIM_BENCHMARK_DETAIL_BENCHMARK_DEVICE_MERGE_SORT_BLOCK_MERGE_PARALLEL_HPP_ #define ROCPRIM_BENCHMARK_DETAIL_BENCHMARK_DEVICE_MERGE_SORT_BLOCK_MERGE_PARALLEL_HPP_ -#include -#include -#include +#include "benchmark_utils.hpp" // Google Benchmark #include @@ -34,9 +32,12 @@ #include // rocPRIM -#include +#include -#include "benchmark_utils.hpp" +#include +#include + +#include namespace rp = rocprim; diff --git a/benchmark/benchmark_device_merge_sort_block_sort.cpp b/benchmark/benchmark_device_merge_sort_block_sort.cpp index 605b138a7..bb0c72165 100644 --- a/benchmark/benchmark_device_merge_sort_block_sort.cpp +++ b/benchmark/benchmark_device_merge_sort_block_sort.cpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2022-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2022-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -20,8 +20,11 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE // SOFTWARE. -#include -#include +#include "benchmark_device_merge_sort_block_sort.parallel.hpp" +#include "benchmark_utils.hpp" + +// CmdParser +#include "cmdparser.hpp" // Google Benchmark #include @@ -29,14 +32,9 @@ // HIP API #include -// rocPRIM -#include - -// CmdParser -#include "cmdparser.hpp" +#include -#include "benchmark_device_merge_sort_block_sort.parallel.hpp" -#include "benchmark_utils.hpp" +#include #ifndef DEFAULT_N const size_t DEFAULT_N = 1024 * 1024 * 32; diff --git a/benchmark/benchmark_device_merge_sort_block_sort.parallel.hpp b/benchmark/benchmark_device_merge_sort_block_sort.parallel.hpp index 57ca12ae3..05b7cedd8 100644 --- a/benchmark/benchmark_device_merge_sort_block_sort.parallel.hpp +++ b/benchmark/benchmark_device_merge_sort_block_sort.parallel.hpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2022-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2022-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -23,9 +23,7 @@ #ifndef ROCPRIM_BENCHMARK_DETAIL_BENCHMARK_DEVICE_MERGE_SORT_BLOCK_SORT_PARALLEL_HPP_ #define ROCPRIM_BENCHMARK_DETAIL_BENCHMARK_DEVICE_MERGE_SORT_BLOCK_SORT_PARALLEL_HPP_ -#include -#include -#include +#include "benchmark_utils.hpp" // Google Benchmark #include @@ -34,9 +32,12 @@ #include // rocPRIM -#include +#include -#include "benchmark_utils.hpp" +#include +#include + +#include namespace rp = rocprim; diff --git a/benchmark/benchmark_device_partition.cpp b/benchmark/benchmark_device_partition.cpp index 37af6b875..dca9d8db2 100644 --- a/benchmark/benchmark_device_partition.cpp +++ b/benchmark/benchmark_device_partition.cpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -20,25 +20,28 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE // SOFTWARE. -#include -#include -#include -#include -#include -#include -#include -#include - -// Google Benchmark -#include "benchmark/benchmark.h" +#include "benchmark_utils.hpp" // CmdParser #include "cmdparser.hpp" -#include "benchmark_utils.hpp" + +// Google Benchmark +#include // HIP API #include -#include +// rocPRIM +#include + +#include +#include +#include +#include +#include +#include + +#include +#include #ifndef DEFAULT_N const size_t DEFAULT_N = 1024 * 1024 * 32; diff --git a/benchmark/benchmark_device_radix_sort.cpp b/benchmark/benchmark_device_radix_sort.cpp index f66870cc8..e30e34363 100644 --- a/benchmark/benchmark_device_radix_sort.cpp +++ b/benchmark/benchmark_device_radix_sort.cpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2017-2022 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -20,8 +20,10 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE // SOFTWARE. -#include -#include +#include "benchmark_device_radix_sort.hpp" +#include "benchmark_utils.hpp" +// CmdParser +#include "cmdparser.hpp" // Google Benchmark #include @@ -29,11 +31,9 @@ // HIP API #include -// CmdParser -#include "cmdparser.hpp" +#include -#include "benchmark_device_radix_sort.hpp" -#include "benchmark_utils.hpp" +#include #ifndef DEFAULT_N const size_t DEFAULT_N = 1024 * 1024 * 32; diff --git a/benchmark/benchmark_device_radix_sort.hpp b/benchmark/benchmark_device_radix_sort.hpp index b6f12faca..648bfae6f 100644 --- a/benchmark/benchmark_device_radix_sort.hpp +++ b/benchmark/benchmark_device_radix_sort.hpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -23,9 +23,7 @@ #ifndef ROCPRIM_BENCHMARK_DEVICE_RADIX_SORT_PARALLEL_HPP_ #define ROCPRIM_BENCHMARK_DEVICE_RADIX_SORT_PARALLEL_HPP_ -#include -#include -#include +#include "benchmark_utils.hpp" // Google Benchmark #include @@ -34,9 +32,13 @@ #include // rocPRIM -#include +#include -#include "benchmark_utils.hpp" +#include +#include +#include + +#include namespace rp = rocprim; @@ -77,8 +79,8 @@ struct device_radix_sort_benchmark : public config_autotune_interface // keys benchmark template - auto do_run(benchmark::State& state, size_t size, const hipStream_t stream) const -> - typename std::enable_if::value, void>::type + auto do_run(benchmark::State& state, size_t size, const hipStream_t stream) const + -> std::enable_if_t::value, void> { auto keys_input = generate_keys(size); @@ -96,15 +98,14 @@ struct device_radix_sort_benchmark : public config_autotune_interface void* d_temporary_storage = nullptr; size_t temporary_storage_bytes = 0; - HIP_CHECK(rp::radix_sort_keys(d_temporary_storage, - temporary_storage_bytes, - d_keys_input, - d_keys_output, - size, - 0, - sizeof(key_type) * 8, - stream, - false)); + HIP_CHECK(invoke_radix_sort(d_temporary_storage, + temporary_storage_bytes, + d_keys_input, + d_keys_output, + static_cast(nullptr), + static_cast(nullptr), + size, + stream)); HIP_CHECK(hipMalloc(&d_temporary_storage, temporary_storage_bytes)); HIP_CHECK(hipDeviceSynchronize()); @@ -112,15 +113,14 @@ struct device_radix_sort_benchmark : public config_autotune_interface // Warm-up for(size_t i = 0; i < warmup_size; i++) { - HIP_CHECK(rp::radix_sort_keys(d_temporary_storage, - temporary_storage_bytes, - d_keys_input, - d_keys_output, - size, - 0, - sizeof(key_type) * 8, - stream, - false)); + HIP_CHECK(invoke_radix_sort(d_temporary_storage, + temporary_storage_bytes, + d_keys_input, + d_keys_output, + static_cast(nullptr), + static_cast(nullptr), + size, + stream)); } HIP_CHECK(hipDeviceSynchronize()); @@ -136,15 +136,14 @@ struct device_radix_sort_benchmark : public config_autotune_interface for(size_t i = 0; i < batch_size; i++) { - HIP_CHECK(rp::radix_sort_keys(d_temporary_storage, - temporary_storage_bytes, - d_keys_input, - d_keys_output, - size, - 0, - sizeof(key_type) * 8, - stream, - false)); + HIP_CHECK(invoke_radix_sort(d_temporary_storage, + temporary_storage_bytes, + d_keys_input, + d_keys_output, + static_cast(nullptr), + static_cast(nullptr), + size, + stream)); } // Record stop event and wait until it completes @@ -170,8 +169,8 @@ struct device_radix_sort_benchmark : public config_autotune_interface // pairs benchmark template - auto do_run(benchmark::State& state, size_t size, const hipStream_t stream) const -> - typename std::enable_if::value, void>::type + auto do_run(benchmark::State& state, size_t size, const hipStream_t stream) const + -> std::enable_if_t::value, void> { auto keys_input = generate_keys(size); @@ -204,17 +203,14 @@ struct device_radix_sort_benchmark : public config_autotune_interface void* d_temporary_storage = nullptr; size_t temporary_storage_bytes = 0; - HIP_CHECK(rp::radix_sort_pairs(d_temporary_storage, - temporary_storage_bytes, - d_keys_input, - d_keys_output, - d_values_input, - d_values_output, - size, - 0, - sizeof(key_type) * 8, - stream, - false)); + HIP_CHECK(invoke_radix_sort(d_temporary_storage, + temporary_storage_bytes, + d_keys_input, + d_keys_output, + d_values_input, + d_values_output, + size, + stream)); HIP_CHECK(hipMalloc(&d_temporary_storage, temporary_storage_bytes)); HIP_CHECK(hipDeviceSynchronize()); @@ -222,17 +218,14 @@ struct device_radix_sort_benchmark : public config_autotune_interface // Warm-up for(size_t i = 0; i < warmup_size; i++) { - HIP_CHECK(rp::radix_sort_pairs(d_temporary_storage, - temporary_storage_bytes, - d_keys_input, - d_keys_output, - d_values_input, - d_values_output, - size, - 0, - sizeof(key_type) * 8, - stream, - false)); + HIP_CHECK(invoke_radix_sort(d_temporary_storage, + temporary_storage_bytes, + d_keys_input, + d_keys_output, + d_values_input, + d_values_output, + size, + stream)); } HIP_CHECK(hipDeviceSynchronize()); @@ -248,17 +241,14 @@ struct device_radix_sort_benchmark : public config_autotune_interface for(size_t i = 0; i < batch_size; i++) { - HIP_CHECK(rp::radix_sort_pairs(d_temporary_storage, - temporary_storage_bytes, - d_keys_input, - d_keys_output, - d_values_input, - d_values_output, - size, - 0, - sizeof(key_type) * 8, - stream, - false)); + HIP_CHECK(invoke_radix_sort(d_temporary_storage, + temporary_storage_bytes, + d_keys_input, + d_keys_output, + d_values_input, + d_values_output, + size, + stream)); } // Record stop event and wait until it completes @@ -289,6 +279,101 @@ struct device_radix_sort_benchmark : public config_autotune_interface { do_run(state, size, stream); } + +private: + template + static auto invoke_radix_sort(void* d_temporary_storage, + size_t& temp_storage_bytes, + K* keys_input, + K* keys_output, + V* values_input, + V* values_output, + size_t size, + hipStream_t stream) + -> std::enable_if_t::value && std::is_same::value, + hipError_t> + { + (void)values_input; + (void)values_output; + return rp::radix_sort_keys(d_temporary_storage, + temp_storage_bytes, + keys_input, + keys_output, + size, + 0, + sizeof(K) * 8, + stream); + } + + template + static auto invoke_radix_sort(void* d_temporary_storage, + size_t& temp_storage_bytes, + K* keys_input, + K* keys_output, + V* values_input, + V* values_output, + size_t size, + hipStream_t stream) + -> std::enable_if_t::value && std::is_same::value, + hipError_t> + { + (void)values_input; + (void)values_output; + return rp::radix_sort_keys(d_temporary_storage, + temp_storage_bytes, + keys_input, + keys_output, + size, + custom_type_decomposer{}, + stream); + } + + template + static auto invoke_radix_sort(void* d_temporary_storage, + size_t& temp_storage_bytes, + K* keys_input, + K* keys_output, + V* values_input, + V* values_output, + size_t size, + hipStream_t stream) + -> std::enable_if_t::value && !std::is_same::value, + hipError_t> + { + return rp::radix_sort_pairs(d_temporary_storage, + temp_storage_bytes, + keys_input, + keys_output, + values_input, + values_output, + size, + 0, + sizeof(K) * 8, + stream); + } + + template + static auto invoke_radix_sort(void* d_temporary_storage, + size_t& temp_storage_bytes, + K* keys_input, + K* keys_output, + V* values_input, + V* values_output, + size_t size, + hipStream_t stream) + -> std::enable_if_t::value && !std::is_same::value, + hipError_t> + { + return rp::radix_sort_pairs(d_temporary_storage, + temp_storage_bytes, + keys_input, + keys_output, + values_input, + values_output, + size, + custom_type_decomposer{}, + stream); + } }; #define CREATE_RADIX_SORT_BENCHMARK(...) \ @@ -301,6 +386,7 @@ inline void add_sort_keys_benchmarks(std::vector; CREATE_RADIX_SORT_BENCHMARK(int) CREATE_RADIX_SORT_BENCHMARK(float) CREATE_RADIX_SORT_BENCHMARK(long long) @@ -308,6 +394,7 @@ inline void add_sort_keys_benchmarks(std::vector& benchmarks, @@ -316,6 +403,7 @@ inline void add_sort_pairs_benchmarks(std::vector; using custom_double2 = custom_type; + using custom_key = custom_type; CREATE_RADIX_SORT_BENCHMARK(int, float) CREATE_RADIX_SORT_BENCHMARK(int, double) @@ -333,6 +421,7 @@ inline void add_sort_pairs_benchmarks(std::vector -#include +// CmdParser +#include "cmdparser.hpp" // Google Benchmark #include @@ -29,15 +29,13 @@ // HIP API #include -// rocPRIM -#include - -// CmdParser -#include "cmdparser.hpp" - #include "benchmark_device_radix_sort_block_sort.parallel.hpp" #include "benchmark_utils.hpp" +#include + +#include + #ifndef DEFAULT_N const size_t DEFAULT_N = 1024 * 1024 * 32; #endif diff --git a/benchmark/benchmark_device_radix_sort_block_sort.parallel.hpp b/benchmark/benchmark_device_radix_sort_block_sort.parallel.hpp index f030b9d7a..21755a0e4 100644 --- a/benchmark/benchmark_device_radix_sort_block_sort.parallel.hpp +++ b/benchmark/benchmark_device_radix_sort_block_sort.parallel.hpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2022 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2022-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -23,9 +23,7 @@ #ifndef ROCPRIM_BENCHMARK_DETAIL_BENCHMARK_DEVICE_RADIX_SORT_BLOCK_SORT_PARALLEL_HPP_ #define ROCPRIM_BENCHMARK_DETAIL_BENCHMARK_DEVICE_RADIX_SORT_BLOCK_SORT_PARALLEL_HPP_ -#include -#include -#include +#include "benchmark_utils.hpp" // Google Benchmark #include @@ -34,9 +32,12 @@ #include // rocPRIM -#include +#include -#include "benchmark_utils.hpp" +#include +#include + +#include namespace rp = rocprim; @@ -112,6 +113,7 @@ struct device_radix_sort_block_sort_benchmark : public config_autotune_interface values_ptr, size, items_per_block, + rp::identity_decomposer{}, 0, sizeof(key_type) * 8, stream, @@ -131,16 +133,18 @@ struct device_radix_sort_block_sort_benchmark : public config_autotune_interface for(size_t i = 0; i < batch_size; i++) { - HIP_CHECK((rp::detail::radix_sort_block_sort(d_keys_input, - d_keys_output, - values_ptr, - values_ptr, - size, - items_per_block, - 0, - sizeof(key_type) * 8, - stream, - false))); + HIP_CHECK( + (rp::detail::radix_sort_block_sort(d_keys_input, + d_keys_output, + values_ptr, + values_ptr, + size, + items_per_block, + rp::identity_decomposer{}, + 0, + sizeof(key_type) * 8, + stream, + false))); } // Record stop event and wait until it completes @@ -223,6 +227,7 @@ struct device_radix_sort_block_sort_benchmark : public config_autotune_interface d_values_output, size, items_per_block, + rp::identity_decomposer{}, 0, sizeof(key_type) * 8, stream, @@ -242,16 +247,18 @@ struct device_radix_sort_block_sort_benchmark : public config_autotune_interface for(size_t i = 0; i < batch_size; i++) { - HIP_CHECK((rp::detail::radix_sort_block_sort(d_keys_input, - d_keys_output, - d_values_input, - d_values_output, - size, - items_per_block, - 0, - sizeof(key_type) * 8, - stream, - false))); + HIP_CHECK( + (rp::detail::radix_sort_block_sort(d_keys_input, + d_keys_output, + d_values_input, + d_values_output, + size, + items_per_block, + rp::identity_decomposer{}, + 0, + sizeof(key_type) * 8, + stream, + false))); } // Record stop event and wait until it completes diff --git a/benchmark/benchmark_device_radix_sort_onesweep.cpp b/benchmark/benchmark_device_radix_sort_onesweep.cpp index 19666b033..85506b257 100644 --- a/benchmark/benchmark_device_radix_sort_onesweep.cpp +++ b/benchmark/benchmark_device_radix_sort_onesweep.cpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2022 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2022-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -20,8 +20,8 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE // SOFTWARE. -#include -#include +// CmdParser +#include "cmdparser.hpp" // Google Benchmark #include @@ -29,12 +29,13 @@ // HIP API #include -// CmdParser -#include "cmdparser.hpp" - #include "benchmark_device_radix_sort_onesweep.parallel.hpp" #include "benchmark_utils.hpp" +#include + +#include + #ifndef DEFAULT_N const size_t DEFAULT_N = 1024 * 1024 * 32; #endif diff --git a/benchmark/benchmark_device_radix_sort_onesweep.parallel.hpp b/benchmark/benchmark_device_radix_sort_onesweep.parallel.hpp index 02320e17b..5cf89aeea 100644 --- a/benchmark/benchmark_device_radix_sort_onesweep.parallel.hpp +++ b/benchmark/benchmark_device_radix_sort_onesweep.parallel.hpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2022-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2022-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -23,9 +23,7 @@ #ifndef ROCPRIM_BENCHMARK_DEVICE_RADIX_SORT_ONESWEEP_PARALLEL_HPP_ #define ROCPRIM_BENCHMARK_DEVICE_RADIX_SORT_ONESWEEP_PARALLEL_HPP_ -#include -#include -#include +#include "benchmark_utils.hpp" // Google Benchmark #include @@ -34,9 +32,12 @@ #include // rocPRIM -#include +#include -#include "benchmark_utils.hpp" +#include +#include + +#include namespace rp = rocprim; @@ -141,6 +142,7 @@ struct device_radix_sort_onesweep_benchmark : public config_autotune_interface d_values_ptr, size, is_result_in_output, + rp::identity_decomposer{}, 0, sizeof(key_type) * 8, stream, @@ -152,20 +154,22 @@ struct device_radix_sort_onesweep_benchmark : public config_autotune_interface // Warm-up for(size_t i = 0; i < warmup_size; i++) { - HIP_CHECK((rp::detail::radix_sort_onesweep_impl(d_temporary_storage, - temporary_storage_bytes, - d_keys_input, - nullptr, - d_keys_output, - d_values_ptr, - nullptr, - d_values_ptr, - size, - is_result_in_output, - 0, - sizeof(key_type) * 8, - stream, - false))); + HIP_CHECK( + (rp::detail::radix_sort_onesweep_impl(d_temporary_storage, + temporary_storage_bytes, + d_keys_input, + nullptr, + d_keys_output, + d_values_ptr, + nullptr, + d_values_ptr, + size, + is_result_in_output, + rp::identity_decomposer{}, + 0, + sizeof(key_type) * 8, + stream, + false))); } HIP_CHECK(hipDeviceSynchronize()); @@ -192,6 +196,7 @@ struct device_radix_sort_onesweep_benchmark : public config_autotune_interface d_values_ptr, size, is_result_in_output, + rp::identity_decomposer{}, 0, sizeof(key_type) * 8, stream, @@ -267,6 +272,7 @@ struct device_radix_sort_onesweep_benchmark : public config_autotune_interface d_values_output, size, is_result_in_output, + rp::identity_decomposer{}, 0, sizeof(key_type) * 8, stream, @@ -278,20 +284,22 @@ struct device_radix_sort_onesweep_benchmark : public config_autotune_interface // Warm-up for(size_t i = 0; i < warmup_size; i++) { - HIP_CHECK((rp::detail::radix_sort_onesweep_impl(d_temporary_storage, - temporary_storage_bytes, - d_keys_input, - nullptr, - d_keys_output, - d_values_input, - nullptr, - d_values_output, - size, - is_result_in_output, - 0, - sizeof(key_type) * 8, - stream, - false))); + HIP_CHECK( + (rp::detail::radix_sort_onesweep_impl(d_temporary_storage, + temporary_storage_bytes, + d_keys_input, + nullptr, + d_keys_output, + d_values_input, + nullptr, + d_values_output, + size, + is_result_in_output, + rp::identity_decomposer{}, + 0, + sizeof(key_type) * 8, + stream, + false))); } HIP_CHECK(hipDeviceSynchronize()); @@ -318,6 +326,7 @@ struct device_radix_sort_onesweep_benchmark : public config_autotune_interface d_values_output, size, is_result_in_output, + rp::identity_decomposer{}, 0, sizeof(key_type) * 8, stream, @@ -373,7 +382,8 @@ struct device_radix_sort_onesweep_benchmark_generator ItemsPerThread, RadixBits, false, - RadixRankAlgorithm>::storage_type; + RadixRankAlgorithm, + rp::identity_decomposer>::storage_type; return sizeof(sharedmem_storage) < TUNING_SHARED_MEMORY_MAX; } diff --git a/benchmark/benchmark_device_reduce.cpp b/benchmark/benchmark_device_reduce.cpp index 0a3458a6e..24bc2c069 100644 --- a/benchmark/benchmark_device_reduce.cpp +++ b/benchmark/benchmark_device_reduce.cpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2017-2022 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -20,8 +20,10 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE // SOFTWARE. -#include -#include +#include "benchmark_device_reduce.parallel.hpp" +#include "benchmark_utils.hpp" +// CmdParser +#include "cmdparser.hpp" // Google Benchmark #include @@ -29,14 +31,9 @@ // HIP API #include -// rocPRIM HIP API -#include - -// CmdParser -#include "cmdparser.hpp" +#include -#include "benchmark_utils.hpp" -#include "benchmark_device_reduce.parallel.hpp" +#include #ifndef DEFAULT_N const size_t DEFAULT_N = 1024 * 1024 * 128; diff --git a/benchmark/benchmark_device_reduce.parallel.hpp b/benchmark/benchmark_device_reduce.parallel.hpp index 0e57367c0..6b128b9ef 100644 --- a/benchmark/benchmark_device_reduce.parallel.hpp +++ b/benchmark/benchmark_device_reduce.parallel.hpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2022 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2022-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -23,9 +23,7 @@ #ifndef ROCPRIM_BENCHMARK_DEVICE_REDUCE_PARALLEL_HPP_ #define ROCPRIM_BENCHMARK_DEVICE_REDUCE_PARALLEL_HPP_ -#include -#include -#include +#include "benchmark_utils.hpp" // Google Benchmark #include @@ -34,9 +32,13 @@ #include // rocPRIM HIP API -#include +#include +#include -#include "benchmark_utils.hpp" +#include +#include + +#include constexpr const char* get_reduce_method_name(rocprim::block_reduce_algorithm alg) { diff --git a/benchmark/benchmark_device_reduce_by_key.cpp b/benchmark/benchmark_device_reduce_by_key.cpp index 285492eca..242bd4e2b 100644 --- a/benchmark/benchmark_device_reduce_by_key.cpp +++ b/benchmark/benchmark_device_reduce_by_key.cpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2017-2022 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -20,23 +20,24 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE // SOFTWARE. -#include -#include -#include -#include -#include - -// Google Benchmark -#include "benchmark/benchmark.h" +#include "benchmark_utils.hpp" // CmdParser #include "cmdparser.hpp" -#include "benchmark_utils.hpp" + +// Google Benchmark +#include // HIP API #include // rocPRIM -#include +#include + +#include +#include +#include +#include +#include #ifndef DEFAULT_N const size_t DEFAULT_N = 1024 * 1024 * 32; diff --git a/benchmark/benchmark_device_run_length_encode.cpp b/benchmark/benchmark_device_run_length_encode.cpp index f0c4753f9..473d2f871 100644 --- a/benchmark/benchmark_device_run_length_encode.cpp +++ b/benchmark/benchmark_device_run_length_encode.cpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2017-2022 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -20,23 +20,24 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE // SOFTWARE. -#include -#include -#include -#include -#include - -// Google Benchmark -#include "benchmark/benchmark.h" +#include "benchmark_utils.hpp" // CmdParser #include "cmdparser.hpp" -#include "benchmark_utils.hpp" + +// Google Benchmark +#include // HIP API #include // rocPRIM -#include +#include + +#include +#include +#include +#include +#include #ifndef DEFAULT_N const size_t DEFAULT_N = 1024 * 1024 * 32; diff --git a/benchmark/benchmark_device_scan.cpp b/benchmark/benchmark_device_scan.cpp index 133a8ef86..2486706bf 100644 --- a/benchmark/benchmark_device_scan.cpp +++ b/benchmark/benchmark_device_scan.cpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -20,8 +20,10 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE // SOFTWARE. -#include -#include +#include "benchmark_device_scan.parallel.hpp" +#include "benchmark_utils.hpp" +// CmdParser +#include "cmdparser.hpp" // Google Benchmark #include @@ -29,14 +31,9 @@ // HIP API #include -// rocPRIM -#include - -// CmdParser -#include "cmdparser.hpp" +#include -#include "benchmark_device_scan.parallel.hpp" -#include "benchmark_utils.hpp" +#include #ifndef DEFAULT_N const size_t DEFAULT_N = 1024 * 1024 * 32; diff --git a/benchmark/benchmark_device_scan.parallel.hpp b/benchmark/benchmark_device_scan.parallel.hpp index 4f976d4c4..f31b02af6 100644 --- a/benchmark/benchmark_device_scan.parallel.hpp +++ b/benchmark/benchmark_device_scan.parallel.hpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -23,9 +23,7 @@ #ifndef ROCPRIM_BENCHMARK_DEVICE_SCAN_PARALLEL_HPP_ #define ROCPRIM_BENCHMARK_DEVICE_SCAN_PARALLEL_HPP_ -#include -#include -#include +#include "benchmark_utils.hpp" // Google Benchmark #include @@ -34,9 +32,13 @@ #include // rocPRIM -#include +#include +#include -#include "benchmark_utils.hpp" +#include +#include + +#include template std::string config_name() diff --git a/benchmark/benchmark_device_scan_by_key.cpp b/benchmark/benchmark_device_scan_by_key.cpp index 18cb25c03..7e8cfface 100644 --- a/benchmark/benchmark_device_scan_by_key.cpp +++ b/benchmark/benchmark_device_scan_by_key.cpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -20,8 +20,10 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE // SOFTWARE. -#include -#include +#include "benchmark_device_scan_by_key.parallel.hpp" +#include "benchmark_utils.hpp" +// CmdParser +#include "cmdparser.hpp" // Google Benchmark #include @@ -29,14 +31,9 @@ // HIP API #include -// rocPRIM -#include - -// CmdParser -#include "cmdparser.hpp" +#include -#include "benchmark_device_scan_by_key.parallel.hpp" -#include "benchmark_utils.hpp" +#include #ifndef DEFAULT_N const size_t DEFAULT_N = 1024 * 1024 * 32; diff --git a/benchmark/benchmark_device_scan_by_key.parallel.hpp b/benchmark/benchmark_device_scan_by_key.parallel.hpp index e4748901a..f97c61adf 100644 --- a/benchmark/benchmark_device_scan_by_key.parallel.hpp +++ b/benchmark/benchmark_device_scan_by_key.parallel.hpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -23,9 +23,7 @@ #ifndef ROCPRIM_BENCHMARK_DEVICE_SCAN_BY_KEY_PARALLEL_HPP_ #define ROCPRIM_BENCHMARK_DEVICE_SCAN_BY_KEY_PARALLEL_HPP_ -#include -#include -#include +#include "benchmark_utils.hpp" // Google Benchmark #include @@ -34,9 +32,13 @@ #include // rocPRIM -#include +#include +#include -#include "benchmark_utils.hpp" +#include +#include + +#include template std::string config_name() diff --git a/benchmark/benchmark_device_segmented_radix_sort_keys.cpp b/benchmark/benchmark_device_segmented_radix_sort_keys.cpp index dfd7e14e9..828a1d918 100644 --- a/benchmark/benchmark_device_segmented_radix_sort_keys.cpp +++ b/benchmark/benchmark_device_segmented_radix_sort_keys.cpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -20,23 +20,24 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE // SOFTWARE. -#include -#include -#include -#include -#include - -// Google Benchmark -#include "benchmark/benchmark.h" -// CmdParser #include "benchmark_utils.hpp" +// CmdParser #include "cmdparser.hpp" +// Google Benchmark +#include + // HIP API #include // rocPRIM -#include +#include + +#include +#include +#include +#include +#include #ifndef DEFAULT_N const size_t DEFAULT_N = 1024 * 1024 * 32; diff --git a/benchmark/benchmark_device_segmented_radix_sort_keys.parallel.hpp b/benchmark/benchmark_device_segmented_radix_sort_keys.parallel.hpp index 4227d223a..d9e471b9f 100644 --- a/benchmark/benchmark_device_segmented_radix_sort_keys.parallel.hpp +++ b/benchmark/benchmark_device_segmented_radix_sort_keys.parallel.hpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -23,9 +23,7 @@ #ifndef ROCPRIM_BENCHMARK_DEVICE_SEGMENTED_RADIX_SORT_KEYS_PARALLEL_HPP_ #define ROCPRIM_BENCHMARK_DEVICE_SEGMENTED_RADIX_SORT_KEYS_PARALLEL_HPP_ -#include -#include -#include +#include "benchmark_utils.hpp" // Google Benchmark #include @@ -34,9 +32,13 @@ #include // rocPRIM -#include +#include +#include -#include "benchmark_utils.hpp" +#include +#include + +#include template std::string warp_sort_config_name(T const& warp_sort_config) diff --git a/benchmark/benchmark_device_segmented_radix_sort_pairs.cpp b/benchmark/benchmark_device_segmented_radix_sort_pairs.cpp index 4aba2b436..03d63724a 100644 --- a/benchmark/benchmark_device_segmented_radix_sort_pairs.cpp +++ b/benchmark/benchmark_device_segmented_radix_sort_pairs.cpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -20,23 +20,24 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE // SOFTWARE. -#include -#include -#include -#include -#include +#include "benchmark_utils.hpp" +// CmdParser +#include "cmdparser.hpp" // Google Benchmark #include "benchmark/benchmark.h" -// CmdParser -#include "benchmark_utils.hpp" -#include "cmdparser.hpp" // HIP API #include // rocPRIM -#include +#include + +#include +#include +#include +#include +#include #ifndef DEFAULT_N const size_t DEFAULT_N = 1024 * 1024 * 32; diff --git a/benchmark/benchmark_device_segmented_radix_sort_pairs.parallel.hpp b/benchmark/benchmark_device_segmented_radix_sort_pairs.parallel.hpp index 917af3a25..891be65bb 100644 --- a/benchmark/benchmark_device_segmented_radix_sort_pairs.parallel.hpp +++ b/benchmark/benchmark_device_segmented_radix_sort_pairs.parallel.hpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -23,9 +23,7 @@ #ifndef ROCPRIM_BENCHMARK_DEVICE_SEGMENTED_RADIX_SORT_PAIRS_PARALLEL_HPP_ #define ROCPRIM_BENCHMARK_DEVICE_SEGMENTED_RADIX_SORT_PAIRS_PARALLEL_HPP_ -#include -#include -#include +#include "benchmark_utils.hpp" // Google Benchmark #include @@ -34,9 +32,13 @@ #include // rocPRIM -#include +#include +#include -#include "benchmark_utils.hpp" +#include +#include + +#include template std::string warp_sort_config_name(T const& warp_sort_config) diff --git a/benchmark/benchmark_device_segmented_reduce.cpp b/benchmark/benchmark_device_segmented_reduce.cpp index 0631a55e5..e75959891 100644 --- a/benchmark/benchmark_device_segmented_reduce.cpp +++ b/benchmark/benchmark_device_segmented_reduce.cpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2017-2022 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -20,23 +20,24 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE // SOFTWARE. -#include -#include -#include -#include -#include - -// Google Benchmark -#include "benchmark/benchmark.h" +#include "benchmark_utils.hpp" // CmdParser #include "cmdparser.hpp" -#include "benchmark_utils.hpp" + +// Google Benchmark +#include // HIP API #include // rocPRIM -#include +#include + +#include +#include +#include +#include +#include #ifndef DEFAULT_N const size_t DEFAULT_N = 1024 * 1024 * 32; diff --git a/benchmark/benchmark_device_select.cpp b/benchmark/benchmark_device_select.cpp index a0cff17cc..3dbb1d76a 100644 --- a/benchmark/benchmark_device_select.cpp +++ b/benchmark/benchmark_device_select.cpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2017-2022 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -20,24 +20,27 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE // SOFTWARE. -#include -#include -#include -#include -#include -#include -#include +#include "benchmark_utils.hpp" +// CmdParser +#include "cmdparser.hpp" // Google Benchmark #include "benchmark/benchmark.h" -// CmdParser -#include "cmdparser.hpp" -#include "benchmark_utils.hpp" // HIP API #include -#include +// rocPRIM +#include + +#include +#include +#include +#include +#include + +#include +#include #ifndef DEFAULT_N const size_t DEFAULT_N = 1024 * 1024 * 32; diff --git a/benchmark/benchmark_device_transform.cpp b/benchmark/benchmark_device_transform.cpp index 88ec3bf6a..79646bff3 100644 --- a/benchmark/benchmark_device_transform.cpp +++ b/benchmark/benchmark_device_transform.cpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2017-2022 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -20,24 +20,26 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE // SOFTWARE. -#include -#include -#include -#include -#include -#include - -// Google Benchmark -#include "benchmark/benchmark.h" +#include "benchmark_utils.hpp" // CmdParser #include "cmdparser.hpp" -#include "benchmark_utils.hpp" + +// Google Benchmark +#include // HIP API #include // rocPRIM -#include +#include + +#include +#include +#include +#include + +#include +#include #ifndef DEFAULT_N const size_t DEFAULT_N = 1024 * 1024 * 128; diff --git a/benchmark/benchmark_predicate_iterator.cpp b/benchmark/benchmark_predicate_iterator.cpp new file mode 100644 index 000000000..c2368dc30 --- /dev/null +++ b/benchmark/benchmark_predicate_iterator.cpp @@ -0,0 +1,243 @@ +// MIT License +// +// Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +#include "benchmark_utils.hpp" +#include "cmdparser.hpp" + +#include + +#include + +// rocPRIM +#include +#include +#include + +#include +#include +#include +#include + +#include +#include + +#ifndef DEFAULT_N +const size_t DEFAULT_N = 1024 * 1024 * 128; +#endif + +const unsigned int batch_size = 10; +const unsigned int warmup_size = 5; + +template +struct identity +{ + __device__ T operator()(T value) + { + return value; + } +}; + +template +struct less_than +{ + __device__ bool operator()(T value) const + { + return value < T{C}; + } +}; + +template +struct increment +{ + __device__ T operator()(T value) const + { + return value + T{I}; + } +}; + +template +struct transform_it +{ + using value_type = T; + + void operator()(T* d_input, T* d_output, const size_t size, const hipStream_t stream) + { + auto t_it = rocprim::make_transform_iterator( + d_input, + [&] __device__(T v) { return Predicate{}(v) ? Transform{}(v) : v; }); + HIP_CHECK(rocprim::transform(t_it, d_output, size, identity{}, stream)); + } +}; + +template +struct read_predicate_it +{ + using value_type = T; + + void operator()(T* d_input, T* d_output, const size_t size, const hipStream_t stream) + { + auto t_it = rocprim::make_transform_iterator(d_input, Transform{}); + auto r_it = rocprim::make_predicate_iterator(t_it, d_input, Predicate{}); + HIP_CHECK(rocprim::transform(r_it, d_output, size, identity{}, stream)); + } +}; + +template +struct write_predicate_it +{ + using value_type = T; + + void operator()(T* d_input, T* d_output, const size_t size, const hipStream_t stream) + { + auto t_it = rocprim::make_transform_iterator(d_input, Transform{}); + auto w_it = rocprim::make_predicate_iterator(d_output, d_input, Predicate{}); + HIP_CHECK(rocprim::transform(t_it, w_it, size, identity{}, stream)); + } +}; + +template +void run_benchmark(benchmark::State& state, size_t size, const hipStream_t stream) +{ + using T = typename IteratorBenchmark::value_type; + + std::vector input = get_random_data(size, T(0), T(99)); + T* d_input; + T* d_output; + HIP_CHECK(hipMalloc(reinterpret_cast(&d_input), size * sizeof(T))); + HIP_CHECK(hipMalloc(reinterpret_cast(&d_output), size * sizeof(T))); + HIP_CHECK(hipMemcpy(d_input, input.data(), size * sizeof(T), hipMemcpyHostToDevice)); + HIP_CHECK(hipDeviceSynchronize()); + + // Warm-up + for(size_t i = 0; i < warmup_size; i++) + { + IteratorBenchmark{}(d_input, d_output, size, stream); + } + HIP_CHECK(hipDeviceSynchronize()); + + // HIP events creation + hipEvent_t start, stop; + HIP_CHECK(hipEventCreate(&start)); + HIP_CHECK(hipEventCreate(&stop)); + + for(auto _ : state) + { + // Record start event + HIP_CHECK(hipEventRecord(start, stream)); + + for(size_t i = 0; i < batch_size; i++) + { + IteratorBenchmark{}(d_input, d_output, size, stream); + } + + // Record stop event and wait until it completes + HIP_CHECK(hipEventRecord(stop, stream)); + HIP_CHECK(hipEventSynchronize(stop)); + + float elapsed_mseconds; + HIP_CHECK(hipEventElapsedTime(&elapsed_mseconds, start, stop)); + state.SetIterationTime(elapsed_mseconds / 1000); + } + + // Destroy HIP events + HIP_CHECK(hipEventDestroy(start)); + HIP_CHECK(hipEventDestroy(stop)); + + state.SetBytesProcessed(state.iterations() * batch_size * size * sizeof(T)); + state.SetItemsProcessed(state.iterations() * batch_size * size); + + HIP_CHECK(hipFree(d_input)); + HIP_CHECK(hipFree(d_output)); +} + +#define CREATE_BENCHMARK(B, T, C) \ + benchmark::RegisterBenchmark(bench_naming::format_name("{lvl:device,algo:" #B ",p:p" #C \ + ",key_type:" #T ",cfg:default_config}") \ + .c_str(), \ + run_benchmark, increment>>, \ + size, \ + stream) + +#define CREATE_TYPED_BENCHMARK(T) \ + CREATE_BENCHMARK(transform_it, T, 0), CREATE_BENCHMARK(read_predicate_it, T, 0), \ + CREATE_BENCHMARK(write_predicate_it, T, 0), CREATE_BENCHMARK(transform_it, T, 25), \ + CREATE_BENCHMARK(read_predicate_it, T, 25), CREATE_BENCHMARK(write_predicate_it, T, 25), \ + CREATE_BENCHMARK(transform_it, T, 50), CREATE_BENCHMARK(read_predicate_it, T, 50), \ + CREATE_BENCHMARK(write_predicate_it, T, 50), CREATE_BENCHMARK(transform_it, T, 75), \ + CREATE_BENCHMARK(read_predicate_it, T, 75), CREATE_BENCHMARK(write_predicate_it, T, 75), \ + CREATE_BENCHMARK(transform_it, T, 100), CREATE_BENCHMARK(read_predicate_it, T, 100), \ + CREATE_BENCHMARK(write_predicate_it, T, 100) + +int main(int argc, char* argv[]) +{ + cli::Parser parser(argc, argv); + parser.set_optional("size", "size", DEFAULT_N, "number of values"); + parser.set_optional("trials", "trials", -1, "number of iterations"); + parser.set_optional("name_format", + "name_format", + "human", + "either: json,human,txt"); + parser.run_and_exit_if_error(); + + // Parse argv + benchmark::Initialize(&argc, argv); + const size_t size = parser.get("size"); + const int trials = parser.get("trials"); + bench_naming::set_format(parser.get("name_format")); + + // HIP + hipStream_t stream = 0; // default + + // Benchmark info + add_common_benchmark_info(); + benchmark::AddCustomContext("size", std::to_string(size)); + + using custom_128 = custom_type; + + // Add benchmarks + std::vector benchmarks = {CREATE_TYPED_BENCHMARK(int8_t), + CREATE_TYPED_BENCHMARK(int16_t), + CREATE_TYPED_BENCHMARK(int32_t), + CREATE_TYPED_BENCHMARK(int64_t), + CREATE_TYPED_BENCHMARK(custom_128)}; + + // Use manual timing + for(auto& b : benchmarks) + { + b->UseManualTime(); + b->Unit(benchmark::kMillisecond); + } + + // Force number of iterations + if(trials > 0) + { + for(auto& b : benchmarks) + { + b->Iterations(trials); + } + } + + // Run benchmarks + benchmark::RunSpecifiedBenchmarks(); + + return 0; +} \ No newline at end of file diff --git a/benchmark/benchmark_utils.hpp b/benchmark/benchmark_utils.hpp index e5863e758..ef901f4be 100644 --- a/benchmark/benchmark_utils.hpp +++ b/benchmark/benchmark_utils.hpp @@ -21,6 +21,13 @@ #ifndef ROCPRIM_BENCHMARK_UTILS_HPP_ #define ROCPRIM_BENCHMARK_UTILS_HPP_ +#include + +// rocPRIM +#include +#include +#include + #include #include #include @@ -33,9 +40,6 @@ #include #include -#include "benchmark/benchmark.h" -#include - #define HIP_CHECK(condition) \ { \ hipError_t error = condition; \ @@ -231,6 +235,13 @@ struct custom_type { return x == rhs.x && y == rhs.y; } + + ROCPRIM_HOST_DEVICE custom_type& operator+=(const custom_type& rhs) + { + this->x += rhs.x; + this->y += rhs.y; + return *this; + } }; template @@ -239,6 +250,21 @@ struct is_custom_type : std::false_type {}; template struct is_custom_type> : std::true_type {}; +template +struct custom_type_decomposer +{ + static_assert(is_custom_type::value, + "custom_type_decomposer can only be used with instantiations of custom_type"); + + using T = typename CustomType::first_type; + using U = typename CustomType::second_type; + + __host__ __device__ ::rocprim::tuple operator()(CustomType& key) const + { + return ::rocprim::tuple{key.x, key.y}; + } +}; + template inline auto generate_random_data_n(OutputIterator it, size_t size, @@ -711,6 +737,11 @@ inline const char* Traits>::name() return "custom_type"; } template<> +inline const char* Traits>::name() +{ + return "custom_type"; +} +template<> inline const char* Traits::name() { return "empty_type"; diff --git a/benchmark/benchmark_warp_exchange.cpp b/benchmark/benchmark_warp_exchange.cpp index 64a0c65a5..7ddb48728 100644 --- a/benchmark/benchmark_warp_exchange.cpp +++ b/benchmark/benchmark_warp_exchange.cpp @@ -20,24 +20,26 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE // SOFTWARE. -#include - -#include -#include -#include -#include -#include - -// Google Benchmark -#include "benchmark/benchmark.h" +#include "benchmark_utils.hpp" // CmdParser #include "cmdparser.hpp" -#include "benchmark_utils.hpp" + +// Google Benchmark +#include // HIP API #include +#include #include +#include +#include +#include +#include + +#include +#include + #ifndef DEFAULT_N const size_t DEFAULT_N = 1024 * 1024 * 32; #endif @@ -309,67 +311,73 @@ int main(int argc, char *argv[]) // Add benchmarks std::vector benchmarks{ - CREATE_BENCHMARK(int, 256, 1, 16, BlockedToStripedOp), - CREATE_BENCHMARK(int, 256, 1, 32, BlockedToStripedOp), - CREATE_BENCHMARK(int, 256, 4, 16, BlockedToStripedOp), - CREATE_BENCHMARK(int, 256, 4, 32, BlockedToStripedOp), + CREATE_BENCHMARK(int, 256, 1, 16, BlockedToStripedOp), + CREATE_BENCHMARK(int, 256, 1, 32, BlockedToStripedOp), + CREATE_BENCHMARK(int, 256, 4, 16, BlockedToStripedOp), + CREATE_BENCHMARK(int, 256, 4, 32, BlockedToStripedOp), CREATE_BENCHMARK(int, 256, 16, 16, BlockedToStripedOp), CREATE_BENCHMARK(int, 256, 16, 32, BlockedToStripedOp), + CREATE_BENCHMARK(int, 256, 32, 32, BlockedToStripedOp), - CREATE_BENCHMARK(int, 256, 1, 16, StripedToBlockedOp), - CREATE_BENCHMARK(int, 256, 1, 32, StripedToBlockedOp), - CREATE_BENCHMARK(int, 256, 4, 16, StripedToBlockedOp), - CREATE_BENCHMARK(int, 256, 4, 32, StripedToBlockedOp), + CREATE_BENCHMARK(int, 256, 1, 16, StripedToBlockedOp), + CREATE_BENCHMARK(int, 256, 1, 32, StripedToBlockedOp), + CREATE_BENCHMARK(int, 256, 4, 16, StripedToBlockedOp), + CREATE_BENCHMARK(int, 256, 4, 32, StripedToBlockedOp), CREATE_BENCHMARK(int, 256, 16, 16, StripedToBlockedOp), CREATE_BENCHMARK(int, 256, 16, 32, StripedToBlockedOp), + CREATE_BENCHMARK(int, 256, 32, 32, StripedToBlockedOp), - CREATE_BENCHMARK(int, 256, 1, 16, BlockedToStripedShuffleOp), - CREATE_BENCHMARK(int, 256, 1, 32, BlockedToStripedShuffleOp), - CREATE_BENCHMARK(int, 256, 4, 16, BlockedToStripedShuffleOp), - CREATE_BENCHMARK(int, 256, 4, 32, BlockedToStripedShuffleOp), + CREATE_BENCHMARK(int, 256, 1, 16, BlockedToStripedShuffleOp), + CREATE_BENCHMARK(int, 256, 1, 32, BlockedToStripedShuffleOp), + CREATE_BENCHMARK(int, 256, 4, 16, BlockedToStripedShuffleOp), + CREATE_BENCHMARK(int, 256, 4, 32, BlockedToStripedShuffleOp), CREATE_BENCHMARK(int, 256, 16, 16, BlockedToStripedShuffleOp), CREATE_BENCHMARK(int, 256, 16, 32, BlockedToStripedShuffleOp), + CREATE_BENCHMARK(int, 256, 32, 32, BlockedToStripedShuffleOp), - CREATE_BENCHMARK(int, 256, 1, 16, StripedToBlockedShuffleOp), - CREATE_BENCHMARK(int, 256, 1, 32, StripedToBlockedShuffleOp), - CREATE_BENCHMARK(int, 256, 4, 16, StripedToBlockedShuffleOp), - CREATE_BENCHMARK(int, 256, 4, 32, StripedToBlockedShuffleOp), + CREATE_BENCHMARK(int, 256, 1, 16, StripedToBlockedShuffleOp), + CREATE_BENCHMARK(int, 256, 1, 32, StripedToBlockedShuffleOp), + CREATE_BENCHMARK(int, 256, 4, 16, StripedToBlockedShuffleOp), + CREATE_BENCHMARK(int, 256, 4, 32, StripedToBlockedShuffleOp), CREATE_BENCHMARK(int, 256, 16, 16, StripedToBlockedShuffleOp), CREATE_BENCHMARK(int, 256, 16, 32, StripedToBlockedShuffleOp), + CREATE_BENCHMARK(int, 256, 32, 32, StripedToBlockedShuffleOp), - CREATE_BENCHMARK(int, 256, 1, 16, ScatterToStripedOp), - CREATE_BENCHMARK(int, 256, 1, 32, ScatterToStripedOp), - CREATE_BENCHMARK(int, 256, 4, 16, ScatterToStripedOp), - CREATE_BENCHMARK(int, 256, 4, 32, ScatterToStripedOp), + CREATE_BENCHMARK(int, 256, 1, 16, ScatterToStripedOp), + CREATE_BENCHMARK(int, 256, 1, 32, ScatterToStripedOp), + CREATE_BENCHMARK(int, 256, 4, 16, ScatterToStripedOp), + CREATE_BENCHMARK(int, 256, 4, 32, ScatterToStripedOp), CREATE_BENCHMARK(int, 256, 16, 16, ScatterToStripedOp), - CREATE_BENCHMARK(int, 256, 16, 32, ScatterToStripedOp) - }; + CREATE_BENCHMARK(int, 256, 16, 32, ScatterToStripedOp)}; int hip_device = 0; HIP_CHECK(::rocprim::detail::get_device_from_stream(stream, hip_device)); if(is_warp_size_supported(64, hip_device)) { std::vector additional_benchmarks{ - CREATE_BENCHMARK(int, 256, 1, 64, BlockedToStripedOp), - CREATE_BENCHMARK(int, 256, 4, 64, BlockedToStripedOp), + CREATE_BENCHMARK(int, 256, 1, 64, BlockedToStripedOp), + CREATE_BENCHMARK(int, 256, 4, 64, BlockedToStripedOp), CREATE_BENCHMARK(int, 256, 16, 64, BlockedToStripedOp), + CREATE_BENCHMARK(int, 256, 64, 64, BlockedToStripedOp), - CREATE_BENCHMARK(int, 256, 1, 64, StripedToBlockedOp), - CREATE_BENCHMARK(int, 256, 4, 64, StripedToBlockedOp), + CREATE_BENCHMARK(int, 256, 1, 64, StripedToBlockedOp), + CREATE_BENCHMARK(int, 256, 4, 64, StripedToBlockedOp), CREATE_BENCHMARK(int, 256, 16, 64, StripedToBlockedOp), + CREATE_BENCHMARK(int, 256, 64, 64, StripedToBlockedOp), - CREATE_BENCHMARK(int, 256, 1, 64, BlockedToStripedShuffleOp), - CREATE_BENCHMARK(int, 256, 4, 64, BlockedToStripedShuffleOp), + CREATE_BENCHMARK(int, 256, 1, 64, BlockedToStripedShuffleOp), + CREATE_BENCHMARK(int, 256, 4, 64, BlockedToStripedShuffleOp), CREATE_BENCHMARK(int, 256, 16, 64, BlockedToStripedShuffleOp), + CREATE_BENCHMARK(int, 256, 64, 64, BlockedToStripedShuffleOp), - CREATE_BENCHMARK(int, 256, 1, 64, StripedToBlockedShuffleOp), - CREATE_BENCHMARK(int, 256, 4, 64, StripedToBlockedShuffleOp), + CREATE_BENCHMARK(int, 256, 1, 64, StripedToBlockedShuffleOp), + CREATE_BENCHMARK(int, 256, 4, 64, StripedToBlockedShuffleOp), CREATE_BENCHMARK(int, 256, 16, 64, StripedToBlockedShuffleOp), + CREATE_BENCHMARK(int, 256, 64, 64, StripedToBlockedShuffleOp), - CREATE_BENCHMARK(int, 256, 1, 64, ScatterToStripedOp), - CREATE_BENCHMARK(int, 256, 4, 64, ScatterToStripedOp), - CREATE_BENCHMARK(int, 256, 16, 64, ScatterToStripedOp) - }; + CREATE_BENCHMARK(int, 256, 1, 64, ScatterToStripedOp), + CREATE_BENCHMARK(int, 256, 4, 64, ScatterToStripedOp), + CREATE_BENCHMARK(int, 256, 16, 64, ScatterToStripedOp)}; benchmarks.insert( benchmarks.end(), additional_benchmarks.begin(), diff --git a/benchmark/benchmark_warp_reduce.cpp b/benchmark/benchmark_warp_reduce.cpp index 921c4105f..14dd57f4c 100644 --- a/benchmark/benchmark_warp_reduce.cpp +++ b/benchmark/benchmark_warp_reduce.cpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2017-2022 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -20,25 +20,26 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE // SOFTWARE. -#include - -#include -#include -#include -#include -#include - -// Google Benchmark -#include "benchmark/benchmark.h" +#include "benchmark_utils.hpp" // CmdParser #include "cmdparser.hpp" -#include "benchmark_utils.hpp" + +// Google Benchmark +#include // HIP API #include // rocPRIM -#include +#include + +#include +#include +#include +#include + +#include +#include #ifndef DEFAULT_N const size_t DEFAULT_N = 1024 * 1024 * 32; diff --git a/benchmark/benchmark_warp_scan.cpp b/benchmark/benchmark_warp_scan.cpp index daee015ae..c12066fd5 100644 --- a/benchmark/benchmark_warp_scan.cpp +++ b/benchmark/benchmark_warp_scan.cpp @@ -20,24 +20,24 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE // SOFTWARE. -#include - -#include -#include -#include -#include -#include - -// Google Benchmark -#include "benchmark/benchmark.h" +#include "benchmark_utils.hpp" // CmdParser #include "cmdparser.hpp" + +// Google Benchmark +#include // HIP API #include // rocPRIM -#include +#include -#include "benchmark_utils.hpp" +#include +#include +#include +#include + +#include +#include #ifndef DEFAULT_N const size_t DEFAULT_N = 1024 * 1024 * 32; diff --git a/benchmark/benchmark_warp_sort.cpp b/benchmark/benchmark_warp_sort.cpp index 5d6eedea0..aae44e2a1 100644 --- a/benchmark/benchmark_warp_sort.cpp +++ b/benchmark/benchmark_warp_sort.cpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2017-2022 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -20,24 +20,26 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE // SOFTWARE. -#include - -#include -#include -#include -#include -#include - -// Google Benchmark -#include "benchmark/benchmark.h" +#include "benchmark_utils.hpp" // CmdParser #include "cmdparser.hpp" + +// Google Benchmark +#include // HIP API #include // rocPRIM -#include +#include +#include +#include -#include "benchmark_utils.hpp" +#include +#include +#include +#include + +#include +#include #ifndef DEFAULT_N const size_t DEFAULT_N = 1024 * 1024 * 32; diff --git a/benchmark/cmdparser.hpp b/benchmark/cmdparser.hpp index ffee10ecb..7ac2e75da 100644 --- a/benchmark/cmdparser.hpp +++ b/benchmark/cmdparser.hpp @@ -1,7 +1,7 @@ // The MIT License (MIT) // // Copyright (c) 2015 - 2016 Florian Rappl -// Modifications Copyright (c) 2019, Advanced Micro Devices, Inc. All rights reserved. +// Modifications Copyright (c) 2019-2024, Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -27,12 +27,12 @@ */ #pragma once +#include #include +#include #include #include #include -#include -#include namespace cli { struct CallbackArgs { diff --git a/docs/block_ops/ops_classes/sort.rst b/docs/block_ops/ops_classes/sort.rst index a71d94b6c..42688fb0b 100644 --- a/docs/block_ops/ops_classes/sort.rst +++ b/docs/block_ops/ops_classes/sort.rst @@ -8,16 +8,15 @@ Sort ******************************************************************** -generic -========= - +Generic Block Sort +================== .. doxygenclass:: rocprim::block_sort :members: .. doxygenenum:: rocprim::block_sort_algorithm -radix sort +Radix sort =========== .. doxygenclass:: rocprim::block_radix_sort diff --git a/docs/device_ops/device_copy.rst b/docs/device_ops/device_copy.rst new file mode 100644 index 000000000..4de86aa72 --- /dev/null +++ b/docs/device_ops/device_copy.rst @@ -0,0 +1,18 @@ +.. meta:: + :description: rocPRIM documentation and API reference library + :keywords: rocPRIM, ROCm, API, documentation + +.. _dev-device_copy: + +DeviceCopy +---------- + +Configuring the kernel +~~~~~~~~~~~~~~~~~~~~~~ + +.. doxygenstruct:: rocprim::batch_copy_config + +batch_copy +~~~~~~~~~~~~ + +.. doxygenfunction:: rocprim::batch_copy(void* temporary_storage, size_t& storage_size, InputBufferItType sources, OutputBufferItType destinations, BufferSizeItType sizes, uint32_t num_copies, hipStream_t stream = hipStreamDefault, bool debug_synchronous = false) diff --git a/docs/device_ops/index.rst b/docs/device_ops/index.rst index 1f01a7414..068b572b1 100644 --- a/docs/device_ops/index.rst +++ b/docs/device_ops/index.rst @@ -21,4 +21,5 @@ * :ref:`dev-adjacent_difference` * :ref:`dev-binary_search` * :ref:`dev-histogram` + * :ref:`dev-device_copy` * :ref:`dev-memcpy` diff --git a/docs/device_ops/sort.rst b/docs/device_ops/sort.rst index 7ffa1cf7c..312db469b 100644 --- a/docs/device_ops/sort.rst +++ b/docs/device_ops/sort.rst @@ -31,46 +31,66 @@ merge_sort radix_sort_keys ================ -ascending ----------- +Ascending Sort +-------------- +.. doxygenfunction:: rocprim::radix_sort_keys(void *temporary_storage, size_t &storage_size, KeysInputIterator keys_input, KeysOutputIterator keys_output, Size size, unsigned int begin_bit=0, unsigned int end_bit=8 *sizeof(Key), hipStream_t stream=0, bool debug_synchronous=false) +.. doxygenfunction:: rocprim::radix_sort_keys(void *temporary_storage, size_t &storage_size, KeysInputIterator keys_input, KeysOutputIterator keys_output, Size size, Decomposer decomposer, unsigned int begin_bit, unsigned int end_bit, hipStream_t stream=0, bool debug_synchronous=false) +.. doxygenfunction:: rocprim::radix_sort_keys(void *temporary_storage, size_t &storage_size, KeysInputIterator keys_input, KeysOutputIterator keys_output, Size size, Decomposer decomposer, hipStream_t stream=0, bool debug_synchronous=false) .. doxygenfunction:: rocprim::radix_sort_keys(void *temporary_storage, size_t &storage_size, double_buffer< Key > &keys, Size size, unsigned int begin_bit=0, unsigned int end_bit=8 *sizeof(Key), hipStream_t stream=0, bool debug_synchronous=false) +.. doxygenfunction:: rocprim::radix_sort_keys(void *temporary_storage, size_t &storage_size, double_buffer< Key > &keys, Size size, Decomposer decomposer, unsigned int begin_bit, unsigned int end_bit, hipStream_t stream=0, bool debug_synchronous=false) +.. doxygenfunction:: rocprim::radix_sort_keys(void *temporary_storage, size_t &storage_size, double_buffer< Key > &keys, Size size, Decomposer decomposer, hipStream_t stream=0, bool debug_synchronous=false) -descending ------------ +Descending Sort +--------------- +.. doxygenfunction:: rocprim::radix_sort_keys_desc(void *temporary_storage, size_t &storage_size, KeysInputIterator keys_input, KeysOutputIterator keys_output, Size size, unsigned int begin_bit=0, unsigned int end_bit=8 *sizeof(Key), hipStream_t stream=0, bool debug_synchronous=false) +.. doxygenfunction:: rocprim::radix_sort_keys_desc(void *temporary_storage, size_t &storage_size, KeysInputIterator keys_input, KeysOutputIterator keys_output, Size size, Decomposer decomposer, unsigned int begin_bit, unsigned int end_bit, hipStream_t stream=0, bool debug_synchronous=false) +.. doxygenfunction:: rocprim::radix_sort_keys_desc(void *temporary_storage, size_t &storage_size, KeysInputIterator keys_input, KeysOutputIterator keys_output, Size size, Decomposer decomposer, hipStream_t stream=0, bool debug_synchronous=false) .. doxygenfunction:: rocprim::radix_sort_keys_desc(void *temporary_storage, size_t &storage_size, double_buffer< Key > &keys, Size size, unsigned int begin_bit=0, unsigned int end_bit=8 *sizeof(Key), hipStream_t stream=0, bool debug_synchronous=false) +.. doxygenfunction:: rocprim::radix_sort_keys_desc(void *temporary_storage, size_t &storage_size, double_buffer< Key > &keys, Size size, Decomposer decomposer, unsigned int begin_bit, unsigned int end_bit, hipStream_t stream=0, bool debug_synchronous=false) +.. doxygenfunction:: rocprim::radix_sort_keys_desc(void *temporary_storage, size_t &storage_size, double_buffer< Key > &keys, Size size, Decomposer decomposer, hipStream_t stream=0, bool debug_synchronous=false) -segmented, ascending ------------------------ +Segmented Ascending Sort +------------------------ .. doxygenfunction:: rocprim::segmented_radix_sort_keys(void *temporary_storage, size_t &storage_size, KeysInputIterator keys_input, KeysOutputIterator keys_output, unsigned int size, unsigned int segments, OffsetIterator begin_offsets, OffsetIterator end_offsets, unsigned int begin_bit=0, unsigned int end_bit=8 *sizeof(Key), hipStream_t stream=0, bool debug_synchronous=false) -segmented, descending ------------------------ +Segmented Descending Sort +------------------------- .. doxygenfunction:: rocprim::segmented_radix_sort_keys_desc(void *temporary_storage, size_t &storage_size, KeysInputIterator keys_input, KeysOutputIterator keys_output, unsigned int size, unsigned int segments, OffsetIterator begin_offsets, OffsetIterator end_offsets, unsigned int begin_bit=0, unsigned int end_bit=8 *sizeof(Key), hipStream_t stream=0, bool debug_synchronous=false) radix_sort_pairs ==================== -ascending ------------ +Ascending Sort +-------------- .. doxygenfunction:: rocprim::radix_sort_pairs(void *temporary_storage, size_t &storage_size, KeysInputIterator keys_input, KeysOutputIterator keys_output, ValuesInputIterator values_input, ValuesOutputIterator values_output, Size size, unsigned int begin_bit=0, unsigned int end_bit=8 *sizeof(Key), hipStream_t stream=0, bool debug_synchronous=false) +.. doxygenfunction:: rocprim::radix_sort_pairs(void *temporary_storage, size_t &storage_size, KeysInputIterator keys_input, KeysOutputIterator keys_output, ValuesInputIterator values_input, ValuesOutputIterator values_output, Size size, Decomposer decomposer, unsigned int begin_bit, unsigned int end_bit, hipStream_t stream=0, bool debug_synchronous=false) +.. doxygenfunction:: rocprim::radix_sort_pairs(void *temporary_storage, size_t &storage_size, KeysInputIterator keys_input, KeysOutputIterator keys_output, ValuesInputIterator values_input, ValuesOutputIterator values_output, Size size, Decomposer decomposer, hipStream_t stream=0, bool debug_synchronous=false) +.. doxygenfunction:: rocprim::radix_sort_pairs(void *temporary_storage, size_t &storage_size, double_buffer< Key > &keys, double_buffer< Value > &values, Size size, unsigned int begin_bit=0, unsigned int end_bit=8 *sizeof(Key), hipStream_t stream=0, bool debug_synchronous=false) +.. doxygenfunction:: rocprim::radix_sort_pairs(void *temporary_storage, size_t &storage_size, double_buffer< Key > &keys, double_buffer< Value > &values, Size size, Decomposer decomposer, unsigned int begin_bit, unsigned int end_bit, hipStream_t stream=0, bool debug_synchronous=false) +.. doxygenfunction:: rocprim::radix_sort_pairs(void *temporary_storage, size_t &storage_size, double_buffer< Key > &keys, double_buffer< Value > &values, Size size, Decomposer decomposer, hipStream_t stream=0, bool debug_synchronous=false) -descending ----------------- +Descending Sort +--------------- .. doxygenfunction:: rocprim::radix_sort_pairs_desc(void *temporary_storage, size_t &storage_size, KeysInputIterator keys_input, KeysOutputIterator keys_output, ValuesInputIterator values_input, ValuesOutputIterator values_output, Size size, unsigned int begin_bit=0, unsigned int end_bit=8 *sizeof(Key), hipStream_t stream=0, bool debug_synchronous=false) +.. doxygenfunction:: rocprim::radix_sort_pairs_desc(void *temporary_storage, size_t &storage_size, KeysInputIterator keys_input, KeysOutputIterator keys_output, ValuesInputIterator values_input, ValuesOutputIterator values_output, Size size, Decomposer decomposer, unsigned int begin_bit, unsigned int end_bit, hipStream_t stream=0, bool debug_synchronous=false) +.. doxygenfunction:: rocprim::radix_sort_pairs_desc(void *temporary_storage, size_t &storage_size, KeysInputIterator keys_input, KeysOutputIterator keys_output, ValuesInputIterator values_input, ValuesOutputIterator values_output, Size size, Decomposer decomposer, hipStream_t stream=0, bool debug_synchronous=false) +.. doxygenfunction:: rocprim::radix_sort_pairs_desc(void *temporary_storage, size_t &storage_size, double_buffer< Key > &keys, double_buffer< Value > &values, Size size, unsigned int begin_bit=0, unsigned int end_bit=8 *sizeof(Key), hipStream_t stream=0, bool debug_synchronous=false) +.. doxygenfunction:: rocprim::radix_sort_pairs_desc(void *temporary_storage, size_t &storage_size, double_buffer< Key > &keys, double_buffer< Value > &values, Size size, Decomposer decomposer, unsigned int begin_bit, unsigned int end_bit, hipStream_t stream=0, bool debug_synchronous=false) +.. doxygenfunction:: rocprim::radix_sort_pairs_desc(void *temporary_storage, size_t &storage_size, double_buffer< Key > &keys, double_buffer< Value > &values, Size size, Decomposer decomposer, hipStream_t stream=0, bool debug_synchronous=false) -segmented, ascending +Segmented Ascending Sort ------------------------ .. doxygenfunction:: rocprim::segmented_radix_sort_pairs(void *temporary_storage, size_t &storage_size, KeysInputIterator keys_input, KeysOutputIterator keys_output, ValuesInputIterator values_input, ValuesOutputIterator values_output, unsigned int size, unsigned int segments, OffsetIterator begin_offsets, OffsetIterator end_offsets, unsigned int begin_bit=0, unsigned int end_bit=8 *sizeof(Key), hipStream_t stream=0, bool debug_synchronous=false) -segmented, ascending --------------------------- +Segmented Descending Sort +------------------------- .. doxygenfunction:: rocprim::segmented_radix_sort_pairs_desc(void *temporary_storage, size_t &storage_size, KeysInputIterator keys_input, KeysOutputIterator keys_output, ValuesInputIterator values_input, ValuesOutputIterator values_output, unsigned int size, unsigned int segments, OffsetIterator begin_offsets, OffsetIterator end_offsets, unsigned int begin_bit=0, unsigned int end_bit=8 *sizeof(Key), hipStream_t stream=0, bool debug_synchronous=false) diff --git a/docs/doxygen/Doxyfile b/docs/doxygen/Doxyfile index fe2180aa8..5dc7aa7ca 100644 --- a/docs/doxygen/Doxyfile +++ b/docs/doxygen/Doxyfile @@ -774,6 +774,7 @@ WARN_LOGFILE = INPUT = mainpage.dox \ primitivesmodule.dox \ + threadmodule.dox \ warpmodule.dox \ blockmodule.dox \ devicemodule.dox \ diff --git a/docs/doxygen/threadmodule.dox b/docs/doxygen/threadmodule.dox new file mode 100644 index 000000000..3022ee740 --- /dev/null +++ b/docs/doxygen/threadmodule.dox @@ -0,0 +1,11 @@ +/** +@brief rocPRIM Thread-level parallel primitives +@author +@file +*/ + +/** + * \defgroup threadmodule Thread-level + * \ingroup primitivesmodule + * + */ diff --git a/docs/index.rst b/docs/index.rst index 75649aeb3..e1863ab36 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -28,6 +28,7 @@ The rocPRIM documentation is structured as follows: * :ref:`dev-index` * :ref:`block-index` * :ref:`warp-index` + * :ref:`thread-index` * :ref:`thread_ops` * :ref:`iterators` * :ref:`intrinsics` diff --git a/docs/reference/iterators.rst b/docs/reference/iterators.rst index 8c88db069..ef4ee032a 100644 --- a/docs/reference/iterators.rst +++ b/docs/reference/iterators.rst @@ -53,6 +53,20 @@ Transform transform(sequence(1)) ... +Predicate +--------- + +.. doxygenclass:: rocprim::predicate_iterator + :members: + +.. note:: + ``predicate_iterator(sequence, test, predicate)`` generates the sequence:: + + predicate(test[0]) ? sequence[0] : default + predicate(test[1]) ? sequence[1] : default + predicate(test[2]) ? sequence[2] : default + ... + Pairing Values with Indices ============================= diff --git a/docs/reference/reference.rst b/docs/reference/reference.rst index 1b955a7cf..5a8d1d870 100644 --- a/docs/reference/reference.rst +++ b/docs/reference/reference.rst @@ -13,6 +13,7 @@ * :ref:`dev-index` * :ref:`block-index` * :ref:`warp-index` +* :ref:`thread-index` * :ref:`thread_ops` * :ref:`iterators` * :ref:`intrinsics` diff --git a/docs/sphinx/_toc.yml.in b/docs/sphinx/_toc.yml.in index 07afd927c..00507c859 100644 --- a/docs/sphinx/_toc.yml.in +++ b/docs/sphinx/_toc.yml.in @@ -30,6 +30,7 @@ subtrees: - file: device_ops/adjacent_difference.rst - file: device_ops/binary_search.rst - file: device_ops/histogram.rst + - file: device_ops/device_copy.rst - file: device_ops/memcpy.rst - file: block_ops/index.rst subtrees: @@ -58,6 +59,10 @@ subtrees: - file: warp_ops/sort.rst - file: warp_ops/shuffle.rst - file: warp_ops/exchange.rst + - file: thread_ops/index.rst + subtrees: + - entries: + - file: thread_ops/radix_key_codec.rst - file: reference/thread_ops.rst - file: reference/iterators.rst - file: reference/intrinsics.rst diff --git a/docs/thread_ops/index.rst b/docs/thread_ops/index.rst new file mode 100644 index 000000000..380f2df4f --- /dev/null +++ b/docs/thread_ops/index.rst @@ -0,0 +1,11 @@ +.. meta:: + :description: rocPRIM documentation and API reference library + :keywords: rocPRIM, ROCm, API, documentation + +.. _thread-index: + +******************************************************************** + Thread-Level Operations +******************************************************************** + + * :ref:`radix-key-codec` diff --git a/docs/thread_ops/radix_key_codec.rst b/docs/thread_ops/radix_key_codec.rst new file mode 100644 index 000000000..718324651 --- /dev/null +++ b/docs/thread_ops/radix_key_codec.rst @@ -0,0 +1,12 @@ +.. meta:: + :description: rocPRIM documentation and API reference library + :keywords: rocPRIM, ROCm, API, documentation + +.. _radix-key-codec: + +******************************************************************** + Radix Key Encoder/Decoder +******************************************************************** + +.. doxygenclass:: rocprim::radix_key_codec + :members: diff --git a/rocprim/include/rocprim/block/block_histogram.hpp b/rocprim/include/rocprim/block/block_histogram.hpp index 9153b2476..f74c35d5b 100644 --- a/rocprim/include/rocprim/block/block_histogram.hpp +++ b/rocprim/include/rocprim/block/block_histogram.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -177,7 +177,7 @@ class block_histogram /// /// \tparam Counter - [inferred] counter type of histogram. /// - /// \param [in] input - reference to an array containing thread input values. + /// \param [in] input - reference to an array containing thread input values. The function expects each value to satisfy 0 <= input[i] < BINS. /// \param [out] hist - histogram bin count. /// \param [in] storage - reference to a temporary storage object of type storage_type. /// @@ -237,7 +237,7 @@ class block_histogram /// /// \tparam Counter - [inferred] counter type of histogram. /// - /// \param [in] input - reference to an array containing thread input values. + /// \param [in] input - reference to an array containing thread input values. The function expects each value to satisfy 0 <= input[i] < BINS. /// \param [out] hist - histogram bin count. template ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE @@ -252,7 +252,7 @@ class block_histogram /// /// \tparam Counter - [inferred] counter type of histogram. /// - /// \param [in] input - reference to an array containing thread input values. + /// \param [in] input - reference to an array containing thread input values. The function expects each value to satisfy 0 <= input[i] < BINS. /// \param [out] hist - histogram bin count. /// \param [in] storage - reference to a temporary storage object of type storage_type. /// @@ -307,7 +307,7 @@ class block_histogram /// /// \tparam Counter - [inferred] counter type of histogram. /// - /// \param [in] input - reference to an array containing thread input values. + /// \param [in] input - reference to an array containing thread input values. The function expects each value to satisfy 0 <= input[i] < BINS. /// \param [out] hist - histogram bin count. template ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE diff --git a/rocprim/include/rocprim/block/block_radix_rank.hpp b/rocprim/include/rocprim/block/block_radix_rank.hpp index 7f9c378d7..7ffd89bcf 100644 --- a/rocprim/include/rocprim/block/block_radix_rank.hpp +++ b/rocprim/include/rocprim/block/block_radix_rank.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2022 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2022-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -24,8 +24,6 @@ #include "../config.hpp" #include "../functional.hpp" -#include "../detail/radix_sort.hpp" - #include "block_scan.hpp" #include "detail/block_radix_rank_basic.hpp" diff --git a/rocprim/include/rocprim/block/block_radix_sort.hpp b/rocprim/include/rocprim/block/block_radix_sort.hpp index 71c1f37f0..4df792a1b 100644 --- a/rocprim/include/rocprim/block/block_radix_sort.hpp +++ b/rocprim/include/rocprim/block/block_radix_sort.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -25,7 +25,7 @@ #include "../config.hpp" #include "../detail/various.hpp" -#include "../detail/radix_sort.hpp" +#include "../thread/radix_key_codec.hpp" #include "../warp/detail/warp_scan_crosslane.hpp" #include "../intrinsics.hpp" @@ -49,6 +49,7 @@ BEGIN_ROCPRIM_NAMESPACE /// \tparam ItemsPerThread - the number of items contributed by each thread. /// \tparam Value - the value type. Default type empty_type indicates /// a keys-only sort. +/// \tparam RadixBitsPerPass - amount of bits to sort per pass. The Default is 4. /// /// \par Overview /// * \p Key type must be an arithmetic type (that is, an integral type or a floating-point @@ -86,34 +87,36 @@ BEGIN_ROCPRIM_NAMESPACE /// } /// \endcode /// \endparblock -template< - class Key, - unsigned int BlockSizeX, - unsigned int ItemsPerThread, - class Value = empty_type, - unsigned int BlockSizeY = 1, - unsigned int BlockSizeZ = 1 -> +template class block_radix_sort { + static_assert(RadixBitsPerPass > 0 && RadixBitsPerPass < 32, + "The RadixBitsPerPass should be larger than 0 and smaller than the size " + "of an unsigned int"); + static constexpr unsigned int BlockSize = BlockSizeX * BlockSizeY * BlockSizeZ; static constexpr bool with_values = !std::is_same::value; - static constexpr unsigned int radix_bits_per_pass = 4; - using bit_key_type = typename ::rocprim::detail::radix_key_codec::bit_key_type; using block_rank_type = ::rocprim::block_radix_rank; - using bit_keys_exchange_type = ::rocprim::block_exchange; + using keys_exchange_type + = ::rocprim::block_exchange; using values_exchange_type = ::rocprim::block_exchange; // Struct used for creating a raw_storage object for this primitive's temporary storage. union storage_type_ { - typename bit_keys_exchange_type::storage_type bit_keys_exchange; + typename keys_exchange_type::storage_type keys_exchange; typename values_exchange_type::storage_type values_exchange; typename block_rank_type::storage_type rank; }; @@ -136,6 +139,8 @@ class block_radix_sort /// \brief Performs ascending radix sort over keys partitioned across threads in a block. /// + /// \tparam Decomposer The type of the decomposer argument. Defaults to the identity decomposer. + /// /// \param [in, out] keys - reference to an array of keys provided by a thread. /// \param [in] storage - reference to a temporary storage object of type storage_type. /// \param [in] begin_bit - [optional] index of the first (least significant) bit used in @@ -143,6 +148,9 @@ class block_radix_sort /// \param [in] end_bit - [optional] past-the-end index (most significant) bit used in /// key comparison. Must be in range (begin_bit; 8 * sizeof(Key)]. Default /// value: \p 8 * sizeof(Key). + /// \param [in] decomposer [optional] If `Key` is not an arithmetic type (integral, floating point), + /// a custom decomposer functor should be passed that produces a `::rocprim::tuple` of references to + /// fundamental types from this custom type. /// /// \par Storage reusage /// Synchronization barrier should be placed before \p storage is reused @@ -175,14 +183,15 @@ class block_radix_sort /// If the \p input values across threads in a block are {[256, 255], ..., [4, 3], [2, 1]}}, then /// then after sort they will be equal {[1, 2], [3, 4] ..., [255, 256]}. /// \endparblock - ROCPRIM_DEVICE ROCPRIM_INLINE - void sort(Key (&keys)[ItemsPerThread], - storage_type& storage, - unsigned int begin_bit = 0, - unsigned int end_bit = 8 * sizeof(Key)) + template + ROCPRIM_DEVICE ROCPRIM_INLINE void sort(Key (&keys)[ItemsPerThread], + storage_type& storage, + unsigned int begin_bit = 0, + unsigned int end_bit = 8 * sizeof(Key), + Decomposer decomposer = {}) { empty_type values[ItemsPerThread]; - sort_impl(keys, values, storage, begin_bit, end_bit); + sort_impl(keys, values, storage, begin_bit, end_bit, decomposer); } /// \overload @@ -197,17 +206,23 @@ class block_radix_sort /// \param [in] end_bit - [optional] past-the-end index (most significant) bit used in /// key comparison. Must be in range (begin_bit; 8 * sizeof(Key)]. Default /// value: \p 8 * sizeof(Key). - ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE - void sort(Key (&keys)[ItemsPerThread], - unsigned int begin_bit = 0, - unsigned int end_bit = 8 * sizeof(Key)) + /// \param [in] decomposer [optional] If `Key` is not an arithmetic type (integral, floating point), + /// a custom decomposer functor should be passed that produces a `::rocprim::tuple` of references to + /// fundamental types from this custom type. + template + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE void sort(Key (&keys)[ItemsPerThread], + unsigned int begin_bit = 0, + unsigned int end_bit = 8 * sizeof(Key), + Decomposer decomposer = {}) { ROCPRIM_SHARED_MEMORY storage_type storage; - sort(keys, storage, begin_bit, end_bit); + sort(keys, storage, begin_bit, end_bit, decomposer); } /// \brief Performs descending radix sort over keys partitioned across threads in a block. /// + /// \tparam Decomposer The type of the decomposer argument. Defaults to the identity decomposer. + /// /// \param [in, out] keys - reference to an array of keys provided by a thread. /// \param [in] storage - reference to a temporary storage object of type storage_type. /// \param [in] begin_bit - [optional] index of the first (least significant) bit used in @@ -215,6 +230,9 @@ class block_radix_sort /// \param [in] end_bit - [optional] past-the-end index (most significant) bit used in /// key comparison. Must be in range (begin_bit; 8 * sizeof(Key)]. Default /// value: \p 8 * sizeof(Key). + /// \param [in] decomposer [optional] If `Key` is not an arithmetic type (integral, floating point), + /// a custom decomposer functor should be passed that produces a `::rocprim::tuple` of references to + /// fundamental types from this custom type. /// /// \par Storage reusage /// Synchronization barrier should be placed before \p storage is reused @@ -247,14 +265,15 @@ class block_radix_sort /// If the \p input values across threads in a block are {[1, 2], [3, 4] ..., [255, 256]}, /// then after sort they will be equal {[256, 255], ..., [4, 3], [2, 1]}. /// \endparblock - ROCPRIM_DEVICE ROCPRIM_INLINE - void sort_desc(Key (&keys)[ItemsPerThread], - storage_type& storage, - unsigned int begin_bit = 0, - unsigned int end_bit = 8 * sizeof(Key)) + template + ROCPRIM_DEVICE ROCPRIM_INLINE void sort_desc(Key (&keys)[ItemsPerThread], + storage_type& storage, + unsigned int begin_bit = 0, + unsigned int end_bit = 8 * sizeof(Key), + Decomposer decomposer = {}) { empty_type values[ItemsPerThread]; - sort_impl(keys, values, storage, begin_bit, end_bit); + sort_impl(keys, values, storage, begin_bit, end_bit, decomposer); } /// \overload @@ -263,19 +282,25 @@ class block_radix_sort /// * This overload does not accept storage argument. Required shared memory is /// allocated by the method itself. /// + /// \tparam Decomposer The type of the decomposer argument. Defaults to the identity decomposer. + /// /// \param [in, out] keys - reference to an array of keys provided by a thread. /// \param [in] begin_bit - [optional] index of the first (least significant) bit used in /// key comparison. Must be in range [0; 8 * sizeof(Key)). Default value: \p 0. /// \param [in] end_bit - [optional] past-the-end index (most significant) bit used in /// key comparison. Must be in range (begin_bit; 8 * sizeof(Key)]. Default /// value: \p 8 * sizeof(Key). - ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE - void sort_desc(Key (&keys)[ItemsPerThread], - unsigned int begin_bit = 0, - unsigned int end_bit = 8 * sizeof(Key)) + /// \param [in] decomposer [optional] If `Key` is not an arithmetic type (integral, floating point), + /// a custom decomposer functor should be passed that produces a `::rocprim::tuple` of references to + /// fundamental types from this custom type. + template + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE void sort_desc(Key (&keys)[ItemsPerThread], + unsigned int begin_bit = 0, + unsigned int end_bit = 8 * sizeof(Key), + Decomposer decomposer = {}) { ROCPRIM_SHARED_MEMORY storage_type storage; - sort_desc(keys, storage, begin_bit, end_bit); + sort_desc(keys, storage, begin_bit, end_bit, decomposer); } /// \brief Performs ascending radix sort over key-value pairs partitioned across @@ -283,6 +308,8 @@ class block_radix_sort /// /// \pre Method is enabled only if \p Value type is different than empty_type. /// + /// \tparam Decomposer The type of the decomposer argument. Defaults to the identity decomposer. + /// /// \param [in, out] keys - reference to an array of keys provided by a thread. /// \param [in, out] values - reference to an array of values provided by a thread. /// \param [in] storage - reference to a temporary storage object of type storage_type. @@ -291,6 +318,9 @@ class block_radix_sort /// \param [in] end_bit - [optional] past-the-end index (most significant) bit used in /// key comparison. Must be in range (begin_bit; 8 * sizeof(Key)]. Default /// value: \p 8 * sizeof(Key). + /// \param [in] decomposer [optional] If `Key` is not an arithmetic type (integral, floating point), + /// a custom decomposer functor should be passed that produces a `::rocprim::tuple` of references to + /// fundamental types from this custom type. /// /// \par Storage reusage /// Synchronization barrier should be placed before \p storage is reused @@ -327,15 +357,16 @@ class block_radix_sort /// will be equal {[1, 2], [3, 4] ..., [255, 256]} and the \p values will be /// equal {[128, 128], [127, 127] ..., [2, 2], [1, 1]}. /// \endparblock - template - ROCPRIM_DEVICE ROCPRIM_INLINE - void sort(Key (&keys)[ItemsPerThread], - typename std::enable_if::type (&values)[ItemsPerThread], - storage_type& storage, - unsigned int begin_bit = 0, - unsigned int end_bit = 8 * sizeof(Key)) + template + ROCPRIM_DEVICE ROCPRIM_INLINE void + sort(Key (&keys)[ItemsPerThread], + typename std::enable_if::type (&values)[ItemsPerThread], + storage_type& storage, + unsigned int begin_bit = 0, + unsigned int end_bit = 8 * sizeof(Key), + Decomposer decomposer = {}) { - sort_impl(keys, values, storage, begin_bit, end_bit); + sort_impl(keys, values, storage, begin_bit, end_bit, decomposer); } /// \overload @@ -347,6 +378,8 @@ class block_radix_sort /// /// \pre Method is enabled only if \p Value type is different than empty_type. /// + /// \tparam Decomposer The type of the decomposer argument. Defaults to the identity decomposer. + /// /// \param [in, out] keys - reference to an array of keys provided by a thread. /// \param [in, out] values - reference to an array of values provided by a thread. /// \param [in] begin_bit - [optional] index of the first (least significant) bit used in @@ -354,15 +387,19 @@ class block_radix_sort /// \param [in] end_bit - [optional] past-the-end index (most significant) bit used in /// key comparison. Must be in range (begin_bit; 8 * sizeof(Key)]. Default /// value: \p 8 * sizeof(Key). - template - ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE - void sort(Key (&keys)[ItemsPerThread], - typename std::enable_if::type (&values)[ItemsPerThread], - unsigned int begin_bit = 0, - unsigned int end_bit = 8 * sizeof(Key)) + /// \param [in] decomposer [optional] If `Key` is not an arithmetic type (integral, floating point), + /// a custom decomposer functor should be passed that produces a `::rocprim::tuple` of references to + /// fundamental types from this custom type. + template + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE void + sort(Key (&keys)[ItemsPerThread], + typename std::enable_if::type (&values)[ItemsPerThread], + unsigned int begin_bit = 0, + unsigned int end_bit = 8 * sizeof(Key), + Decomposer decomposer = {}) { ROCPRIM_SHARED_MEMORY storage_type storage; - sort(keys, values, storage, begin_bit, end_bit); + sort(keys, values, storage, begin_bit, end_bit, decomposer); } /// \brief Performs descending radix sort over key-value pairs partitioned across @@ -370,6 +407,8 @@ class block_radix_sort /// /// \pre Method is enabled only if \p Value type is different than empty_type. /// + /// \tparam Decomposer The type of the decomposer argument. Defaults to the identity decomposer. + /// /// \param [in, out] keys - reference to an array of keys provided by a thread. /// \param [in, out] values - reference to an array of values provided by a thread. /// \param [in] storage - reference to a temporary storage object of type storage_type. @@ -378,6 +417,9 @@ class block_radix_sort /// \param [in] end_bit - [optional] past-the-end index (most significant) bit used in /// key comparison. Must be in range (begin_bit; 8 * sizeof(Key)]. Default /// value: \p 8 * sizeof(Key). + /// \param [in] decomposer [optional] If `Key` is not an arithmetic type (integral, floating point), + /// a custom decomposer functor should be passed that produces a `::rocprim::tuple` of references to + /// fundamental types from this custom type. /// /// \par Storage reusage /// Synchronization barrier should be placed before \p storage is reused @@ -414,15 +456,16 @@ class block_radix_sort /// the \p keys will be equal {[256, 255], ..., [4, 3], [2, 1]} and the \p values /// will be equal {[1, 1], [2, 2] ..., [128, 128]}. /// \endparblock - template - ROCPRIM_DEVICE ROCPRIM_INLINE - void sort_desc(Key (&keys)[ItemsPerThread], - typename std::enable_if::type (&values)[ItemsPerThread], - storage_type& storage, - unsigned int begin_bit = 0, - unsigned int end_bit = 8 * sizeof(Key)) + template + ROCPRIM_DEVICE ROCPRIM_INLINE void + sort_desc(Key (&keys)[ItemsPerThread], + typename std::enable_if::type (&values)[ItemsPerThread], + storage_type& storage, + unsigned int begin_bit = 0, + unsigned int end_bit = 8 * sizeof(Key), + Decomposer decomposer = {}) { - sort_impl(keys, values, storage, begin_bit, end_bit); + sort_impl(keys, values, storage, begin_bit, end_bit, decomposer); } /// \overload @@ -434,6 +477,8 @@ class block_radix_sort /// /// \pre Method is enabled only if \p Value type is different than empty_type. /// + /// \tparam Decomposer The type of the decomposer argument. Defaults to the identity decomposer. + /// /// \param [in, out] keys - reference to an array of keys provided by a thread. /// \param [in, out] values - reference to an array of values provided by a thread. /// \param [in] begin_bit - [optional] index of the first (least significant) bit used in @@ -441,20 +486,26 @@ class block_radix_sort /// \param [in] end_bit - [optional] past-the-end index (most significant) bit used in /// key comparison. Must be in range (begin_bit; 8 * sizeof(Key)]. Default /// value: \p 8 * sizeof(Key). - template - ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE - void sort_desc(Key (&keys)[ItemsPerThread], - typename std::enable_if::type (&values)[ItemsPerThread], - unsigned int begin_bit = 0, - unsigned int end_bit = 8 * sizeof(Key)) + /// \param [in] decomposer [optional] If `Key` is not an arithmetic type (integral, floating point), + /// a custom decomposer functor should be passed that produces a `::rocprim::tuple` of references to + /// fundamental types from this custom type. + template + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE void + sort_desc(Key (&keys)[ItemsPerThread], + typename std::enable_if::type (&values)[ItemsPerThread], + unsigned int begin_bit = 0, + unsigned int end_bit = 8 * sizeof(Key), + Decomposer decomposer = {}) { ROCPRIM_SHARED_MEMORY storage_type storage; - sort_desc(keys, values, storage, begin_bit, end_bit); + sort_desc(keys, values, storage, begin_bit, end_bit, decomposer); } /// \brief Performs ascending radix sort over keys partitioned across threads in a block, /// results are saved in a striped arrangement. /// + /// \tparam Decomposer The type of the decomposer argument. Defaults to the identity decomposer. + /// /// \param [in, out] keys - reference to an array of keys provided by a thread. /// \param [in] storage - reference to a temporary storage object of type storage_type. /// \param [in] begin_bit - [optional] index of the first (least significant) bit used in @@ -462,6 +513,9 @@ class block_radix_sort /// \param [in] end_bit - [optional] past-the-end index (most significant) bit used in /// key comparison. Must be in range (begin_bit; 8 * sizeof(Key)]. Default /// value: \p 8 * sizeof(Key). + /// \param [in] decomposer [optional] If `Key` is not an arithmetic type (integral, floating point), + /// a custom decomposer functor should be passed that produces a `::rocprim::tuple` of references to + /// fundamental types from this custom type. /// /// \par Storage reusage /// Synchronization barrier should be placed before \p storage is reused @@ -494,14 +548,15 @@ class block_radix_sort /// If the \p input values across threads in a block are {[256, 255], ..., [4, 3], [2, 1]}}, then /// then after sort they will be equal {[1, 129], [2, 130] ..., [128, 256]}. /// \endparblock - ROCPRIM_DEVICE ROCPRIM_INLINE - void sort_to_striped(Key (&keys)[ItemsPerThread], - storage_type& storage, - unsigned int begin_bit = 0, - unsigned int end_bit = 8 * sizeof(Key)) + template + ROCPRIM_DEVICE ROCPRIM_INLINE void sort_to_striped(Key (&keys)[ItemsPerThread], + storage_type& storage, + unsigned int begin_bit = 0, + unsigned int end_bit = 8 * sizeof(Key), + Decomposer decomposer = {}) { empty_type values[ItemsPerThread]; - sort_impl(keys, values, storage, begin_bit, end_bit); + sort_impl(keys, values, storage, begin_bit, end_bit, decomposer); } /// \overload @@ -511,24 +566,32 @@ class block_radix_sort /// * This overload does not accept storage argument. Required shared memory is /// allocated by the method itself. /// + /// \tparam Decomposer The type of the decomposer argument. Defaults to the identity decomposer. + /// /// \param [in, out] keys - reference to an array of keys provided by a thread. /// \param [in] begin_bit - [optional] index of the first (least significant) bit used in /// key comparison. Must be in range [0; 8 * sizeof(Key)). Default value: \p 0. /// \param [in] end_bit - [optional] past-the-end index (most significant) bit used in /// key comparison. Must be in range (begin_bit; 8 * sizeof(Key)]. Default /// value: \p 8 * sizeof(Key). - ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE - void sort_to_striped(Key (&keys)[ItemsPerThread], - unsigned int begin_bit = 0, - unsigned int end_bit = 8 * sizeof(Key)) + /// \param [in] decomposer [optional] If `Key` is not an arithmetic type (integral, floating point), + /// a custom decomposer functor should be passed that produces a `::rocprim::tuple` of references to + /// fundamental types from this custom type. + template + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE void sort_to_striped(Key (&keys)[ItemsPerThread], + unsigned int begin_bit = 0, + unsigned int end_bit = 8 * sizeof(Key), + Decomposer decomposer = {}) { ROCPRIM_SHARED_MEMORY storage_type storage; - sort_to_striped(keys, storage, begin_bit, end_bit); + sort_to_striped(keys, storage, begin_bit, end_bit, decomposer); } /// \brief Performs descending radix sort over keys partitioned across threads in a block, /// results are saved in a striped arrangement. /// + /// \tparam Decomposer The type of the decomposer argument. Defaults to the identity decomposer. + /// /// \param [in, out] keys - reference to an array of keys provided by a thread. /// \param [in] storage - reference to a temporary storage object of type storage_type. /// \param [in] begin_bit - [optional] index of the first (least significant) bit used in @@ -536,6 +599,9 @@ class block_radix_sort /// \param [in] end_bit - [optional] past-the-end index (most significant) bit used in /// key comparison. Must be in range (begin_bit; 8 * sizeof(Key)]. Default /// value: \p 8 * sizeof(Key). + /// \param [in] decomposer [optional] If `Key` is not an arithmetic type (integral, floating point), + /// a custom decomposer functor should be passed that produces a `::rocprim::tuple` of references to + /// fundamental types from this custom type. /// /// \par Storage reusage /// Synchronization barrier should be placed before \p storage is reused @@ -568,14 +634,15 @@ class block_radix_sort /// If the \p input values across threads in a block are {[1, 2], [3, 4] ..., [255, 256]}, /// then after sort they will be equal {[256, 128], ..., [130, 2], [129, 1]}. /// \endparblock - ROCPRIM_DEVICE ROCPRIM_INLINE - void sort_desc_to_striped(Key (&keys)[ItemsPerThread], - storage_type& storage, - unsigned int begin_bit = 0, - unsigned int end_bit = 8 * sizeof(Key)) + template + ROCPRIM_DEVICE ROCPRIM_INLINE void sort_desc_to_striped(Key (&keys)[ItemsPerThread], + storage_type& storage, + unsigned int begin_bit = 0, + unsigned int end_bit = 8 * sizeof(Key), + Decomposer decomposer = {}) { empty_type values[ItemsPerThread]; - sort_impl(keys, values, storage, begin_bit, end_bit); + sort_impl(keys, values, storage, begin_bit, end_bit, decomposer); } /// \overload @@ -585,19 +652,26 @@ class block_radix_sort /// * This overload does not accept storage argument. Required shared memory is /// allocated by the method itself. /// + /// \tparam Decomposer The type of the decomposer argument. Defaults to the identity decomposer. + /// /// \param [in, out] keys - reference to an array of keys provided by a thread. /// \param [in] begin_bit - [optional] index of the first (least significant) bit used in /// key comparison. Must be in range [0; 8 * sizeof(Key)). Default value: \p 0. /// \param [in] end_bit - [optional] past-the-end index (most significant) bit used in /// key comparison. Must be in range (begin_bit; 8 * sizeof(Key)]. Default /// value: \p 8 * sizeof(Key). - ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE - void sort_desc_to_striped(Key (&keys)[ItemsPerThread], - unsigned int begin_bit = 0, - unsigned int end_bit = 8 * sizeof(Key)) + /// \param [in] decomposer [optional] If `Key` is not an arithmetic type (integral, floating point), + /// a custom decomposer functor should be passed that produces a `::rocprim::tuple` of references to + /// fundamental types from this custom type. + template + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE void sort_desc_to_striped(Key (&keys)[ItemsPerThread], + unsigned int begin_bit = 0, + unsigned int end_bit + = 8 * sizeof(Key), + Decomposer decomposer = {}) { ROCPRIM_SHARED_MEMORY storage_type storage; - sort_desc_to_striped(keys, storage, begin_bit, end_bit); + sort_desc_to_striped(keys, storage, begin_bit, end_bit, decomposer); } /// \brief Performs ascending radix sort over key-value pairs partitioned across @@ -605,6 +679,8 @@ class block_radix_sort /// /// \pre Method is enabled only if \p Value type is different than empty_type. /// + /// \tparam Decomposer The type of the decomposer argument. Defaults to the identity decomposer. + /// /// \param [in, out] keys - reference to an array of keys provided by a thread. /// \param [in, out] values - reference to an array of values provided by a thread. /// \param [in] storage - reference to a temporary storage object of type storage_type. @@ -613,6 +689,9 @@ class block_radix_sort /// \param [in] end_bit - [optional] past-the-end index (most significant) bit used in /// key comparison. Must be in range (begin_bit; 8 * sizeof(Key)]. Default /// value: \p 8 * sizeof(Key). + /// \param [in] decomposer [optional] If `Key` is not an arithmetic type (integral, floating point), + /// a custom decomposer functor should be passed that produces a `::rocprim::tuple` of references to + /// fundamental types from this custom type. /// /// \par Storage reusage /// Synchronization barrier should be placed before \p storage is reused @@ -649,15 +728,16 @@ class block_radix_sort /// \p keys will be equal {[1, 5], [2, 6], [3, 7], [4, 8]} and the \p values will be /// equal {[-8, -4], [-7, -3], [-6, -2], [-5, -1]}. /// \endparblock - template - ROCPRIM_DEVICE ROCPRIM_INLINE - void sort_to_striped(Key (&keys)[ItemsPerThread], - typename std::enable_if::type (&values)[ItemsPerThread], - storage_type& storage, - unsigned int begin_bit = 0, - unsigned int end_bit = 8 * sizeof(Key)) + template + ROCPRIM_DEVICE ROCPRIM_INLINE void + sort_to_striped(Key (&keys)[ItemsPerThread], + typename std::enable_if::type (&values)[ItemsPerThread], + storage_type& storage, + unsigned int begin_bit = 0, + unsigned int end_bit = 8 * sizeof(Key), + Decomposer decomposer = {}) { - sort_impl(keys, values, storage, begin_bit, end_bit); + sort_impl(keys, values, storage, begin_bit, end_bit, decomposer); } /// \overload @@ -667,6 +747,8 @@ class block_radix_sort /// * This overload does not accept storage argument. Required shared memory is /// allocated by the method itself. /// + /// \tparam Decomposer The type of the decomposer argument. Defaults to the identity decomposer. + /// /// \param [in, out] keys - reference to an array of keys provided by a thread. /// \param [in, out] values - reference to an array of values provided by a thread. /// \param [in] begin_bit - [optional] index of the first (least significant) bit used in @@ -674,15 +756,19 @@ class block_radix_sort /// \param [in] end_bit - [optional] past-the-end index (most significant) bit used in /// key comparison. Must be in range (begin_bit; 8 * sizeof(Key)]. Default /// value: \p 8 * sizeof(Key). - template - ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE - void sort_to_striped(Key (&keys)[ItemsPerThread], - typename std::enable_if::type (&values)[ItemsPerThread], - unsigned int begin_bit = 0, - unsigned int end_bit = 8 * sizeof(Key)) + /// \param [in] decomposer [optional] If `Key` is not an arithmetic type (integral, floating point), + /// a custom decomposer functor should be passed that produces a `::rocprim::tuple` of references to + /// fundamental types from this custom type. + template + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE void + sort_to_striped(Key (&keys)[ItemsPerThread], + typename std::enable_if::type (&values)[ItemsPerThread], + unsigned int begin_bit = 0, + unsigned int end_bit = 8 * sizeof(Key), + Decomposer decomposer = {}) { ROCPRIM_SHARED_MEMORY storage_type storage; - sort_to_striped(keys, values, storage, begin_bit, end_bit); + sort_to_striped(keys, values, storage, begin_bit, end_bit, decomposer); } /// \brief Performs descending radix sort over key-value pairs partitioned across @@ -690,6 +776,8 @@ class block_radix_sort /// /// \pre Method is enabled only if \p Value type is different than empty_type. /// + /// \tparam Decomposer The type of the decomposer argument. Defaults to the identity decomposer. + /// /// \param [in, out] keys - reference to an array of keys provided by a thread. /// \param [in, out] values - reference to an array of values provided by a thread. /// \param [in] storage - reference to a temporary storage object of type storage_type. @@ -698,6 +786,9 @@ class block_radix_sort /// \param [in] end_bit - [optional] past-the-end index (most significant) bit used in /// key comparison. Must be in range (begin_bit; 8 * sizeof(Key)]. Default /// value: \p 8 * sizeof(Key). + /// \param [in] decomposer [optional] If `Key` is not an arithmetic type (integral, floating point), + /// a custom decomposer functor should be passed that produces a `::rocprim::tuple` of references to + /// fundamental types from this custom type. /// /// \par Storage reusage /// Synchronization barrier should be placed before \p storage is reused @@ -734,15 +825,16 @@ class block_radix_sort /// \p keys will be equal {[8, 4], [7, 3], [6, 2], [5, 1]} and the \p values will be /// equal {[10, 50], [20, 60], [30, 70], [40, 80]}. /// \endparblock - template - ROCPRIM_DEVICE ROCPRIM_INLINE - void sort_desc_to_striped(Key (&keys)[ItemsPerThread], - typename std::enable_if::type (&values)[ItemsPerThread], - storage_type& storage, - unsigned int begin_bit = 0, - unsigned int end_bit = 8 * sizeof(Key)) + template + ROCPRIM_DEVICE ROCPRIM_INLINE void sort_desc_to_striped( + Key (&keys)[ItemsPerThread], + typename std::enable_if::type (&values)[ItemsPerThread], + storage_type& storage, + unsigned int begin_bit = 0, + unsigned int end_bit = 8 * sizeof(Key), + Decomposer decomposer = {}) { - sort_impl(keys, values, storage, begin_bit, end_bit); + sort_impl(keys, values, storage, begin_bit, end_bit, decomposer); } /// \overload @@ -752,6 +844,8 @@ class block_radix_sort /// * This overload does not accept storage argument. Required shared memory is /// allocated by the method itself. /// + /// \tparam Decomposer The type of the decomposer argument. Defaults to the identity decomposer. + /// /// \param [in, out] keys - reference to an array of keys provided by a thread. /// \param [in, out] values - reference to an array of values provided by a thread. /// \param [in] begin_bit - [optional] index of the first (least significant) bit used in @@ -759,80 +853,83 @@ class block_radix_sort /// \param [in] end_bit - [optional] past-the-end index (most significant) bit used in /// key comparison. Must be in range (begin_bit; 8 * sizeof(Key)]. Default /// value: \p 8 * sizeof(Key). - template - ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE - void sort_desc_to_striped(Key (&keys)[ItemsPerThread], - typename std::enable_if::type (&values)[ItemsPerThread], - unsigned int begin_bit = 0, - unsigned int end_bit = 8 * sizeof(Key)) + /// \param [in] decomposer [optional] If `Key` is not an arithmetic type (integral, floating point), + /// a custom decomposer functor should be passed that produces a `::rocprim::tuple` of references to + /// fundamental types from this custom type. + template + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE void sort_desc_to_striped( + Key (&keys)[ItemsPerThread], + typename std::enable_if::type (&values)[ItemsPerThread], + unsigned int begin_bit = 0, + unsigned int end_bit = 8 * sizeof(Key), + Decomposer decomposer = {}) { ROCPRIM_SHARED_MEMORY storage_type storage; - sort_desc_to_striped(keys, values, storage, begin_bit, end_bit); + sort_desc_to_striped(keys, values, storage, begin_bit, end_bit, decomposer); } private: - - template - ROCPRIM_DEVICE ROCPRIM_INLINE - void sort_impl(Key (&keys)[ItemsPerThread], - SortedValue (&values)[ItemsPerThread], - storage_type& storage, - unsigned int begin_bit, - unsigned int end_bit) + template + ROCPRIM_DEVICE ROCPRIM_INLINE void sort_impl(Key (&keys)[ItemsPerThread], + SortedValue (&values)[ItemsPerThread], + storage_type& storage, + unsigned int begin_bit, + unsigned int end_bit, + Decomposer decomposer) { - using key_codec = ::rocprim::detail::radix_key_codec; + using key_codec = ::rocprim::radix_key_codec; - bit_key_type bit_keys[ItemsPerThread]; ROCPRIM_UNROLL for(unsigned int i = 0; i < ItemsPerThread; i++) { - bit_keys[i] = key_codec::encode(keys[i]); + key_codec::encode_inplace(keys[i], decomposer); } while(true) { - const int pass_bits = min(radix_bits_per_pass, end_bit - begin_bit); + const int pass_bits = min(RadixBitsPerPass, end_bit - begin_bit); unsigned int ranks[ItemsPerThread]; block_rank_type().rank_keys( - bit_keys, + keys, ranks, storage.get().rank, - [begin_bit, pass_bits](const bit_key_type& key) - { return key_codec::extract_digit(key, begin_bit, pass_bits); }); - begin_bit += radix_bits_per_pass; + [begin_bit, pass_bits, decomposer](const Key& key) mutable + { return key_codec::extract_digit(key, begin_bit, pass_bits, decomposer); }); + begin_bit += RadixBitsPerPass; - exchange_keys(storage, bit_keys, ranks); + exchange_keys(storage, keys, ranks); exchange_values(storage, values, ranks); if(begin_bit >= end_bit) + { break; + } - // Synchronization required to make bock_rank wait on the next iteration. + // Synchronization required to make block_rank wait on the next iteration. ::rocprim::syncthreads(); } if ROCPRIM_IF_CONSTEXPR(ToStriped) { - to_striped_keys(storage, bit_keys); + to_striped_keys(storage, keys); to_striped_values(storage, values); } ROCPRIM_UNROLL for(unsigned int i = 0; i < ItemsPerThread; i++) { - keys[i] = key_codec::decode(bit_keys[i]); + key_codec::decode_inplace(keys[i], decomposer); } } - ROCPRIM_DEVICE ROCPRIM_INLINE - void exchange_keys(storage_type& storage, - bit_key_type (&bit_keys)[ItemsPerThread], - const unsigned int (&ranks)[ItemsPerThread]) + ROCPRIM_DEVICE ROCPRIM_INLINE void exchange_keys(storage_type& storage, + Key (&keys)[ItemsPerThread], + const unsigned int (&ranks)[ItemsPerThread]) { storage_type_& storage_ = storage.get(); ::rocprim::syncthreads(); // Storage will be reused (union), synchronization is needed - bit_keys_exchange_type().scatter_to_blocked(bit_keys, bit_keys, ranks, storage_.bit_keys_exchange); + keys_exchange_type().scatter_to_blocked(keys, keys, ranks, storage_.keys_exchange); } template @@ -856,13 +953,12 @@ class block_radix_sort (void) ranks; } - ROCPRIM_DEVICE ROCPRIM_INLINE - void to_striped_keys(storage_type& storage, - bit_key_type (&bit_keys)[ItemsPerThread]) + ROCPRIM_DEVICE ROCPRIM_INLINE void to_striped_keys(storage_type& storage, + Key (&keys)[ItemsPerThread]) { storage_type_& storage_ = storage.get(); ::rocprim::syncthreads(); // Storage will be reused (union), synchronization is needed - bit_keys_exchange_type().blocked_to_striped(bit_keys, bit_keys, storage_.bit_keys_exchange); + keys_exchange_type().blocked_to_striped(keys, keys, storage_.keys_exchange); } template @@ -884,18 +980,6 @@ class block_radix_sort } }; -#ifndef DOXYGEN_SHOULD_SKIP_THIS -template -constexpr unsigned int - block_radix_sort:: - radix_bits_per_pass; -#endif - END_ROCPRIM_NAMESPACE /// @} diff --git a/rocprim/include/rocprim/block/detail/block_radix_rank_basic.hpp b/rocprim/include/rocprim/block/detail/block_radix_rank_basic.hpp index c1ac212b3..9cc8c0fce 100644 --- a/rocprim/include/rocprim/block/detail/block_radix_rank_basic.hpp +++ b/rocprim/include/rocprim/block/detail/block_radix_rank_basic.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2022 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2022-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -24,8 +24,7 @@ #include "../../config.hpp" #include "../../detail/various.hpp" #include "../../functional.hpp" - -#include "../../detail/radix_sort.hpp" +#include "../../thread/radix_key_codec.hpp" #include "../block_scan.hpp" @@ -198,7 +197,7 @@ class block_radix_rank const unsigned int begin_bit, const unsigned int pass_bits) { - using key_codec = ::rocprim::detail::radix_key_codec; + using key_codec = ::rocprim::radix_key_codec; using bit_key_type = typename key_codec::bit_key_type; bit_key_type bit_keys[ItemsPerThread]; diff --git a/rocprim/include/rocprim/block/detail/block_radix_rank_match.hpp b/rocprim/include/rocprim/block/detail/block_radix_rank_match.hpp index abcdb3257..e3eca7808 100644 --- a/rocprim/include/rocprim/block/detail/block_radix_rank_match.hpp +++ b/rocprim/include/rocprim/block/detail/block_radix_rank_match.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2022-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2022-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -26,7 +26,7 @@ #include "../../functional.hpp" #include "../../types.hpp" -#include "../../detail/radix_sort.hpp" +#include "../../thread/radix_key_codec.hpp" #include "../block_scan.hpp" @@ -170,7 +170,7 @@ class block_radix_rank_match const unsigned int begin_bit, const unsigned int pass_bits) { - using key_codec = ::rocprim::detail::radix_key_codec; + using key_codec = ::rocprim::radix_key_codec; using bit_key_type = typename key_codec::bit_key_type; bit_key_type bit_keys[ItemsPerThread]; diff --git a/rocprim/include/rocprim/block/detail/block_sort_bitonic.hpp b/rocprim/include/rocprim/block/detail/block_sort_bitonic.hpp index 503d742f9..be38e50c2 100644 --- a/rocprim/include/rocprim/block/detail/block_sort_bitonic.hpp +++ b/rocprim/include/rocprim/block/detail/block_sort_bitonic.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -26,8 +26,8 @@ #include "../../config.hpp" #include "../../detail/various.hpp" -#include "../../intrinsics.hpp" #include "../../functional.hpp" +#include "../../intrinsics.hpp" #include "../../warp/warp_sort.hpp" @@ -36,14 +36,12 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { -template< - class Key, - unsigned int BlockSizeX, - unsigned int BlockSizeY, - unsigned int BlockSizeZ, - unsigned int ItemsPerThread, - class Value -> +template class block_sort_bitonic { static constexpr unsigned int BlockSize = BlockSizeX * BlockSizeY * BlockSizeZ; @@ -65,11 +63,9 @@ class block_sort_bitonic public: using storage_type = detail::raw_storage>; - template - ROCPRIM_DEVICE ROCPRIM_INLINE - void sort(Key& thread_key, - storage_type& storage, - BinaryFunction compare_function) + template + ROCPRIM_DEVICE ROCPRIM_INLINE void + sort(Key& thread_key, storage_type& storage, BinaryFunction compare_function) { this->sort_impl( ::rocprim::flat_block_thread_id(), @@ -79,10 +75,9 @@ class block_sort_bitonic } template - ROCPRIM_DEVICE ROCPRIM_INLINE - void sort(Key (&thread_keys)[ItemsPerThread], - storage_type& storage, - BinaryFunction compare_function) + ROCPRIM_DEVICE ROCPRIM_INLINE void sort(Key (&thread_keys)[ItemsPerThread], + storage_type& storage, + BinaryFunction compare_function) { this->sort_impl( ::rocprim::flat_block_thread_id(), @@ -92,29 +87,25 @@ class block_sort_bitonic } template - ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE - void sort(Key& thread_key, - BinaryFunction compare_function) + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE void sort(Key& thread_key, BinaryFunction compare_function) { ROCPRIM_SHARED_MEMORY storage_type storage; this->sort(thread_key, storage, compare_function); } template - ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE - void sort(Key (&thread_keys)[ItemsPerThread], - BinaryFunction compare_function) + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE void sort(Key (&thread_keys)[ItemsPerThread], + BinaryFunction compare_function) { ROCPRIM_SHARED_MEMORY storage_type storage; this->sort(thread_keys, storage, compare_function); } template - ROCPRIM_DEVICE ROCPRIM_INLINE - void sort(Key& thread_key, - Value& thread_value, - storage_type& storage, - BinaryFunction compare_function) + ROCPRIM_DEVICE ROCPRIM_INLINE void sort(Key& thread_key, + Value& thread_value, + storage_type& storage, + BinaryFunction compare_function) { this->sort_impl( ::rocprim::flat_block_thread_id(), @@ -139,20 +130,17 @@ class block_sort_bitonic } template - ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE - void sort(Key& thread_key, - Value& thread_value, - BinaryFunction compare_function) + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE void + sort(Key& thread_key, Value& thread_value, BinaryFunction compare_function) { ROCPRIM_SHARED_MEMORY storage_type storage; this->sort(thread_key, thread_value, storage, compare_function); } template - ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE - void sort(Key (&thread_keys)[ItemsPerThread], - Value (&thread_values)[ItemsPerThread], - BinaryFunction compare_function) + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE void sort(Key (&thread_keys)[ItemsPerThread], + Value (&thread_values)[ItemsPerThread], + BinaryFunction compare_function) { ROCPRIM_SHARED_MEMORY storage_type storage; this->sort(thread_keys, thread_values, storage, compare_function); @@ -215,42 +203,44 @@ class block_sort_bitonic } private: - ROCPRIM_DEVICE ROCPRIM_INLINE - void copy_to_shared(Key& k, const unsigned int flat_tid, storage_type& storage) + ROCPRIM_DEVICE ROCPRIM_INLINE void + copy_to_shared(Key& k, const unsigned int flat_tid, storage_type& storage) { storage_type_& storage_ = storage.get(); - storage_.key[flat_tid] = k; + storage_.key[flat_tid] = k; ::rocprim::syncthreads(); } - ROCPRIM_DEVICE ROCPRIM_INLINE - void copy_to_shared(Key (&k)[ItemsPerThread], const unsigned int flat_tid, storage_type& storage) { + ROCPRIM_DEVICE ROCPRIM_INLINE void + copy_to_shared(Key (&k)[ItemsPerThread], const unsigned int flat_tid, storage_type& storage) + { storage_type_& storage_ = storage.get(); ROCPRIM_UNROLL - for(unsigned int item = 0; item < ItemsPerThread; ++item) { + for(unsigned int item = 0; item < ItemsPerThread; ++item) + { storage_.key[item * BlockSize + flat_tid] = k[item]; } ::rocprim::syncthreads(); } - ROCPRIM_DEVICE ROCPRIM_INLINE - void copy_to_shared(Key& k, Value& v, const unsigned int flat_tid, storage_type& storage) + ROCPRIM_DEVICE ROCPRIM_INLINE void + copy_to_shared(Key& k, Value& v, const unsigned int flat_tid, storage_type& storage) { storage_type_& storage_ = storage.get(); - storage_.key[flat_tid] = k; - storage_.value[flat_tid] = v; + storage_.key[flat_tid] = k; + storage_.value[flat_tid] = v; ::rocprim::syncthreads(); } - ROCPRIM_DEVICE ROCPRIM_INLINE - void copy_to_shared(Key (&k)[ItemsPerThread], - Value (&v)[ItemsPerThread], - const unsigned int flat_tid, - storage_type& storage) + ROCPRIM_DEVICE ROCPRIM_INLINE void copy_to_shared(Key (&k)[ItemsPerThread], + Value (&v)[ItemsPerThread], + const unsigned int flat_tid, + storage_type& storage) { storage_type_& storage_ = storage.get(); ROCPRIM_UNROLL - for(unsigned int item = 0; item < ItemsPerThread; ++item) { + for(unsigned int item = 0; item < ItemsPerThread; ++item) + { storage_.key[item * BlockSize + flat_tid] = k[item]; storage_.value[item * BlockSize + flat_tid] = v[item]; } @@ -258,18 +248,18 @@ class block_sort_bitonic } template - ROCPRIM_DEVICE ROCPRIM_INLINE - void swap(Key& key, - const unsigned int flat_tid, - const unsigned int next_id, - const bool dir, - storage_type& storage, - BinaryFunction compare_function) + ROCPRIM_DEVICE ROCPRIM_INLINE void swap(Key& key, + const unsigned int flat_tid, + const unsigned int next_id, + const bool dir, + storage_type& storage, + BinaryFunction compare_function) { storage_type_& storage_ = storage.get(); - Key next_key = storage_.key[next_id]; - bool compare = (next_id < flat_tid) ? compare_function(key, next_key) : compare_function(next_key, key); - bool swap = compare ^ dir; + Key next_key = storage_.key[next_id]; + bool compare = (next_id < flat_tid) ? compare_function(key, next_key) + : compare_function(next_key, key); + bool swap = compare ^ dir; if(swap) { key = next_key; @@ -277,20 +267,21 @@ class block_sort_bitonic } template - ROCPRIM_DEVICE ROCPRIM_INLINE - void swap(Key (&key)[ItemsPerThread], - const unsigned int flat_tid, - const unsigned int next_id, - const bool dir, - storage_type& storage, - BinaryFunction compare_function) + ROCPRIM_DEVICE ROCPRIM_INLINE void swap(Key (&key)[ItemsPerThread], + const unsigned int flat_tid, + const unsigned int next_id, + const bool dir, + storage_type& storage, + BinaryFunction compare_function) { storage_type_& storage_ = storage.get(); ROCPRIM_UNROLL - for(unsigned int item = 0; item < ItemsPerThread; ++item) { - Key next_key = storage_.key[item * BlockSize + next_id]; - bool compare = (next_id < flat_tid) ? compare_function(key[item], next_key) : compare_function(next_key, key[item]); - bool swap = compare ^ dir; + for(unsigned int item = 0; item < ItemsPerThread; ++item) + { + Key next_key = storage_.key[item * BlockSize + next_id]; + bool compare = (next_id < flat_tid) ? compare_function(key[item], next_key) + : compare_function(next_key, key[item]); + bool swap = compare ^ dir; if(swap) { key[item] = next_key; @@ -299,44 +290,43 @@ class block_sort_bitonic } template - ROCPRIM_DEVICE ROCPRIM_INLINE - void swap(Key& key, - Value& value, - const unsigned int flat_tid, - const unsigned int next_id, - const bool dir, - storage_type& storage, - BinaryFunction compare_function) + ROCPRIM_DEVICE ROCPRIM_INLINE void swap(Key& key, + Value& value, + const unsigned int flat_tid, + const unsigned int next_id, + const bool dir, + storage_type& storage, + BinaryFunction compare_function) { storage_type_& storage_ = storage.get(); - Key next_key = storage_.key[next_id]; - bool b = next_id < flat_tid; + Key next_key = storage_.key[next_id]; + bool b = next_id < flat_tid; bool compare = compare_function(b ? key : next_key, b ? next_key : key); - bool swap = compare ^ dir; + bool swap = compare ^ dir; if(swap) { - key = next_key; + key = next_key; value = storage_.value[next_id]; } } template - ROCPRIM_DEVICE ROCPRIM_INLINE - void swap(Key (&key)[ItemsPerThread], - Value (&value)[ItemsPerThread], - const unsigned int flat_tid, - const unsigned int next_id, - const bool dir, - storage_type& storage, - BinaryFunction compare_function) + ROCPRIM_DEVICE ROCPRIM_INLINE void swap(Key (&key)[ItemsPerThread], + Value (&value)[ItemsPerThread], + const unsigned int flat_tid, + const unsigned int next_id, + const bool dir, + storage_type& storage, + BinaryFunction compare_function) { storage_type_& storage_ = storage.get(); ROCPRIM_UNROLL - for(unsigned int item = 0; item < ItemsPerThread; ++item) { - Key next_key = storage_.key[item * BlockSize + next_id]; - bool b = next_id < flat_tid; - bool compare = compare_function(b ? key[item] : next_key, b ? next_key : key[item]); - bool swap = compare ^ dir; + for(unsigned int item = 0; item < ItemsPerThread; ++item) + { + Key next_key = storage_.key[item * BlockSize + next_id]; + bool b = next_id < flat_tid; + bool compare = compare_function(b ? key[item] : next_key, b ? next_key : key[item]); + bool swap = compare ^ dir; if(swap) { key[item] = next_key; @@ -457,100 +447,94 @@ class block_sort_bitonic } } - template< - unsigned int Size, - class BinaryFunction, - class... KeyValue - > + template ROCPRIM_DEVICE ROCPRIM_INLINE - typename std::enable_if<(Size <= ::rocprim::device_warp_size())>::type - sort_power_two(const unsigned int flat_tid, - storage_type& storage, - BinaryFunction compare_function, - KeyValue&... kv) + typename std::enable_if<(Size <= ::rocprim::device_warp_size())>::type + sort_power_two(const unsigned int flat_tid, + storage_type& storage, + BinaryFunction compare_function, + KeyValue&... kv) { - (void) flat_tid; - (void) storage; + (void)flat_tid; + (void)storage; ::rocprim::warp_sort wsort; wsort.sort(kv..., compare_function); } template - ROCPRIM_DEVICE ROCPRIM_INLINE - void warp_swap(Key& k, Value& v, int mask, bool dir, BinaryFunction compare_function) + ROCPRIM_DEVICE ROCPRIM_INLINE void + warp_swap(Key& k, Value& v, int mask, bool dir, BinaryFunction compare_function) { - Key k1 = warp_shuffle_xor(k, mask); + Key k1 = warp_swizzle_shuffle(k, mask); bool swap = compare_function(dir ? k : k1, dir ? k1 : k); - if (swap) + if(swap) { k = k1; - v = warp_shuffle_xor(v, mask); + v = warp_swizzle_shuffle(v, mask); } } - template - ROCPRIM_DEVICE ROCPRIM_INLINE - void warp_swap(Key (&k)[ItemsPerThread], - Value (&v)[ItemsPerThread], - int mask, - bool dir, - BinaryFunction compare_function) + template + ROCPRIM_DEVICE ROCPRIM_INLINE void warp_swap(Key (&k)[ItemsPerThread], + Value (&v)[ItemsPerThread], + int mask, + bool dir, + BinaryFunction compare_function) { ROCPRIM_UNROLL - for(unsigned int item = 0; item < ItemsPerThread; ++item) { - Key k1 = warp_shuffle_xor(k[item], mask); + for(unsigned int item = 0; item < ItemsPerThread; ++item) + { + Key k1 = warp_swizzle_shuffle(k[item], mask); bool swap = compare_function(dir ? k[item] : k1, dir ? k1 : k[item]); - if (swap) + if(swap) { k[item] = k1; - v[item] = warp_shuffle_xor(v[item], mask); + v[item] = warp_swizzle_shuffle(v[item], mask); } } } template - ROCPRIM_DEVICE ROCPRIM_INLINE - void warp_swap(Key& k, int mask, bool dir, BinaryFunction compare_function) + ROCPRIM_DEVICE ROCPRIM_INLINE void + warp_swap(Key& k, int mask, bool dir, BinaryFunction compare_function) { - Key k1 = warp_shuffle_xor(k, mask); + Key k1 = warp_swizzle_shuffle(k, mask); bool swap = compare_function(dir ? k : k1, dir ? k1 : k); - if (swap) + if(swap) { k = k1; } } - template - ROCPRIM_DEVICE ROCPRIM_INLINE - void warp_swap(Key (&k)[ItemsPerThread], int mask, bool dir, BinaryFunction compare_function) + template + ROCPRIM_DEVICE ROCPRIM_INLINE void + warp_swap(Key (&k)[ItemsPerThread], int mask, bool dir, BinaryFunction compare_function) { ROCPRIM_UNROLL - for(unsigned int item = 0; item < ItemsPerThread; ++item) { - Key k1 = warp_shuffle_xor(k[item], mask); + for(unsigned int item = 0; item < ItemsPerThread; ++item) + { + Key k1 = warp_swizzle_shuffle(k[item], mask); bool swap = compare_function(dir ? k[item] : k1, dir ? k1 : k[item]); - if (swap) + if(swap) { k[item] = k1; } } } - template - ROCPRIM_DEVICE ROCPRIM_INLINE - typename std::enable_if<(Items < 2)>::type - thread_merge(bool /*dir*/, BinaryFunction /*compare_function*/, KeyValue&... /*kv*/) - { - } + template + ROCPRIM_DEVICE ROCPRIM_INLINE typename std::enable_if<(Items < 2)>::type + thread_merge(bool /*dir*/, BinaryFunction /*compare_function*/, KeyValue&... /*kv*/) + {} - template - ROCPRIM_DEVICE ROCPRIM_INLINE - void thread_swap(Key (&k)[ItemsPerThread], - Value (&v)[ItemsPerThread], - bool dir, - unsigned int i, - unsigned int j, - BinaryFunction compare_function) + template + ROCPRIM_DEVICE ROCPRIM_INLINE void thread_swap(Key (&k)[ItemsPerThread], + Value (&v)[ItemsPerThread], + bool dir, + unsigned int i, + unsigned int j, + BinaryFunction compare_function) { if(compare_function(k[i], k[j]) == dir) { @@ -562,13 +546,12 @@ class block_sort_bitonic v[j] = v_temp; } } - template - ROCPRIM_DEVICE ROCPRIM_INLINE - void thread_swap(Key (&k)[ItemsPerThread], - bool dir, - unsigned int i, - unsigned int j, - BinaryFunction compare_function) + template + ROCPRIM_DEVICE ROCPRIM_INLINE void thread_swap(Key (&k)[ItemsPerThread], + bool dir, + unsigned int i, + unsigned int j, + BinaryFunction compare_function) { if(compare_function(k[i], k[j]) == dir) { @@ -578,19 +561,15 @@ class block_sort_bitonic } } - template - ROCPRIM_DEVICE ROCPRIM_INLINE - void thread_shuffle(unsigned int offset, bool dir, BinaryFunction compare_function, KeyValue&... kv) + template + ROCPRIM_DEVICE ROCPRIM_INLINE void thread_shuffle(unsigned int offset, + bool dir, + BinaryFunction compare_function, + KeyValue&... kv) { ROCPRIM_UNROLL for(unsigned base = 0; base < ItemsPerThread; base += 2 * offset) { - ROCPRIM_UNROLL -// Workaround to prevent the compiler thinking this is a 'Parallel Loop' on clang 15 -// because it leads to invalid code generation with `T` = `char` and `ItemsPerthread` = 4 -#if defined(__clang_major__) && __clang_major__ >= 15 - #pragma clang loop vectorize(disable) -#endif for(unsigned i = 0; i < offset; ++i) { thread_swap(kv..., dir, base + i, base + i + offset, compare_function); @@ -598,16 +577,15 @@ class block_sort_bitonic } } - template - ROCPRIM_DEVICE ROCPRIM_INLINE - typename std::enable_if::type - thread_merge(bool dir, BinaryFunction compare_function, KeyValue&... kv) + template + ROCPRIM_DEVICE ROCPRIM_INLINE typename std::enable_if::type + thread_merge(bool dir, BinaryFunction compare_function, KeyValue&... kv) { ROCPRIM_UNROLL for(unsigned int k = ItemsPerThread / 2; k > 0; k /= 2) { thread_shuffle(k, dir, compare_function, kv...); - } + } } /// Bitonic sort. @@ -622,14 +600,14 @@ class block_sort_bitonic { const auto warp_id_is_even = ((flat_tid / ::rocprim::device_warp_size()) % 2) == 0; ::rocprim::warp_sort wsort; - auto compare_function2 = - [compare_function, warp_id_is_even](const Key& a, const Key& b) mutable -> bool - { - auto r = compare_function(a, b); - if(warp_id_is_even) - return r; - return !r; - }; + auto compare_function2 + = [compare_function, warp_id_is_even](const Key& a, const Key& b) mutable -> bool + { + auto r = compare_function(a, b); + if(warp_id_is_even) + return r; + return !r; + }; wsort.sort(kv..., compare_function2); ROCPRIM_UNROLL @@ -644,11 +622,14 @@ class block_sort_bitonic ::rocprim::syncthreads(); } + const unsigned int id = detail::logical_lane_id<::rocprim::device_warp_size()>(); + constexpr unsigned int s = ::rocprim::device_warp_size() / 2; + ROCPRIM_UNROLL - for(unsigned int k = ::rocprim::device_warp_size() / 2; k > 0; k /= 2) + for(unsigned int k = s; k > 0; k /= 2) { - const bool length_even = ((detail::logical_lane_id<::rocprim::device_warp_size()>() / k ) % 2 ) == 0; - const bool local_dir = length_even ? dir : !dir; + const bool length_even = ((id / k) % 2) == 0; + const bool local_dir = length_even ? dir : !dir; warp_swap(kv..., k, local_dir, compare_function); } thread_merge(dir, compare_function, kv...); diff --git a/rocprim/include/rocprim/detail/match_result_type.hpp b/rocprim/include/rocprim/detail/match_result_type.hpp new file mode 100644 index 000000000..40b928982 --- /dev/null +++ b/rocprim/include/rocprim/detail/match_result_type.hpp @@ -0,0 +1,50 @@ +// Copyright (c) 2018-2024 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_DETAIL_MATCH_RESULT_TYPE_HPP_ +#define ROCPRIM_DETAIL_MATCH_RESULT_TYPE_HPP_ + +#include "../config.hpp" + +#include "../type_traits.hpp" + +ROCPRIM_PRAGMA_MESSAGE("Internal 'match_result_type.hpp'-header has been depracated. Please " + "include 'rocprim/type_traits.hpp' instead!"); + +BEGIN_ROCPRIM_NAMESPACE +namespace detail +{ + +template +using invoke_result [[deprecated("Use 'rocprim::invoke_result' instead!")]] += rocprim::invoke_result; + +template +using match_result [[deprecated("Use 'rocprim::invoke_result_binary_op' instead!")]] += rocprim::invoke_result_binary_op; + +template +using match_result_type [[deprecated("Use 'rocprim::invoke_result_binary_op_t' instead!")]] += rocprim::invoke_result_binary_op_t; + +} // end namespace detail +END_ROCPRIM_NAMESPACE + +#endif // ROCPRIM_DETAIL_MATCH_RESULT_TYPE_HPP_ diff --git a/rocprim/include/rocprim/detail/radix_sort.hpp b/rocprim/include/rocprim/detail/radix_sort.hpp index 32ff17e87..2be7221b8 100644 --- a/rocprim/include/rocprim/detail/radix_sort.hpp +++ b/rocprim/include/rocprim/detail/radix_sort.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -21,321 +21,8 @@ #ifndef ROCPRIM_DETAIL_RADIX_SORT_HPP_ #define ROCPRIM_DETAIL_RADIX_SORT_HPP_ -#include - -#include "../config.hpp" -#include "../type_traits.hpp" - -BEGIN_ROCPRIM_NAMESPACE -namespace detail -{ - -// Encode and decode integral and floating point values for radix sort in such a way that preserves -// correct order of negative and positive keys (i.e. negative keys go before positive ones, -// which is not true for a simple reinterpetation of the key's bits). - -// Digit extractor takes into account that (+0.0 == -0.0) is true for floats, -// so both +0.0 and -0.0 are reflected into the same bit pattern for digit extraction. -// Maximum digit length is 32. - -template -struct radix_key_codec_integral { }; - -template -struct radix_key_codec_integral::value>::type> -{ - using bit_key_type = BitKey; - - ROCPRIM_DEVICE ROCPRIM_INLINE - static bit_key_type encode(Key key) - { - return __builtin_bit_cast(bit_key_type, key); - } - - ROCPRIM_DEVICE ROCPRIM_INLINE - static Key decode(bit_key_type bit_key) - { - return __builtin_bit_cast(Key, bit_key); - } - - template - ROCPRIM_DEVICE static unsigned int - extract_digit(bit_key_type bit_key, unsigned int start, unsigned int length) - { - unsigned int mask = (1u << length) - 1; - return static_cast(bit_key >> start) & mask; - } -}; - -template -struct radix_key_codec_integral< - Key, - BitKey, - typename std::enable_if::value>::type> -{ - using bit_key_type = BitKey; - - ROCPRIM_DEVICE ROCPRIM_INLINE static bit_key_type encode(Key key) - { - return __builtin_bit_cast(bit_key_type, key); - } - - ROCPRIM_DEVICE ROCPRIM_INLINE static Key decode(bit_key_type bit_key) - { - return __builtin_bit_cast(Key, bit_key); - } - - template - ROCPRIM_DEVICE static unsigned int - extract_digit(bit_key_type bit_key, unsigned int start, unsigned int length) - { - unsigned int mask = (1u << length) - 1; - return static_cast(bit_key >> start) & mask; - } -}; - -template -struct radix_key_codec_integral::value>::type> -{ - using bit_key_type = BitKey; - - static constexpr bit_key_type sign_bit = bit_key_type(1) << (sizeof(bit_key_type) * 8 - 1); - - ROCPRIM_DEVICE ROCPRIM_INLINE - static bit_key_type encode(Key key) - { - const bit_key_type bit_key = __builtin_bit_cast(bit_key_type, key); - return sign_bit ^ bit_key; - } - - ROCPRIM_DEVICE ROCPRIM_INLINE - static Key decode(bit_key_type bit_key) - { - bit_key ^= sign_bit; - return __builtin_bit_cast(Key, bit_key); - } - - template - ROCPRIM_DEVICE static unsigned int - extract_digit(bit_key_type bit_key, unsigned int start, unsigned int length) - { - unsigned int mask = (1u << length) - 1; - return static_cast(bit_key >> start) & mask; - } -}; - -template -struct radix_key_codec_integral::value>::type> -{ - using bit_key_type = BitKey; - - static constexpr bit_key_type sign_bit = bit_key_type(1) << (sizeof(bit_key_type) * 8 - 1); - - ROCPRIM_DEVICE ROCPRIM_INLINE static bit_key_type encode(Key key) - { - const bit_key_type bit_key = __builtin_bit_cast(bit_key_type, key); - return sign_bit ^ bit_key; - } - - ROCPRIM_DEVICE ROCPRIM_INLINE static Key decode(bit_key_type bit_key) - { - bit_key ^= sign_bit; - return __builtin_bit_cast(Key, bit_key); - } - - template - ROCPRIM_DEVICE static unsigned int - extract_digit(bit_key_type bit_key, unsigned int start, unsigned int length) - { - unsigned int mask = (1u << length) - 1; - return static_cast(bit_key >> start) & mask; - } -}; - -template -struct float_bit_mask; - -template<> -struct float_bit_mask -{ - static constexpr uint32_t sign_bit = 0x80000000; - static constexpr uint32_t exponent = 0x7F800000; - static constexpr uint32_t mantissa = 0x007FFFFF; - using bit_type = uint32_t; -}; - -template<> -struct float_bit_mask -{ - static constexpr uint64_t sign_bit = 0x8000000000000000; - static constexpr uint64_t exponent = 0x7FF0000000000000; - static constexpr uint64_t mantissa = 0x000FFFFFFFFFFFFF; - using bit_type = uint64_t; -}; - -template<> -struct float_bit_mask -{ - static constexpr uint16_t sign_bit = 0x8000; - static constexpr uint16_t exponent = 0x7F80; - static constexpr uint16_t mantissa = 0x007F; - using bit_type = uint16_t; -}; - -template<> -struct float_bit_mask -{ - static constexpr uint16_t sign_bit = 0x8000; - static constexpr uint16_t exponent = 0x7C00; - static constexpr uint16_t mantissa = 0x03FF; - using bit_type = uint16_t; -}; - -template -struct radix_key_codec_floating -{ - using bit_key_type = BitKey; - - static constexpr bit_key_type sign_bit = float_bit_mask::sign_bit; - - ROCPRIM_DEVICE ROCPRIM_INLINE - static bit_key_type encode(Key key) - { - bit_key_type bit_key = __builtin_bit_cast(bit_key_type, key); - bit_key ^= (sign_bit & bit_key) == 0 ? sign_bit : bit_key_type(-1); - return bit_key; - } - - ROCPRIM_DEVICE ROCPRIM_INLINE - static Key decode(bit_key_type bit_key) - { - bit_key ^= (sign_bit & bit_key) == 0 ? bit_key_type(-1) : sign_bit; - return __builtin_bit_cast(Key, bit_key); - } - - template - ROCPRIM_DEVICE static unsigned int - extract_digit(bit_key_type bit_key, unsigned int start, unsigned int length) - { - unsigned int mask = (1u << length) - 1; - - // radix_key_codec_floating::encode() maps 0.0 to 0x8000'0000, - // and -0.0 to 0x7FFF'FFFF. - // radix_key_codec::encode() then flips the bits if descending, yielding: - // value | descending | ascending | - // ----- | ----------- | ----------- | - // 0.0 | 0x7FFF'FFFF | 0x8000'0000 | - // -0.0 | 0x8000'0000 | 0x7FFF'FFFF | - // - // For ascending sort, both should be mapped to 0x8000'0000, - // and for descending sort, both should be mapped to 0x7FFF'FFFF. - if ROCPRIM_IF_CONSTEXPR(Descending) - { - bit_key = bit_key == sign_bit ? static_cast(~sign_bit) : bit_key; - } - else - { - bit_key = bit_key == static_cast(~sign_bit) ? sign_bit : bit_key; - } - return static_cast(bit_key >> start) & mask; - } -}; - -template -struct radix_key_codec_base -{ - static_assert(sizeof(Key) == 0, - "Only integral and floating point types supported as radix sort keys"); -}; - -template -struct radix_key_codec_base< - Key, - typename std::enable_if<::rocprim::is_integral::value>::type -> : radix_key_codec_integral::type> { }; - -template -struct radix_key_codec_base::value>::type> - : radix_key_codec_integral -{}; - -template -struct radix_key_codec_base::value>::type> - : radix_key_codec_integral -{}; - -template<> -struct radix_key_codec_base -{ - using bit_key_type = unsigned char; - - ROCPRIM_DEVICE ROCPRIM_INLINE - static bit_key_type encode(bool key) - { - return static_cast(key); - } - - ROCPRIM_DEVICE ROCPRIM_INLINE - static bool decode(bit_key_type bit_key) - { - return static_cast(bit_key); - } - - template - ROCPRIM_DEVICE static unsigned int - extract_digit(bit_key_type bit_key, unsigned int start, unsigned int length) - { - unsigned int mask = (1u << length) - 1; - return static_cast(bit_key >> start) & mask; - } -}; - -template<> -struct radix_key_codec_base<::rocprim::half> : radix_key_codec_floating<::rocprim::half, unsigned short> { }; - -template<> -struct radix_key_codec_base<::rocprim::bfloat16> : radix_key_codec_floating<::rocprim::bfloat16, unsigned short> { }; - -template<> -struct radix_key_codec_base : radix_key_codec_floating { }; - -template<> -struct radix_key_codec_base : radix_key_codec_floating { }; - -template -class radix_key_codec : protected radix_key_codec_base -{ - using base_type = radix_key_codec_base; - -public: - using bit_key_type = typename base_type::bit_key_type; - - ROCPRIM_DEVICE ROCPRIM_INLINE - static bit_key_type encode(Key key) - { - bit_key_type bit_key = base_type::encode(key); - return (Descending ? ~bit_key : bit_key); - } - - ROCPRIM_DEVICE ROCPRIM_INLINE - static Key decode(bit_key_type bit_key) - { - bit_key = (Descending ? ~bit_key : bit_key); - return base_type::decode(bit_key); - } - - ROCPRIM_DEVICE ROCPRIM_INLINE - static unsigned int extract_digit(bit_key_type bit_key, unsigned int start, unsigned int radix_bits) - { - return base_type::template extract_digit(bit_key, start, radix_bits); - } -}; - -} // end namespace detail -END_ROCPRIM_NAMESPACE +ROCPRIM_PRAGMA_MESSAGE("Functionality from rocprim/detail/radix_sort.hpp has been moved to " + "rocprim/thread/radix_key_codec.hpp.") +#include "../thread/radix_key_codec.hpp" #endif // ROCPRIM_DETAIL_RADIX_SORT_HPP_ diff --git a/rocprim/include/rocprim/detail/various.hpp b/rocprim/include/rocprim/detail/various.hpp index 1def58e3d..70b9e7af0 100644 --- a/rocprim/include/rocprim/detail/various.hpp +++ b/rocprim/include/rocprim/detail/various.hpp @@ -394,6 +394,25 @@ ROCPRIM_HOST_DEVICE ROCPRIM_INLINE DstPtr cast_align_down(Src* pointer) #endif } +template +ROCPRIM_HOST_DEVICE auto bit_cast(const Source& source) + -> std::enable_if_t::value + && std::is_trivially_copyable::value, + Destination> +{ +#if defined(__has_builtin) && __has_builtin(__builtin_bit_cast) + return __builtin_bit_cast(Destination, source); +#else + static_assert( + std::is_trivially_constructable::value, + "Fallback implementation of bit_cast requires Destination to be trivially constructible"); + Destination dest; + memcpy(&dest, &source, sizeof(Destination)); + return dest; +#endif +} + } // end namespace detail END_ROCPRIM_NAMESPACE diff --git a/rocprim/include/rocprim/device/config_types.hpp b/rocprim/include/rocprim/device/config_types.hpp index 0b8c75cd8..32bdb5cd6 100644 --- a/rocprim/include/rocprim/device/config_types.hpp +++ b/rocprim/include/rocprim/device/config_types.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2018-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal diff --git a/rocprim/include/rocprim/device/detail/config/device_adjacent_difference.hpp b/rocprim/include/rocprim/device/detail/config/device_adjacent_difference.hpp index 139ce7e8b..e1e2cc2c9 100644 --- a/rocprim/include/rocprim/device/detail/config/device_adjacent_difference.hpp +++ b/rocprim/include/rocprim/device/detail/config/device_adjacent_difference.hpp @@ -43,74 +43,6 @@ struct default_adjacent_difference_config : default_adjacent_difference_config_base::type {}; -// Based on value_type = double -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : adjacent_difference_config<512, 4> -{}; - -// Based on value_type = float -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : adjacent_difference_config<1024, 2> -{}; - -// Based on value_type = rocprim::half -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> : adjacent_difference_config<128, 4> -{}; - -// Based on value_type = int64_t -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : adjacent_difference_config<512, 4> -{}; - -// Based on value_type = int -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : adjacent_difference_config<1024, 2> -{}; - -// Based on value_type = short -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : adjacent_difference_config<256, 4> -{}; - -// Based on value_type = int8_t -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> : adjacent_difference_config<32, 8> -{}; - // Based on value_type = double template struct default_adjacent_difference_config< @@ -118,7 +50,7 @@ struct default_adjacent_difference_config< value_type, std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : adjacent_difference_config<128, 8> + : adjacent_difference_config<1024, 1> {}; // Based on value_type = float @@ -137,7 +69,7 @@ struct default_adjacent_difference_config< static_cast(target_arch::gfx1102), value_type, std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> : adjacent_difference_config<32, 4> + && (sizeof(value_type) <= 2))>> : adjacent_difference_config<512, 11> {}; // Based on value_type = int64_t @@ -147,7 +79,7 @@ struct default_adjacent_difference_config< value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : adjacent_difference_config<1024, 2> + : adjacent_difference_config<1024, 1> {}; // Based on value_type = int @@ -167,7 +99,7 @@ struct default_adjacent_difference_config< value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : adjacent_difference_config<32, 4> + : adjacent_difference_config<256, 11> {}; // Based on value_type = int8_t @@ -176,75 +108,75 @@ struct default_adjacent_difference_config< static_cast(target_arch::gfx1102), value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> : adjacent_difference_config<256, 32> + && (sizeof(value_type) <= 1))>> : adjacent_difference_config<64, 19> {}; // Based on value_type = double template struct default_adjacent_difference_config< - static_cast(target_arch::gfx900), + static_cast(target_arch::gfx908), value_type, std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : adjacent_difference_config<512, 2> + : adjacent_difference_config<128, 1> {}; // Based on value_type = float template struct default_adjacent_difference_config< - static_cast(target_arch::gfx900), + static_cast(target_arch::gfx908), value_type, std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : adjacent_difference_config<512, 2> + : adjacent_difference_config<128, 3> {}; // Based on value_type = rocprim::half template struct default_adjacent_difference_config< - static_cast(target_arch::gfx900), + static_cast(target_arch::gfx908), value_type, std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> : adjacent_difference_config<128, 64> + && (sizeof(value_type) <= 2))>> : adjacent_difference_config<64, 7> {}; // Based on value_type = int64_t template struct default_adjacent_difference_config< - static_cast(target_arch::gfx900), + static_cast(target_arch::gfx908), value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : adjacent_difference_config<512, 2> + : adjacent_difference_config<128, 1> {}; // Based on value_type = int template struct default_adjacent_difference_config< - static_cast(target_arch::gfx900), + static_cast(target_arch::gfx908), value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : adjacent_difference_config<512, 2> + : adjacent_difference_config<128, 2> {}; // Based on value_type = short template struct default_adjacent_difference_config< - static_cast(target_arch::gfx900), + static_cast(target_arch::gfx908), value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : adjacent_difference_config<128, 64> + : adjacent_difference_config<64, 7> {}; // Based on value_type = int8_t template struct default_adjacent_difference_config< - static_cast(target_arch::gfx900), + static_cast(target_arch::gfx908), value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> : adjacent_difference_config<64, 16> + && (sizeof(value_type) <= 1))>> : adjacent_difference_config<64, 19> {}; // Based on value_type = double @@ -254,7 +186,7 @@ struct default_adjacent_difference_config< value_type, std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : adjacent_difference_config<128, 2> + : adjacent_difference_config<1024, 2> {}; // Based on value_type = float @@ -264,7 +196,7 @@ struct default_adjacent_difference_config< value_type, std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : adjacent_difference_config<512, 4> + : adjacent_difference_config<1024, 5> {}; // Based on value_type = rocprim::half @@ -273,7 +205,7 @@ struct default_adjacent_difference_config< static_cast(target_arch::gfx906), value_type, std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> : adjacent_difference_config<64, 16> + && (sizeof(value_type) <= 2))>> : adjacent_difference_config<64, 17> {}; // Based on value_type = int64_t @@ -293,7 +225,7 @@ struct default_adjacent_difference_config< value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : adjacent_difference_config<512, 4> + : adjacent_difference_config<1024, 5> {}; // Based on value_type = short @@ -303,7 +235,7 @@ struct default_adjacent_difference_config< value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : adjacent_difference_config<64, 16> + : adjacent_difference_config<64, 7> {}; // Based on value_type = int8_t @@ -312,75 +244,75 @@ struct default_adjacent_difference_config< static_cast(target_arch::gfx906), value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> : adjacent_difference_config<64, 16> + && (sizeof(value_type) <= 1))>> : adjacent_difference_config<64, 19> {}; // Based on value_type = double template struct default_adjacent_difference_config< - static_cast(target_arch::gfx908), + static_cast(target_arch::gfx1030), value_type, std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : adjacent_difference_config<128, 2> + : adjacent_difference_config<1024, 1> {}; // Based on value_type = float template struct default_adjacent_difference_config< - static_cast(target_arch::gfx908), + static_cast(target_arch::gfx1030), value_type, std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : adjacent_difference_config<512, 2> + : adjacent_difference_config<1024, 1> {}; // Based on value_type = rocprim::half template struct default_adjacent_difference_config< - static_cast(target_arch::gfx908), + static_cast(target_arch::gfx1030), value_type, std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> : adjacent_difference_config<64, 8> + && (sizeof(value_type) <= 2))>> : adjacent_difference_config<1024, 5> {}; // Based on value_type = int64_t template struct default_adjacent_difference_config< - static_cast(target_arch::gfx908), + static_cast(target_arch::gfx1030), value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : adjacent_difference_config<128, 2> + : adjacent_difference_config<1024, 1> {}; // Based on value_type = int template struct default_adjacent_difference_config< - static_cast(target_arch::gfx908), + static_cast(target_arch::gfx1030), value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : adjacent_difference_config<128, 2> + : adjacent_difference_config<1024, 1> {}; // Based on value_type = short template struct default_adjacent_difference_config< - static_cast(target_arch::gfx908), + static_cast(target_arch::gfx1030), value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : adjacent_difference_config<128, 8> + : adjacent_difference_config<1024, 5> {}; // Based on value_type = int8_t template struct default_adjacent_difference_config< - static_cast(target_arch::gfx908), + static_cast(target_arch::gfx1030), value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> : adjacent_difference_config<64, 16> + && (sizeof(value_type) <= 1))>> : adjacent_difference_config<64, 7> {}; // Based on value_type = double @@ -390,7 +322,7 @@ struct default_adjacent_difference_config< value_type, std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : adjacent_difference_config<128, 2> + : adjacent_difference_config<128, 1> {}; // Based on value_type = float @@ -400,7 +332,7 @@ struct default_adjacent_difference_config< value_type, std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : adjacent_difference_config<512, 2> + : adjacent_difference_config<128, 3> {}; // Based on value_type = rocprim::half @@ -409,7 +341,7 @@ struct default_adjacent_difference_config< static_cast(target_arch::unknown), value_type, std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> : adjacent_difference_config<64, 8> + && (sizeof(value_type) <= 2))>> : adjacent_difference_config<64, 7> {}; // Based on value_type = int64_t @@ -419,7 +351,7 @@ struct default_adjacent_difference_config< value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : adjacent_difference_config<128, 2> + : adjacent_difference_config<128, 1> {}; // Based on value_type = int @@ -439,7 +371,7 @@ struct default_adjacent_difference_config< value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : adjacent_difference_config<128, 8> + : adjacent_difference_config<64, 7> {}; // Based on value_type = int8_t @@ -448,7 +380,7 @@ struct default_adjacent_difference_config< static_cast(target_arch::unknown), value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> : adjacent_difference_config<64, 16> + && (sizeof(value_type) <= 1))>> : adjacent_difference_config<64, 19> {}; // Based on value_type = double @@ -458,7 +390,7 @@ struct default_adjacent_difference_config< value_type, std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : adjacent_difference_config<128, 2> + : adjacent_difference_config<128, 1> {}; // Based on value_type = float @@ -468,7 +400,7 @@ struct default_adjacent_difference_config< value_type, std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : adjacent_difference_config<512, 2> + : adjacent_difference_config<128, 3> {}; // Based on value_type = rocprim::half @@ -477,7 +409,7 @@ struct default_adjacent_difference_config< static_cast(target_arch::gfx90a), value_type, std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> : adjacent_difference_config<64, 8> + && (sizeof(value_type) <= 2))>> : adjacent_difference_config<64, 7> {}; // Based on value_type = int64_t @@ -487,7 +419,7 @@ struct default_adjacent_difference_config< value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : adjacent_difference_config<128, 2> + : adjacent_difference_config<128, 1> {}; // Based on value_type = int @@ -507,7 +439,7 @@ struct default_adjacent_difference_config< value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : adjacent_difference_config<128, 8> + : adjacent_difference_config<64, 7> {}; // Based on value_type = int8_t @@ -516,7 +448,7 @@ struct default_adjacent_difference_config< static_cast(target_arch::gfx90a), value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> : adjacent_difference_config<64, 16> + && (sizeof(value_type) <= 1))>> : adjacent_difference_config<64, 19> {}; } // end namespace detail diff --git a/rocprim/include/rocprim/device/detail/config/device_adjacent_difference_inplace.hpp b/rocprim/include/rocprim/device/detail/config/device_adjacent_difference_inplace.hpp index ab0df2f81..b4f4473b0 100644 --- a/rocprim/include/rocprim/device/detail/config/device_adjacent_difference_inplace.hpp +++ b/rocprim/include/rocprim/device/detail/config/device_adjacent_difference_inplace.hpp @@ -50,7 +50,7 @@ struct default_adjacent_difference_inplace_config< value_type, std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : adjacent_difference_config<128, 16> + : adjacent_difference_config<32, 17> {}; // Based on value_type = float @@ -60,7 +60,7 @@ struct default_adjacent_difference_inplace_config< value_type, std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : adjacent_difference_config<256, 16> + : adjacent_difference_config<64, 17> {}; // Based on value_type = rocprim::half @@ -69,7 +69,7 @@ struct default_adjacent_difference_inplace_config< static_cast(target_arch::gfx1102), value_type, std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> : adjacent_difference_config<512, 16> + && (sizeof(value_type) <= 2))>> : adjacent_difference_config<128, 17> {}; // Based on value_type = int64_t @@ -79,7 +79,7 @@ struct default_adjacent_difference_inplace_config< value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : adjacent_difference_config<128, 16> + : adjacent_difference_config<32, 17> {}; // Based on value_type = int @@ -89,7 +89,7 @@ struct default_adjacent_difference_inplace_config< value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : adjacent_difference_config<256, 16> + : adjacent_difference_config<64, 17> {}; // Based on value_type = short @@ -99,7 +99,7 @@ struct default_adjacent_difference_inplace_config< value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : adjacent_difference_config<256, 32> + : adjacent_difference_config<128, 17> {}; // Based on value_type = int8_t @@ -108,143 +108,75 @@ struct default_adjacent_difference_inplace_config< static_cast(target_arch::gfx1102), value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> : adjacent_difference_config<512, 32> + && (sizeof(value_type) <= 1))>> : adjacent_difference_config<256, 17> {}; // Based on value_type = double template struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : adjacent_difference_config<512, 4> -{}; - -// Based on value_type = float -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : adjacent_difference_config<1024, 4> -{}; - -// Based on value_type = rocprim::half -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> : adjacent_difference_config<1024, 8> -{}; - -// Based on value_type = int64_t -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : adjacent_difference_config<512, 4> -{}; - -// Based on value_type = int -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : adjacent_difference_config<1024, 4> -{}; - -// Based on value_type = short -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : adjacent_difference_config<1024, 8> -{}; - -// Based on value_type = int8_t -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> : adjacent_difference_config<32, 64> -{}; - -// Based on value_type = double -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx900), + static_cast(target_arch::gfx908), value_type, std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : adjacent_difference_config<256, 16> + : adjacent_difference_config<512, 2> {}; // Based on value_type = float template struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx900), + static_cast(target_arch::gfx908), value_type, std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : adjacent_difference_config<128, 64> + : adjacent_difference_config<1024, 5> {}; // Based on value_type = rocprim::half template struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx900), + static_cast(target_arch::gfx908), value_type, std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> : adjacent_difference_config<256, 64> + && (sizeof(value_type) <= 2))>> : adjacent_difference_config<32, 23> {}; // Based on value_type = int64_t template struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx900), + static_cast(target_arch::gfx908), value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : adjacent_difference_config<256, 16> + : adjacent_difference_config<512, 2> {}; // Based on value_type = int template struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx900), + static_cast(target_arch::gfx908), value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : adjacent_difference_config<128, 64> + : adjacent_difference_config<1024, 5> {}; // Based on value_type = short template struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx900), + static_cast(target_arch::gfx908), value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : adjacent_difference_config<256, 64> + : adjacent_difference_config<32, 23> {}; // Based on value_type = int8_t template struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx900), + static_cast(target_arch::gfx908), value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> : adjacent_difference_config<512, 16> + && (sizeof(value_type) <= 1))>> : adjacent_difference_config<64, 19> {}; // Based on value_type = double @@ -254,7 +186,7 @@ struct default_adjacent_difference_inplace_config< value_type, std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : adjacent_difference_config<1024, 4> + : adjacent_difference_config<128, 11> {}; // Based on value_type = float @@ -264,7 +196,7 @@ struct default_adjacent_difference_inplace_config< value_type, std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : adjacent_difference_config<1024, 8> + : adjacent_difference_config<256, 11> {}; // Based on value_type = rocprim::half @@ -273,7 +205,7 @@ struct default_adjacent_difference_inplace_config< static_cast(target_arch::gfx906), value_type, std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> : adjacent_difference_config<256, 16> + && (sizeof(value_type) <= 2))>> : adjacent_difference_config<64, 17> {}; // Based on value_type = int64_t @@ -283,7 +215,7 @@ struct default_adjacent_difference_inplace_config< value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : adjacent_difference_config<1024, 4> + : adjacent_difference_config<128, 11> {}; // Based on value_type = int @@ -293,7 +225,7 @@ struct default_adjacent_difference_inplace_config< value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : adjacent_difference_config<512, 16> + : adjacent_difference_config<256, 11> {}; // Based on value_type = short @@ -303,7 +235,7 @@ struct default_adjacent_difference_inplace_config< value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : adjacent_difference_config<256, 16> + : adjacent_difference_config<64, 17> {}; // Based on value_type = int8_t @@ -312,75 +244,75 @@ struct default_adjacent_difference_inplace_config< static_cast(target_arch::gfx906), value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> : adjacent_difference_config<64, 16> + && (sizeof(value_type) <= 1))>> : adjacent_difference_config<64, 19> {}; // Based on value_type = double template struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx908), + static_cast(target_arch::gfx1030), value_type, std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : adjacent_difference_config<512, 2> + : adjacent_difference_config<32, 17> {}; // Based on value_type = float template struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx908), + static_cast(target_arch::gfx1030), value_type, std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : adjacent_difference_config<1024, 4> + : adjacent_difference_config<32, 17> {}; // Based on value_type = rocprim::half template struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx908), + static_cast(target_arch::gfx1030), value_type, std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> : adjacent_difference_config<512, 8> + && (sizeof(value_type) <= 2))>> : adjacent_difference_config<512, 11> {}; // Based on value_type = int64_t template struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx908), + static_cast(target_arch::gfx1030), value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : adjacent_difference_config<512, 2> + : adjacent_difference_config<32, 17> {}; // Based on value_type = int template struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx908), + static_cast(target_arch::gfx1030), value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : adjacent_difference_config<1024, 4> + : adjacent_difference_config<32, 19> {}; // Based on value_type = short template struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx908), + static_cast(target_arch::gfx1030), value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : adjacent_difference_config<64, 32> + : adjacent_difference_config<512, 11> {}; // Based on value_type = int8_t template struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx908), + static_cast(target_arch::gfx1030), value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> : adjacent_difference_config<64, 16> + && (sizeof(value_type) <= 1))>> : adjacent_difference_config<32, 23> {}; // Based on value_type = double @@ -400,7 +332,7 @@ struct default_adjacent_difference_inplace_config< value_type, std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : adjacent_difference_config<1024, 4> + : adjacent_difference_config<1024, 5> {}; // Based on value_type = rocprim::half @@ -409,7 +341,7 @@ struct default_adjacent_difference_inplace_config< static_cast(target_arch::unknown), value_type, std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> : adjacent_difference_config<512, 8> + && (sizeof(value_type) <= 2))>> : adjacent_difference_config<32, 23> {}; // Based on value_type = int64_t @@ -429,7 +361,7 @@ struct default_adjacent_difference_inplace_config< value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : adjacent_difference_config<1024, 4> + : adjacent_difference_config<1024, 5> {}; // Based on value_type = short @@ -439,7 +371,7 @@ struct default_adjacent_difference_inplace_config< value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : adjacent_difference_config<64, 32> + : adjacent_difference_config<32, 23> {}; // Based on value_type = int8_t @@ -448,7 +380,7 @@ struct default_adjacent_difference_inplace_config< static_cast(target_arch::unknown), value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> : adjacent_difference_config<64, 16> + && (sizeof(value_type) <= 1))>> : adjacent_difference_config<64, 19> {}; // Based on value_type = double @@ -468,7 +400,7 @@ struct default_adjacent_difference_inplace_config< value_type, std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : adjacent_difference_config<1024, 4> + : adjacent_difference_config<1024, 5> {}; // Based on value_type = rocprim::half @@ -477,7 +409,7 @@ struct default_adjacent_difference_inplace_config< static_cast(target_arch::gfx90a), value_type, std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> : adjacent_difference_config<512, 8> + && (sizeof(value_type) <= 2))>> : adjacent_difference_config<32, 23> {}; // Based on value_type = int64_t @@ -497,7 +429,7 @@ struct default_adjacent_difference_inplace_config< value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : adjacent_difference_config<1024, 4> + : adjacent_difference_config<1024, 5> {}; // Based on value_type = short @@ -507,7 +439,7 @@ struct default_adjacent_difference_inplace_config< value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : adjacent_difference_config<64, 32> + : adjacent_difference_config<32, 23> {}; // Based on value_type = int8_t @@ -516,7 +448,7 @@ struct default_adjacent_difference_inplace_config< static_cast(target_arch::gfx90a), value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> : adjacent_difference_config<64, 16> + && (sizeof(value_type) <= 1))>> : adjacent_difference_config<64, 19> {}; } // end namespace detail diff --git a/rocprim/include/rocprim/device/detail/device_adjacent_difference.hpp b/rocprim/include/rocprim/device/detail/device_adjacent_difference.hpp index 4a7592b5d..de69dfe21 100644 --- a/rocprim/include/rocprim/device/detail/device_adjacent_difference.hpp +++ b/rocprim/include/rocprim/device/detail/device_adjacent_difference.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2022-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2022-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -28,6 +28,7 @@ #include "../../detail/various.hpp" #include "../../config.hpp" +#include "../../type_traits.hpp" #include "device_config_helper.hpp" #include @@ -179,7 +180,7 @@ ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE void adjacent_difference_kernel_impl( const std::size_t starting_block) { using input_type = typename std::iterator_traits::value_type; - using output_type = typename std::iterator_traits::value_type; + using output_type = rocprim::invoke_result_binary_op_t; static constexpr adjacent_difference_config_params params = device_params(); @@ -195,11 +196,7 @@ ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE void adjacent_difference_kernel_impl( using adjacent_helper = adjacent_diff_helper; -#if defined(__gfx1102__) or defined(__gfx1030__) - ROCPRIM_SHARED_MEMORY struct -#else ROCPRIM_SHARED_MEMORY union -#endif { typename block_load_type::storage_type load; typename adjacent_helper::storage_type adjacent_diff; diff --git a/rocprim/include/rocprim/device/detail/device_batch_memcpy.hpp b/rocprim/include/rocprim/device/detail/device_batch_memcpy.hpp index 9c73c9c12..f4aac0023 100644 --- a/rocprim/include/rocprim/device/detail/device_batch_memcpy.hpp +++ b/rocprim/include/rocprim/device/detail/device_batch_memcpy.hpp @@ -1,6 +1,6 @@ /****************************************************************************** * Copyright (c) 2011-2022, NVIDIA CORPORATION. All rights reserved. - * Modifications Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * Modifications Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: @@ -32,6 +32,7 @@ #include "rocprim/device/config_types.hpp" #include "rocprim/device/detail/device_scan_common.hpp" #include "rocprim/device/detail/lookback_scan_state.hpp" +#include "rocprim/device/device_memcpy_config.hpp" #include "rocprim/device/device_scan.hpp" #include "rocprim/block/block_exchange.hpp" @@ -104,22 +105,51 @@ struct counter } }; -template -ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE static uint8_t read_byte(void* buffer_src, Offset offset) +template::type = 0> +ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE static Alias read_item(InputIt buffer_src, Offset offset) { return rocprim::thread_load( - reinterpret_cast(buffer_src) + offset); + reinterpret_cast(buffer_src) + offset); } -template +template::type = 0> +ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE static Alias read_item(InputIt buffer_src, Offset offset) +{ + return rocprim::thread_load(buffer_src + offset); +} + +template::type = 0> ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE static void - write_byte(void* buffer_dst, Offset offset, uint8_t value) + write_item(InputIt buffer_dst, Offset offset, Alias value) { rocprim::thread_store( - reinterpret_cast(buffer_dst) + offset, + reinterpret_cast(buffer_dst) + offset, value); } +template::type = 0> +ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE static void + write_item(InputIt buffer_dst, Offset offset, Alias value) +{ + rocprim::thread_store(buffer_dst + offset, value); +} + template struct aligned_ranges { @@ -267,9 +297,42 @@ ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE static void vectorized_copy_bytes(const void in_ptr += warp_size; } } + +template::type = 0> +ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE static void + copy_items(InputIt input_buffer, OutputIt output_buffer, Offset num_items, Offset offset = 0) +{ + vectorized_copy_bytes(input_buffer, output_buffer, num_items, offset); +} + +template::type = 0> +ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE static void + copy_items(InputIt input_buffer, OutputIt output_buffer, Offset num_items, Offset offset = 0) +{ + constexpr auto warp_size = rocprim::device_warp_size(); + output_buffer += offset; + input_buffer += offset; + for(Offset i = threadIdx.x % warp_size; i < num_items; i += warp_size) + { + *(output_buffer + i) = *(input_buffer + i); + } +} + } // namespace batch_memcpy -template +template struct batch_memcpy_impl { using input_buffer_type = typename std::iterator_traits::value_type; @@ -278,6 +341,12 @@ struct batch_memcpy_impl using input_type = typename std::iterator_traits::value_type; + using Alias = + typename std::conditional::value_type>::value_type>::type; + // top level policy static constexpr uint32_t block_size = Config::non_blev_block_size; static constexpr uint32_t buffers_per_thread = Config::non_blev_buffers_per_thread; @@ -549,9 +618,9 @@ struct batch_memcpy_impl { const auto buffer_id = buffers_by_size_class[buffer_offset].buffer_id; - batch_memcpy::vectorized_copy_bytes(tile_buffers.srcs[buffer_id], - tile_buffers.dsts[buffer_id], - tile_buffers.sizes[buffer_id]); + batch_memcpy::copy_items(tile_buffers.srcs[buffer_id], + tile_buffers.dsts[buffer_id], + tile_buffers.sizes[buffer_id]); } } @@ -638,12 +707,12 @@ struct batch_memcpy_impl if(is_full_window) { - uint8_t src_byte[tlev_bytes_per_thread]; + Alias src_byte[tlev_bytes_per_thread]; ROCPRIM_UNROLL for(uint32_t i = 0; i < tlev_bytes_per_thread; ++i) { - src_byte[i] = batch_memcpy::read_byte( + src_byte[i] = batch_memcpy::read_item( tile_buffers.srcs[zipped_byte_assignment[i].tile_buffer_id], zipped_byte_assignment[i].buffer_byte_offset); } @@ -651,7 +720,7 @@ struct batch_memcpy_impl ROCPRIM_UNROLL for(uint32_t i = 0; i < tlev_bytes_per_thread; ++i) { - batch_memcpy::write_byte( + batch_memcpy::write_item( tile_buffers.dsts[zipped_byte_assignment[i].tile_buffer_id], zipped_byte_assignment[i].buffer_byte_offset, src_byte[i]); @@ -669,12 +738,12 @@ struct batch_memcpy_impl const auto buffer_id = zipped_byte_assignment[i].tile_buffer_id; const auto buffer_offset = zipped_byte_assignment[i].buffer_byte_offset; - const auto src_byte - = batch_memcpy::read_byte(tile_buffers.srcs[buffer_id], - buffer_offset); - batch_memcpy::write_byte(tile_buffers.dsts[buffer_id], - buffer_offset, - src_byte); + const auto src_byte = batch_memcpy::read_item( + tile_buffers.srcs[buffer_id], + buffer_offset); + batch_memcpy::write_item(tile_buffers.dsts[buffer_id], + buffer_offset, + src_byte); } absolute_tlev_byte_offset += block_size; } @@ -916,9 +985,12 @@ struct batch_memcpy_impl { if(thread_offset < blev_buffers.sizes[buffer_id]) { - uint8_t item - = batch_memcpy::read_byte(blev_buffers.srcs[buffer_id], thread_offset); - batch_memcpy::write_byte(blev_buffers.dsts[buffer_id], thread_offset, item); + Alias item + = batch_memcpy::read_item(blev_buffers.srcs[buffer_id], + thread_offset); + batch_memcpy::write_item(blev_buffers.dsts[buffer_id], + thread_offset, + item); } } tile_id += flat_grid_size; @@ -930,16 +1002,244 @@ struct batch_memcpy_impl - tile_offset_within_buffer), static_cast(blev_tile_size)); - batch_memcpy::vectorized_copy_bytes(blev_buffers.srcs[buffer_id], - blev_buffers.dsts[buffer_id], - items_to_copy, - tile_offset_within_buffer); + batch_memcpy::copy_items(blev_buffers.srcs[buffer_id], + blev_buffers.dsts[buffer_id], + items_to_copy, + tile_offset_within_buffer); tile_id += flat_grid_size; } } }; +template +ROCPRIM_INLINE static hipError_t batch_memcpy_func(void* temporary_storage, + size_t& storage_size, + InputBufferItType sources, + OutputBufferItType destinations, + BufferSizeItType sizes, + uint32_t num_copies, + hipStream_t stream = hipStreamDefault, + bool debug_synchronous = false) +{ + using Config = detail::default_or_custom_config>; + + static_assert(Config::wlev_size_threshold < Config::blev_size_threshold, + "wlev_size_threshold should be smaller than blev_size_threshold"); + + using BufferOffsetType = unsigned int; + using BlockOffsetType = unsigned int; + + hipError_t error = hipSuccess; + + using batch_memcpy_impl_type = detail::batch_memcpy_impl; + + static constexpr uint32_t non_blev_block_size = Config::non_blev_block_size; + static constexpr uint32_t non_blev_buffers_per_thread = Config::non_blev_buffers_per_thread; + static constexpr uint32_t blev_block_size = Config::blev_block_size; + + constexpr uint32_t buffers_per_block = non_blev_block_size * non_blev_buffers_per_thread; + const uint32_t num_blocks = rocprim::detail::ceiling_div(num_copies, buffers_per_block); + + using scan_state_buffer_type = rocprim::detail::lookback_scan_state; + using scan_state_block_type = rocprim::detail::lookback_scan_state; + + // Pack buffers + typename batch_memcpy_impl_type::copyable_buffers const buffers{ + sources, + destinations, + sizes, + }; + + detail::temp_storage::layout scan_state_buffer_layout{}; + error = scan_state_buffer_type::get_temp_storage_layout(num_blocks, + stream, + scan_state_buffer_layout); + if(error != hipSuccess) + { + return error; + } + + detail::temp_storage::layout blev_block_scan_state_layout{}; + error = scan_state_block_type::get_temp_storage_layout(num_blocks, + stream, + blev_block_scan_state_layout); + if(error != hipSuccess) + { + return error; + } + + uint8_t* blev_buffer_scan_data; + uint8_t* blev_block_scan_state_data; + + // The non-blev kernel will prepare blev copy. Communication between the two + // kernels is done via `blev_buffers`. + typename batch_memcpy_impl_type::copyable_blev_buffers blev_buffers{}; + + // Partition `d_temp_storage`. + // If `d_temp_storage` is null, calculate the allocation size instead. + error = detail::temp_storage::partition( + temporary_storage, + storage_size, + detail::temp_storage::make_linear_partition( + detail::temp_storage::ptr_aligned_array(&blev_buffers.srcs, num_copies), + detail::temp_storage::ptr_aligned_array(&blev_buffers.dsts, num_copies), + detail::temp_storage::ptr_aligned_array(&blev_buffers.sizes, num_copies), + detail::temp_storage::ptr_aligned_array(&blev_buffers.offsets, num_copies), + detail::temp_storage::make_partition(&blev_buffer_scan_data, scan_state_buffer_layout), + detail::temp_storage::make_partition(&blev_block_scan_state_data, + blev_block_scan_state_layout))); + + // If allocation failed, return error. + if(error != hipSuccess) + { + return error; + } + + // Return the storage size. + if(temporary_storage == nullptr) + { + return hipSuccess; + } + + // Compute launch parameters. + + int device_id = hipGetStreamDeviceId(stream); + + // Get the number of multiprocessors + int multiprocessor_count{}; + error = hipDeviceGetAttribute(&multiprocessor_count, + hipDeviceAttributeMultiprocessorCount, + device_id); + if(error != hipSuccess) + { + return error; + } + + // `hipOccupancyMaxActiveBlocksPerMultiprocessor` uses the default device. + // We need to perserve the current default device id while we change it temporarily + // to get the max occupancy on this stream. + int previous_device; + error = hipGetDevice(&previous_device); + if(error != hipSuccess) + { + return error; + } + + error = hipSetDevice(device_id); + if(error != hipSuccess) + { + return error; + } + + int blev_occupancy{}; + error = hipOccupancyMaxActiveBlocksPerMultiprocessor(&blev_occupancy, + batch_memcpy_impl_type::blev_memcpy_kernel, + blev_block_size, + 0 /* dynSharedMemPerBlk */); + if(error != hipSuccess) + { + return error; + } + + // Restore the default device id to initial state + error = hipSetDevice(previous_device); + if(error != hipSuccess) + { + return error; + } + + constexpr BlockOffsetType init_kernel_threads = 128; + const BlockOffsetType init_kernel_grid_size + = rocprim::detail::ceiling_div(num_blocks, init_kernel_threads); + + auto batch_memcpy_blev_grid_size + = multiprocessor_count * blev_occupancy * 1 /* subscription factor */; + + BlockOffsetType batch_memcpy_grid_size = num_blocks; + + // Prepare init_scan_states_kernel. + scan_state_buffer_type scan_state_buffer{}; + error = scan_state_buffer_type::create(scan_state_buffer, + blev_buffer_scan_data, + num_blocks, + stream); + if(error != hipSuccess) + { + return error; + } + + scan_state_block_type scan_state_block{}; + error = scan_state_block_type::create(scan_state_block, + blev_block_scan_state_data, + num_blocks, + stream); + if(error != hipSuccess) + { + return error; + } + + // Launch init_scan_states_kernel. + batch_memcpy_impl_type:: + init_tile_state_kernel<<>>( + scan_state_buffer, + scan_state_block, + num_blocks); + error = hipGetLastError(); + if(error != hipSuccess) + { + return error; + } + if(debug_synchronous) + { + hipStreamSynchronize(stream); + } + + // Launch batch_memcpy_non_blev_kernel. + batch_memcpy_impl_type:: + non_blev_memcpy_kernel<<>>( + buffers, + num_copies, + blev_buffers, + scan_state_buffer, + scan_state_block); + error = hipGetLastError(); + if(error != hipSuccess) + { + return error; + } + if(debug_synchronous) + { + hipStreamSynchronize(stream); + } + + // Launch batch_memcpy_blev_kernel. + batch_memcpy_impl_type:: + blev_memcpy_kernel<<>>( + blev_buffers, + scan_state_buffer, + batch_memcpy_grid_size - 1); + error = hipGetLastError(); + if(error != hipSuccess) + { + return error; + } + if(debug_synchronous) + { + hipStreamSynchronize(stream); + } + + return hipSuccess; +} + } // namespace detail END_ROCPRIM_NAMESPACE diff --git a/rocprim/include/rocprim/device/detail/device_radix_sort.hpp b/rocprim/include/rocprim/device/detail/device_radix_sort.hpp index a80989940..c89f351e6 100644 --- a/rocprim/include/rocprim/device/detail/device_radix_sort.hpp +++ b/rocprim/include/rocprim/device/detail/device_radix_sort.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -26,10 +26,9 @@ #include "../../config.hpp" #include "../../detail/various.hpp" -#include "../../detail/radix_sort.hpp" -#include "../../intrinsics.hpp" #include "../../functional.hpp" +#include "../../intrinsics.hpp" #include "../../types.hpp" #include "../../block/block_discontinuity.hpp" @@ -40,6 +39,7 @@ #include "../../block/block_radix_sort.hpp" #include "../../block/block_scan.hpp" #include "../../block/block_store_func.hpp" +#include "../../thread/radix_key_codec.hpp" BEGIN_ROCPRIM_NAMESPACE @@ -48,42 +48,51 @@ namespace detail // Wrapping functions that allow one to call proper methods (with or without values) // (a variant with values is enabled only when Value is not empty_type) -template -ROCPRIM_DEVICE ROCPRIM_INLINE -void sort_block(SortType sorter, - SortKey (&keys)[ItemsPerThread], - SortValue (&values)[ItemsPerThread], - typename SortType::storage_type& storage, - unsigned int begin_bit, - unsigned int end_bit) +template +ROCPRIM_DEVICE ROCPRIM_INLINE void sort_block(SortType sorter, + SortKey (&keys)[ItemsPerThread], + SortValue (&values)[ItemsPerThread], + typename SortType::storage_type& storage, + Decomposer decomposer, + unsigned int begin_bit, + unsigned int end_bit) { if(Descending) { - sorter.sort_desc(keys, values, storage, begin_bit, end_bit); + sorter.sort_desc(keys, values, storage, begin_bit, end_bit, decomposer); } else { - sorter.sort(keys, values, storage, begin_bit, end_bit); + sorter.sort(keys, values, storage, begin_bit, end_bit, decomposer); } } -template -ROCPRIM_DEVICE ROCPRIM_INLINE -void sort_block(SortType sorter, - SortKey (&keys)[ItemsPerThread], - ::rocprim::empty_type (&values)[ItemsPerThread], - typename SortType::storage_type& storage, - unsigned int begin_bit, - unsigned int end_bit) +template +ROCPRIM_DEVICE ROCPRIM_INLINE void sort_block(SortType sorter, + SortKey (&keys)[ItemsPerThread], + ::rocprim::empty_type (&values)[ItemsPerThread], + typename SortType::storage_type& storage, + Decomposer decomposer, + unsigned int begin_bit, + unsigned int end_bit) { (void) values; if(Descending) { - sorter.sort_desc(keys, storage, begin_bit, end_bit); + sorter.sort_desc(keys, storage, begin_bit, end_bit, decomposer); } else { - sorter.sort(keys, storage, begin_bit, end_bit); + sorter.sort(keys, storage, begin_bit, end_bit, decomposer); } } @@ -126,7 +135,7 @@ struct radix_digit_count_helper using key_type = typename std::iterator_traits::value_type; - using key_codec = radix_key_codec; + using key_codec = ::rocprim::radix_key_codec; using bit_key_type = typename key_codec::bit_key_type; const unsigned int flat_id = ::rocprim::detail::block_thread_id<0>(); @@ -203,8 +212,6 @@ struct radix_sort_single_helper using key_type = Key; using value_type = Value; - using key_codec = radix_key_codec; - using bit_key_type = typename key_codec::bit_key_type; using sort_type = ::rocprim::block_radix_sort; static constexpr bool with_values = !std::is_same::value; @@ -214,21 +221,20 @@ struct radix_sort_single_helper typename sort_type::storage_type sort; }; - template< - class KeysInputIterator, - class KeysOutputIterator, - class ValuesInputIterator, - class ValuesOutputIterator - > - ROCPRIM_DEVICE ROCPRIM_INLINE - void sort_single(KeysInputIterator keys_input, - KeysOutputIterator keys_output, - ValuesInputIterator values_input, - ValuesOutputIterator values_output, - unsigned int size, - unsigned int bit, - unsigned int current_radix_bits, - storage_type& storage) + template + ROCPRIM_DEVICE ROCPRIM_INLINE void sort_single(KeysInputIterator keys_input, + KeysOutputIterator keys_output, + ValuesInputIterator values_input, + ValuesOutputIterator values_output, + unsigned int size, + Decomposer decomposer, + unsigned int bit, + unsigned int current_radix_bits, + storage_type& storage) { const unsigned int flat_id = ::rocprim::detail::block_thread_id<0>(); const unsigned int flat_block_id = ::rocprim::detail::block_id<0>(); @@ -239,7 +245,6 @@ struct radix_sort_single_helper using key_type = typename std::iterator_traits::value_type; using key_codec = radix_key_codec; - using bit_key_type = typename key_codec::bit_key_type; key_type keys[ItemsPerThread]; value_type values[ItemsPerThread]; @@ -253,7 +258,7 @@ struct radix_sort_single_helper } else { - const key_type out_of_bounds = key_codec::decode(bit_key_type(-1)); + const key_type out_of_bounds = key_codec::get_out_of_bounds_key(decomposer); block_load_direct_blocked(flat_id, keys_input + block_offset, keys, @@ -268,7 +273,13 @@ struct radix_sort_single_helper } } - sort_block(sort_type(), keys, values, storage.sort, bit, bit + current_radix_bits); + sort_block(sort_type(), + keys, + values, + storage.sort, + decomposer, + bit, + bit + current_radix_bits); // Store keys and values if(!is_incomplete_block) @@ -313,7 +324,7 @@ struct radix_sort_and_scatter_helper using key_type = Key; using value_type = Value; - using key_codec = radix_key_codec; + using key_codec = ::rocprim::radix_key_codec; using bit_key_type = typename key_codec::bit_key_type; using keys_load_type = ::rocprim::block_load< key_type, BlockSize, ItemsPerThread, @@ -407,7 +418,13 @@ struct radix_sort_and_scatter_helper } ::rocprim::syncthreads(); - sort_block(sort_type(), keys, values, storage.sort, bit, bit + current_radix_bits); + sort_block(sort_type(), + keys, + values, + storage.sort, + identity_decomposer{}, + bit, + bit + current_radix_bits); bit_key_type bit_keys[ItemsPerThread]; unsigned int digits[ItemsPerThread]; @@ -481,23 +498,22 @@ struct radix_sort_and_scatter_helper } }; -template< - unsigned int BlockSize, - unsigned int ItemsPerThread, - bool Descending, - class KeysInputIterator, - class KeysOutputIterator, - class ValuesInputIterator, - class ValuesOutputIterator -> -ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE -void sort_single(KeysInputIterator keys_input, - KeysOutputIterator keys_output, - ValuesInputIterator values_input, - ValuesOutputIterator values_output, - unsigned int size, - unsigned int bit, - unsigned int current_radix_bits) +template +ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE void sort_single(KeysInputIterator keys_input, + KeysOutputIterator keys_output, + ValuesInputIterator values_input, + ValuesOutputIterator values_output, + unsigned int size, + Decomposer decomposer, + unsigned int bit, + unsigned int current_radix_bits) { using key_type = typename std::iterator_traits::value_type; using value_type = typename std::iterator_traits::value_type; @@ -509,11 +525,15 @@ void sort_single(KeysInputIterator keys_input, ROCPRIM_SHARED_MEMORY typename sort_single_helper::storage_type storage; - sort_single_helper().template sort_single( - keys_input, keys_output, values_input, values_output, - size, bit, current_radix_bits, - storage - ); + sort_single_helper().template sort_single(keys_input, + keys_output, + values_input, + values_output, + size, + decomposer, + bit, + current_radix_bits, + storage); } template @@ -528,8 +548,8 @@ auto compare_nan_sensitive(const T& a, const T& b) using bit_key_type = typename float_bit_mask::bit_type; static constexpr auto sign_bit = float_bit_mask::sign_bit; - auto a_bits = __builtin_bit_cast(bit_key_type, a); - auto b_bits = __builtin_bit_cast(bit_key_type, b); + auto a_bits = ::rocprim::detail::bit_cast(a); + auto b_bits = ::rocprim::detail::bit_cast(b); // convert -0.0 to +0.0 a_bits = a_bits == sign_bit ? 0 : a_bits; @@ -543,106 +563,119 @@ auto compare_nan_sensitive(const T& a, const T& b) } template -ROCPRIM_DEVICE ROCPRIM_INLINE -auto compare_nan_sensitive(const T& a, const T& b) - -> typename std::enable_if::value, bool>::type +ROCPRIM_DEVICE auto compare_nan_sensitive(const T& a, const T& b) -> + typename std::enable_if::value, bool>::type { return a > b; } -template< - bool Descending, - bool UseRadixMask, - class T, - class Enable = void -> +template struct radix_merge_compare; template -struct radix_merge_compare +struct radix_merge_compare { - ROCPRIM_DEVICE ROCPRIM_INLINE - bool operator()(const T& a, const T& b) const + ROCPRIM_DEVICE bool operator()(const T& a, const T& b) const { return compare_nan_sensitive(b, a); } }; template -struct radix_merge_compare +struct radix_merge_compare { - ROCPRIM_DEVICE ROCPRIM_INLINE - bool operator()(const T& a, const T& b) const + ROCPRIM_DEVICE bool operator()(const T& a, const T& b) const { return compare_nan_sensitive(a, b); } }; -template -struct radix_merge_compare::value>::type> +template +struct radix_merge_compare { T radix_mask; - ROCPRIM_HOST_DEVICE ROCPRIM_INLINE - radix_merge_compare(const unsigned int start_bit, const unsigned int current_radix_bits) + ROCPRIM_HOST_DEVICE radix_merge_compare(const unsigned int start_bit, + const unsigned int current_radix_bits, + identity_decomposer = {}) { T radix_mask_upper = (T(1) << (current_radix_bits + start_bit)) - 1; T radix_mask_bottom = (T(1) << start_bit) - 1; radix_mask = radix_mask_upper ^ radix_mask_bottom; } - ROCPRIM_DEVICE ROCPRIM_INLINE - bool operator()(const T& a, const T& b) const + ROCPRIM_DEVICE bool operator()(const T& a, const T& b) const { const T masked_key_a = a & radix_mask; const T masked_key_b = b & radix_mask; - return masked_key_b > masked_key_a; + return Descending ? masked_key_a > masked_key_b : masked_key_b > masked_key_a; } }; -template -struct radix_merge_compare::value>::type> +template +struct radix_merge_compare { - T radix_mask; + Decomposer decomposer_; + unsigned int start_bit_; + unsigned int radix_bits_; + + ROCPRIM_HOST_DEVICE radix_merge_compare(const unsigned int start_bit, + const unsigned int current_radix_bits, + Decomposer decomposer) + : decomposer_(decomposer), start_bit_(start_bit), radix_bits_(current_radix_bits) + {} - ROCPRIM_HOST_DEVICE ROCPRIM_INLINE - radix_merge_compare(const unsigned int start_bit, const unsigned int current_radix_bits) + ROCPRIM_HOST_DEVICE bool operator()(T lhs, T rhs) const { - T radix_mask_upper = (T(1) << (current_radix_bits + start_bit)) - 1; - T radix_mask_bottom = (T(1) << start_bit) - 1; - radix_mask = (radix_mask_upper ^ radix_mask_bottom); - } + using codec_t = radix_key_codec; - ROCPRIM_DEVICE ROCPRIM_INLINE - bool operator()(const T& a, const T& b) const - { - const T masked_key_a = a & radix_mask; - const T masked_key_b = b & radix_mask; - return masked_key_a > masked_key_b; - } -}; + // Encoding the values considers the ascending / descending nature of the sort + codec_t::encode_inplace(lhs, decomposer_); + codec_t::encode_inplace(rhs, decomposer_); -template -struct radix_merge_compare::value>::type> -{ - // radix_merge_compare supports masks only for integrals. - // even though masks are never used for floating point-types, - // it needs to be able to compile. - ROCPRIM_HOST_DEVICE ROCPRIM_INLINE - radix_merge_compare(const unsigned int, const unsigned int){} + // Digits can be extracted in 32 bit batches, but radix_bits_ can be larger than that + static constexpr int digit_batch_size = 32; - ROCPRIM_DEVICE ROCPRIM_INLINE - bool operator()(const T&, const T&) const { return false; } + // Moving from MSB to LSB + int current_start_bit + = rocprim::max(0, static_cast(start_bit_ + radix_bits_) - digit_batch_size); + unsigned int remaining_radix_bits = radix_bits_; + for(; remaining_radix_bits > 0;) + { + const unsigned int current_radix_bits + = rocprim::min(remaining_radix_bits, static_cast(digit_batch_size)); + remaining_radix_bits -= current_radix_bits; + + const unsigned int lhs_digits + = codec_t::extract_digit(lhs, + static_cast(current_start_bit), + current_radix_bits, + decomposer_); + const unsigned int rhs_digits + = codec_t::extract_digit(rhs, + static_cast(current_start_bit), + current_radix_bits, + decomposer_); + + // Since we are moving from MSB to LSB, the earlier iteration implies the relation (if digits are not equal) + if(lhs_digits != rhs_digits) + { + return rhs_digits > lhs_digits; + } + current_start_bit + = rocprim::max(current_start_bit - static_cast(current_radix_bits), + static_cast(start_bit_)); + } + return false; + } }; template + bool Descending, + class Decomposer> struct onesweep_histograms_helper { static constexpr unsigned int radix_size = 1u << RadixBits; @@ -659,7 +692,6 @@ struct onesweep_histograms_helper using counter_type = uint32_t; using key_codec = radix_key_codec; - using bit_key_type = typename key_codec::bit_key_type; struct storage_type { @@ -686,8 +718,9 @@ struct onesweep_histograms_helper template ROCPRIM_DEVICE void count_digits_at_place(const unsigned int flat_id, const unsigned int stripe, - const bit_key_type (&bit_keys)[ItemsPerThread], + const KeyType (&keys)[ItemsPerThread], const unsigned int place, + Decomposer decomposer, const unsigned int start_bit, const unsigned int current_radix_bits, const unsigned int valid_count, @@ -700,7 +733,7 @@ struct onesweep_histograms_helper if(IsFull || pos < valid_count) { const unsigned int digit - = key_codec::extract_digit(bit_keys[i], start_bit, current_radix_bits); + = key_codec::extract_digit(keys[i], start_bit, current_radix_bits, decomposer); ::rocprim::detail::atomic_add(&get_counter(stripe, place, digit, storage), 1); } } @@ -710,6 +743,7 @@ struct onesweep_histograms_helper ROCPRIM_DEVICE void count_digits(KeysInputIterator keys_input, Offset* global_digit_counts, const unsigned int valid_count, + Decomposer decomposer, const unsigned int begin_bit, const unsigned int end_bit, storage_type& storage) @@ -734,19 +768,19 @@ struct onesweep_histograms_helper ::rocprim::syncthreads(); // Compute a shared histogram for each digit and each place. - bit_key_type bit_keys[ItemsPerThread]; ROCPRIM_UNROLL for(unsigned int i = 0; i < ItemsPerThread; ++i) { - bit_keys[i] = key_codec::encode(keys[i]); + key_codec::encode_inplace(keys[i], decomposer); } for(unsigned int bit = begin_bit, place = 0; bit < end_bit; bit += RadixBits, ++place) { count_digits_at_place(flat_id, stripe, - bit_keys, + keys, place, + decomposer, bit, min(RadixBits, end_bit - bit), valid_count, @@ -783,17 +817,23 @@ template + class Offset, + class Decomposer> ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE void onesweep_histograms(KeysInputIterator keys_input, Offset* global_digit_counts, const Offset size, const Offset full_blocks, + Decomposer decomposer, const unsigned int begin_bit, const unsigned int end_bit) { using key_type = typename std::iterator_traits::value_type; - using count_helper_type - = onesweep_histograms_helper; + using count_helper_type = onesweep_histograms_helper; constexpr unsigned int items_per_block = BlockSize * ItemsPerThread; @@ -807,6 +847,7 @@ ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE void onesweep_histograms(KeysInputIterator count_helper_type{}.template count_digits(keys_input + block_offset, global_digit_counts, items_per_block, + decomposer, begin_bit, end_bit, storage); @@ -817,6 +858,7 @@ ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE void onesweep_histograms(KeysInputIterator count_helper_type{}.template count_digits(keys_input + block_offset, global_digit_counts, valid_in_last_block, + decomposer, begin_bit, end_bit, storage); @@ -898,7 +940,8 @@ template + block_radix_rank_algorithm RadixRankAlgorithm, + class Decomposer> struct onesweep_iteration_helper { static constexpr unsigned int radix_size = 1u << RadixBits; @@ -906,7 +949,6 @@ struct onesweep_iteration_helper static constexpr bool with_values = !std::is_same::value; using key_codec = radix_key_codec; - using bit_key_type = typename key_codec::bit_key_type; using radix_rank_type = ::rocprim::block_radix_rank; static constexpr bool load_warp_striped @@ -922,7 +964,7 @@ struct onesweep_iteration_helper Offset global_digit_offsets[radix_size]; union { - bit_key_type ordered_block_keys[items_per_block]; + Key ordered_block_keys[items_per_block]; Value ordered_block_values[items_per_block]; }; }; @@ -942,6 +984,7 @@ struct onesweep_iteration_helper Offset* global_digit_offsets_in, Offset* global_digit_offsets_out, onesweep_lookback_state* lookback_states, + Decomposer decomposer, const unsigned int bit, const unsigned int current_radix_bits, const unsigned int valid_items, @@ -972,7 +1015,7 @@ struct onesweep_iteration_helper // Note that this will lead to an incorrect digit count. Since this is the very last digit, // it does not matter. It does cause the final digit offset to be increased past its end, // but again this does not matter since this is the last iteration in which it will be used anyway. - const Key out_of_bounds = key_codec::decode(bit_key_type(-1)); + const Key out_of_bounds = key_codec::get_out_of_bounds_key(decomposer); if ROCPRIM_IF_CONSTEXPR(load_warp_striped) { block_load_direct_warp_striped(flat_id, @@ -991,11 +1034,10 @@ struct onesweep_iteration_helper } } - bit_key_type bit_keys[ItemsPerThread]; ROCPRIM_UNROLL for(unsigned int i = 0; i < ItemsPerThread; ++i) { - bit_keys[i] = key_codec::encode(keys[i]); + key_codec::encode_inplace(keys[i], decomposer); } // Compute the block-based key ranks, the digit counts, and the prefix sum of the digit counts. @@ -1005,11 +1047,11 @@ struct onesweep_iteration_helper // Tile-wide digit count unsigned int digit_counts[digits_per_thread]; radix_rank_type{}.rank_keys( - bit_keys, + keys, ranks, storage.rank, - [bit, current_radix_bits](const bit_key_type& key) - { return key_codec::extract_digit(key, bit, current_radix_bits); }, + [bit, current_radix_bits, decomposer](const Key& key) + { return key_codec::extract_digit(key, bit, current_radix_bits, decomposer); }, exclusive_digit_prefix, digit_counts); @@ -1019,7 +1061,7 @@ struct onesweep_iteration_helper ROCPRIM_UNROLL for(unsigned int i = 0; i < ItemsPerThread; ++i) { - storage.ordered_block_keys[ranks[i]] = bit_keys[i]; + storage.ordered_block_keys[ranks[i]] = keys[i]; } ::rocprim::syncthreads(); @@ -1082,11 +1124,12 @@ struct onesweep_iteration_helper const unsigned int rank = i * BlockSize + flat_id; if(IsFull || rank < valid_items) { - const bit_key_type bit_key = storage.ordered_block_keys[rank]; + Key key = storage.ordered_block_keys[rank]; const unsigned int digit - = key_codec::extract_digit(bit_key, bit, current_radix_bits); + = key_codec::extract_digit(key, bit, current_radix_bits, decomposer); + key_codec::decode_inplace(key, decomposer); const Offset global_offset = storage.global_digit_offsets[digit]; - keys_output[rank + global_offset] = key_codec::decode(bit_key); + keys_output[rank + global_offset] = key; } } @@ -1132,8 +1175,8 @@ struct onesweep_iteration_helper const unsigned int rank = i * BlockSize + flat_id; if(IsFull || rank < valid_items) { - const bit_key_type bit_key = storage.ordered_block_keys[rank]; - digits[i] = key_codec::extract_digit(bit_key, bit, current_radix_bits); + const Key key = storage.ordered_block_keys[rank]; + digits[i] = key_codec::extract_digit(key, bit, current_radix_bits, decomposer); } } @@ -1189,7 +1232,8 @@ template + class Offset, + class Decomposer> ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE void onesweep_iteration(KeysInputIterator keys_input, KeysOutputIterator keys_output, @@ -1199,6 +1243,7 @@ ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE void Offset* global_digit_offsets_in, Offset* global_digit_offsets_out, onesweep_lookback_state* lookback_states, + Decomposer decomposer, const unsigned int bit, const unsigned int current_radix_bits, const unsigned int full_blocks) @@ -1213,7 +1258,8 @@ ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE void ItemsPerThread, RadixBits, Descending, - RadixRankAlgorithm>; + RadixRankAlgorithm, + Decomposer>; constexpr unsigned int items_per_block = BlockSize * ItemsPerThread; const unsigned int block_id = ::rocprim::detail::block_id<0>(); @@ -1229,6 +1275,7 @@ ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE void global_digit_offsets_in, global_digit_offsets_out, lookback_states, + decomposer, bit, current_radix_bits, items_per_block, @@ -1244,6 +1291,7 @@ ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE void global_digit_offsets_in, global_digit_offsets_out, lookback_states, + decomposer, bit, current_radix_bits, valid_in_last_block, diff --git a/rocprim/include/rocprim/device/detail/device_segmented_radix_sort.hpp b/rocprim/include/rocprim/device/detail/device_segmented_radix_sort.hpp index 52fd28086..4da8d974f 100644 --- a/rocprim/include/rocprim/device/detail/device_segmented_radix_sort.hpp +++ b/rocprim/include/rocprim/device/detail/device_segmented_radix_sort.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -27,8 +27,9 @@ #include "../../config.hpp" #include "../../detail/various.hpp" -#include "../../intrinsics.hpp" #include "../../functional.hpp" +#include "../../intrinsics.hpp" +#include "../../type_traits.hpp" #include "../../types.hpp" #include "../../block/block_load.hpp" @@ -39,6 +40,8 @@ #include "../../warp/warp_sort.hpp" #include "../../warp/warp_store.hpp" +#include "../../thread/radix_key_codec.hpp" + #include "../device_segmented_radix_sort_config.hpp" #include "device_radix_sort.hpp" @@ -422,7 +425,13 @@ class segmented_radix_sort_single_block_helper } ::rocprim::syncthreads(); - sort_block(sort_type(), keys, values, storage.sort, begin_bit, end_bit); + sort_block(sort_type(), + keys, + values, + storage.sort, + identity_decomposer{}, + begin_bit, + end_bit); ::rocprim::syncthreads(); keys_store_type().store(keys_output + begin_offset, keys, valid_count, storage.keys_store); @@ -527,7 +536,7 @@ class segmented_warp_sort_helper< using key_type = Key; using value_type = Value; - using key_codec = ::rocprim::detail::radix_key_codec; + using key_codec = ::rocprim::radix_key_codec; using bit_key_type = typename key_codec::bit_key_type; using keys_load_type = ::rocprim::warp_load; @@ -565,6 +574,46 @@ class segmented_warp_sort_helper< typename sort_type::storage_type sort; }; +private: + template + ROCPRIM_DEVICE auto invoke_warp_sort(stable_key_type (&stable_keys)[items_per_thread], + value_type (&values)[items_per_thread], + storage_type& storage, + unsigned int begin_bit, + unsigned int end_bit) + -> std::enable_if_t::value> + { + (void)begin_bit; + (void)end_bit; + sort_type().sort(stable_keys, + values, + storage.sort, + make_stable_comparator(radix_comparator_type{})); + } + + template + ROCPRIM_DEVICE auto invoke_warp_sort(stable_key_type (&stable_keys)[items_per_thread], + value_type (&values)[items_per_thread], + storage_type& storage, + unsigned int begin_bit, + unsigned int end_bit) + -> std::enable_if_t::value> + { + if(begin_bit == 0 && end_bit == 8 * sizeof(key_type)) + { + sort_type().sort(stable_keys, + values, + storage.sort, + make_stable_comparator(radix_comparator_type{})); + } + else + { + radix_comparator_type comparator(begin_bit, end_bit - begin_bit); + sort_type().sort(stable_keys, values, storage.sort, make_stable_comparator(comparator)); + } + } + +public: template< class KeysInputIterator, class KeysOutputIterator, @@ -605,18 +654,7 @@ class segmented_warp_sort_helper< } ::rocprim::wave_barrier(); - if(begin_bit == 0 && end_bit == 8 * sizeof(key_type)) - { - sort_type().sort(stable_keys, - values, - storage.sort, - make_stable_comparator(radix_comparator_type{})); - } - else - { - radix_comparator_type comparator(begin_bit, end_bit - begin_bit); - sort_type().sort(stable_keys, values, storage.sort, make_stable_comparator(comparator)); - } + invoke_warp_sort(stable_keys, values, storage, begin_bit, end_bit); ROCPRIM_UNROLL for(unsigned int i = 0; i < items_per_thread; i++) diff --git a/rocprim/include/rocprim/device/device_adjacent_difference.hpp b/rocprim/include/rocprim/device/device_adjacent_difference.hpp index 13476ce43..c8df93869 100644 --- a/rocprim/include/rocprim/device/device_adjacent_difference.hpp +++ b/rocprim/include/rocprim/device/device_adjacent_difference.hpp @@ -41,6 +41,7 @@ #include #include #include +#include #include @@ -110,9 +111,12 @@ hipError_t adjacent_difference_impl(void* const temporary_storage, const hipStream_t stream, const bool debug_synchronous) { - using value_type = typename std::iterator_traits::value_type; + using value_type = typename std::iterator_traits::value_type; + using output_type = rocprim::invoke_result_binary_op_t; + using larger_type + = std::conditional_t<(sizeof(value_type) >= sizeof(output_type)), value_type, output_type>; - using config = wrapped_adjacent_difference_config; + using config = wrapped_adjacent_difference_config; detail::target_arch target_arch; hipError_t result = detail::host_target_arch(stream, target_arch); diff --git a/rocprim/include/rocprim/device/device_copy.hpp b/rocprim/include/rocprim/device/device_copy.hpp new file mode 100644 index 000000000..0d27e44c6 --- /dev/null +++ b/rocprim/include/rocprim/device/device_copy.hpp @@ -0,0 +1,150 @@ +// Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_DEVICE_DEVICE_COPY_HPP_ +#define ROCPRIM_DEVICE_DEVICE_COPY_HPP_ + +#include "../config.hpp" +#include "../functional.hpp" + +#include "config_types.hpp" + +#include "detail/device_batch_memcpy.hpp" +#include "device_copy_config.hpp" +#include "rocprim/device/detail/device_config_helper.hpp" + +BEGIN_ROCPRIM_NAMESPACE + +/// \brief Copy `sizes[i]` elements from `sources[i]` to `destinations[i]` for all `i` in the range [0, `num_copies`]. +/// +/// \tparam Config [optional] configuration of the primitive. It has to be \p batch_mcopy_config . +/// \tparam InputBufferItType type of iterator to source pointers. +/// \tparam OutputBufferItType type of iterator to desetination pointers. +/// \tparam BufferSizeItType type of iterator to sizes. +/// +/// \param [in] temporary_storage pointer to device-accessible temporary storage. +/// When a null pointer is passed, the required allocation size in bytes is written to +/// `storage_size` and the function returns without performing the copy. +/// \param [in, out] storage_size reference to the size in bytes of `temporary_storage`. +/// \param [in] sources iterator of source pointers. +/// \param [in] destinations iterator of destination pointers. +/// \param [in] sizes iterator of range sizes to copy. +/// \param [in] num_copies number of ranges to copy +/// \param [in] stream [optional] HIP stream object to enqueue the copy on. Default is `hipStreamDefault`. +/// \param [in] debug_synchronous - [optional] If true, synchronization after every kernel +/// launch is forced in order to check for errors. The default value is `false`. +/// +/// Performs multiple device to device copies as a single batched operation. +/// Roughly equivalent to +/// \code{.cpp} +/// for (auto i = 0; i < num_copies; ++i) { +/// auto* src = sources[i]; +/// auto* dst = destinations[i]; +/// auto size = sizes[i]; +/// for (auto j = 0; j < size; ++j) +/// { +/// dst[j] = src[j]; +/// } +/// } +/// \endcode +/// except executed on the device in parallel. +/// Note that sources and destinations do not have to be part of the same array. I.e. you can copy +/// from both array A and B to array C and D with a single call to this function. +/// Source ranges are allowed to overlap, +/// however, destinations overlapping with either other destinations or with sources is not allowed, +/// and will result in undefined behaviour. +/// +/// \par Example +/// \parblock +/// In this example multiple sections of data are copied from \p a to \p b . +/// +/// \code{.cpp} +/// #include +ROCPRIM_INLINE static hipError_t batch_copy(void* temporary_storage, + size_t& storage_size, + InputBufferItType sources, + OutputBufferItType destinations, + BufferSizeItType sizes, + uint32_t num_copies, + hipStream_t stream = hipStreamDefault, + bool debug_synchronous = false) +{ + return detail:: + batch_memcpy_func( + temporary_storage, + storage_size, + sources, + destinations, + sizes, + num_copies, + stream, + debug_synchronous); +} + +END_ROCPRIM_NAMESPACE + +#endif diff --git a/rocprim/include/rocprim/device/device_copy_config.hpp b/rocprim/include/rocprim/device/device_copy_config.hpp new file mode 100644 index 000000000..275956e87 --- /dev/null +++ b/rocprim/include/rocprim/device/device_copy_config.hpp @@ -0,0 +1,67 @@ +// Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_DEVICE_DEVICE_COPY_CONFIG_HPP_ +#define ROCPRIM_DEVICE_DEVICE_COPY_CONFIG_HPP_ + +#include "config_types.hpp" +#include "detail/device_config_helper.hpp" +#include "device_memcpy_config.hpp" + +/// \addtogroup primitivesmodule_deviceconfigs +/// @{ + +BEGIN_ROCPRIM_NAMESPACE + +/// \brief +/// +/// \tparam NonBlevBlockSize - number of threads per block for thread- and warp-level copy. +/// \tparam NonBlevBuffersPerThreaed - number of buffers processed per thread. +/// \tparam TlevBytesPerThread - number of bytes per thread for thread-level copy. +/// \tparam BlevBlockSize - number of thread per block for block-level copy. +/// \tparam BlevBytesPerThread - number of bytes per thread for block-level copy. +/// \tparam WlevSizeThreshold - minimum size to use warp-level copy instead of thread-level. +/// \tparam BlevSizeThreshold - minimum size to use block-level copy instead of warp-level. +template +struct batch_copy_config + : public batch_memcpy_config +{ +#ifndef DOXYGEN_SHOULD_SKIP_THIS +#endif +}; + +END_ROCPRIM_NAMESPACE + +/// @} +// end of group primitivesmodule_deviceconfigs + +#endif diff --git a/rocprim/include/rocprim/device/device_memcpy.hpp b/rocprim/include/rocprim/device/device_memcpy.hpp index e76a11e54..32d8df455 100644 --- a/rocprim/include/rocprim/device/device_memcpy.hpp +++ b/rocprim/include/rocprim/device/device_memcpy.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -130,215 +130,16 @@ ROCPRIM_INLINE static hipError_t batch_memcpy(void* temporary_stora hipStream_t stream = hipStreamDefault, bool debug_synchronous = false) { - using Config = detail::default_or_custom_config>; - - static_assert(Config::wlev_size_threshold < Config::blev_size_threshold, - "wlev_size_threshold should be smaller than blev_size_threshold"); - - using BufferOffsetType = unsigned int; - using BlockOffsetType = unsigned int; - - hipError_t error = hipSuccess; - - using batch_memcpy_impl_type = detail:: - batch_memcpy_impl; - - static constexpr uint32_t non_blev_block_size = Config::non_blev_block_size; - static constexpr uint32_t non_blev_buffers_per_thread = Config::non_blev_buffers_per_thread; - static constexpr uint32_t blev_block_size = Config::blev_block_size; - - constexpr uint32_t buffers_per_block = non_blev_block_size * non_blev_buffers_per_thread; - const uint32_t num_blocks = rocprim::detail::ceiling_div(num_copies, buffers_per_block); - - using scan_state_buffer_type = rocprim::detail::lookback_scan_state; - using scan_state_block_type = rocprim::detail::lookback_scan_state; - - // Pack buffers - typename batch_memcpy_impl_type::copyable_buffers const buffers{ - sources, - destinations, - sizes, - }; - - detail::temp_storage::layout scan_state_buffer_layout{}; - error = scan_state_buffer_type::get_temp_storage_layout(num_blocks, - stream, - scan_state_buffer_layout); - if(error != hipSuccess) - { - return error; - } - - detail::temp_storage::layout blev_block_scan_state_layout{}; - error = scan_state_block_type::get_temp_storage_layout(num_blocks, - stream, - blev_block_scan_state_layout); - if(error != hipSuccess) - { - return error; - } - - uint8_t* blev_buffer_scan_data; - uint8_t* blev_block_scan_state_data; - - // The non-blev kernel will prepare blev copy. Communication between the two - // kernels is done via `blev_buffers`. - typename batch_memcpy_impl_type::copyable_blev_buffers blev_buffers{}; - - // Partition `d_temp_storage`. - // If `d_temp_storage` is null, calculate the allocation size instead. - error = detail::temp_storage::partition( - temporary_storage, - storage_size, - detail::temp_storage::make_linear_partition( - detail::temp_storage::ptr_aligned_array(&blev_buffers.srcs, num_copies), - detail::temp_storage::ptr_aligned_array(&blev_buffers.dsts, num_copies), - detail::temp_storage::ptr_aligned_array(&blev_buffers.sizes, num_copies), - detail::temp_storage::ptr_aligned_array(&blev_buffers.offsets, num_copies), - detail::temp_storage::make_partition(&blev_buffer_scan_data, scan_state_buffer_layout), - detail::temp_storage::make_partition(&blev_block_scan_state_data, - blev_block_scan_state_layout))); - - // If allocation failed, return error. - if(error != hipSuccess) - { - return error; - } - - // Return the storage size. - if(temporary_storage == nullptr) - { - return hipSuccess; - } - - // Compute launch parameters. - - int device_id = hipGetStreamDeviceId(stream); - - // Get the number of multiprocessors - int multiprocessor_count{}; - error = hipDeviceGetAttribute(&multiprocessor_count, - hipDeviceAttributeMultiprocessorCount, - device_id); - if(error != hipSuccess) - { - return error; - } - - // `hipOccupancyMaxActiveBlocksPerMultiprocessor` uses the default device. - // We need to perserve the current default device id while we change it temporarily - // to get the max occupancy on this stream. - int previous_device; - error = hipGetDevice(&previous_device); - if(error != hipSuccess) - { - return error; - } - - error = hipSetDevice(device_id); - if(error != hipSuccess) - { - return error; - } - - int blev_occupancy{}; - error = hipOccupancyMaxActiveBlocksPerMultiprocessor(&blev_occupancy, - batch_memcpy_impl_type::blev_memcpy_kernel, - blev_block_size, - 0 /* dynSharedMemPerBlk */); - if(error != hipSuccess) - { - return error; - } - - // Restore the default device id to initial state - error = hipSetDevice(previous_device); - if(error != hipSuccess) - { - return error; - } - - constexpr BlockOffsetType init_kernel_threads = 128; - const BlockOffsetType init_kernel_grid_size - = rocprim::detail::ceiling_div(num_blocks, init_kernel_threads); - - auto batch_memcpy_blev_grid_size - = multiprocessor_count * blev_occupancy * 1 /* subscription factor */; - - BlockOffsetType batch_memcpy_grid_size = num_blocks; - - // Prepare init_scan_states_kernel. - scan_state_buffer_type scan_state_buffer{}; - error = scan_state_buffer_type::create(scan_state_buffer, - blev_buffer_scan_data, - num_blocks, - stream); - if(error != hipSuccess) - { - return error; - } - - scan_state_block_type scan_state_block{}; - error = scan_state_block_type::create(scan_state_block, - blev_block_scan_state_data, - num_blocks, - stream); - if(error != hipSuccess) - { - return error; - } - - // Launch init_scan_states_kernel. - batch_memcpy_impl_type:: - init_tile_state_kernel<<>>( - scan_state_buffer, - scan_state_block, - num_blocks); - error = hipGetLastError(); - if(error != hipSuccess) - { - return error; - } - if(debug_synchronous) - { - hipStreamSynchronize(stream); - } - - // Launch batch_memcpy_non_blev_kernel. - batch_memcpy_impl_type:: - non_blev_memcpy_kernel<<>>( - buffers, + return detail:: + batch_memcpy_func( + temporary_storage, + storage_size, + sources, + destinations, + sizes, num_copies, - blev_buffers, - scan_state_buffer, - scan_state_block); - error = hipGetLastError(); - if(error != hipSuccess) - { - return error; - } - if(debug_synchronous) - { - hipStreamSynchronize(stream); - } - - // Launch batch_memcpy_blev_kernel. - batch_memcpy_impl_type:: - blev_memcpy_kernel<<>>( - blev_buffers, - scan_state_buffer, - batch_memcpy_grid_size - 1); - error = hipGetLastError(); - if(error != hipSuccess) - { - return error; - } - if(debug_synchronous) - { - hipStreamSynchronize(stream); - } - - return hipSuccess; + stream, + debug_synchronous); } END_ROCPRIM_NAMESPACE diff --git a/rocprim/include/rocprim/device/device_radix_sort.hpp b/rocprim/include/rocprim/device/device_radix_sort.hpp index c4676b264..525465fb5 100644 --- a/rocprim/include/rocprim/device/device_radix_sort.hpp +++ b/rocprim/include/rocprim/device/device_radix_sort.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -27,7 +27,6 @@ #include #include "../config.hpp" -#include "../detail/radix_sort.hpp" #include "../detail/temp_storage.hpp" #include "../detail/various.hpp" @@ -35,6 +34,7 @@ #include "../functional.hpp" #include "../types.hpp" +#include "../type_traits.hpp" #include "detail/config/device_radix_sort_onesweep.hpp" #include "detail/device_radix_sort.hpp" #include "device_transform.hpp" @@ -68,6 +68,30 @@ namespace detail #endif +template +constexpr auto tuple_bit_size_impl() + -> std::enable_if_t::value, size_t> +{ + return 0; +} + +template +constexpr auto tuple_bit_size_impl() + -> std::enable_if_t::value, size_t> +{ + using element_t = std::decay_t<::rocprim::tuple_element_t>; + return 8 * sizeof(element_t) + tuple_bit_size_impl(); +} + +template +struct decomposer_max_bits + : public std::integral_constant< + unsigned int, + tuple_bit_size_impl()( + std::declval>()))>, + 0>()> +{}; + template using offset_type_t = std::conditional_t< sizeof(Size) <= 4, @@ -75,13 +99,14 @@ using offset_type_t = std::conditional_t< size_t >; -template +template ROCPRIM_KERNEL __launch_bounds__(device_params().histogram.block_size) void onesweep_histograms_kernel( KeysInputIterator keys_input, Offset* global_digit_counts, const Offset size, const Offset full_blocks, + Decomposer decomposer, const unsigned int begin_bit, const unsigned int end_bit) { @@ -93,6 +118,7 @@ ROCPRIM_KERNEL global_digit_counts, size, full_blocks, + decomposer, begin_bit, end_bit); } @@ -111,16 +137,18 @@ template -inline hipError_t radix_sort_onesweep_global_offsets(KeysInputIterator keys_input, - ValuesInputIterator, - Offset* global_digit_offsets, - const Offset size, - const unsigned int digit_places, - const unsigned begin_bit, - const unsigned end_bit, - const hipStream_t stream, - const bool debug_synchronous) + class Offset, + class Decomposer> +hipError_t radix_sort_onesweep_global_offsets(KeysInputIterator keys_input, + ValuesInputIterator, + Offset* global_digit_offsets, + const Offset size, + const unsigned int digit_places, + Decomposer decomposer, + const unsigned begin_bit, + const unsigned end_bit, + const hipStream_t stream, + const bool debug_synchronous) { using key_type = typename std::iterator_traits::value_type; using value_type = typename std::iterator_traits::value_type; @@ -168,6 +196,7 @@ inline hipError_t radix_sort_onesweep_global_offsets(KeysInputIterator keys_inpu global_digit_offsets, size, full_blocks, + decomposer, begin_bit, end_bit); ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("compute_global_digit_histograms", size, start); @@ -195,7 +224,8 @@ template + class Offset, + class Decomposer> ROCPRIM_KERNEL __launch_bounds__(device_params().sort.block_size) void onesweep_iteration_kernel( KeysInputIterator keys_input, @@ -206,6 +236,7 @@ ROCPRIM_KERNEL Offset* global_digit_offsets_in, Offset* global_digit_offsets_out, onesweep_lookback_state* lookback_states, + Decomposer decomposer, const unsigned int bit, const unsigned int current_radix_bits, const unsigned int full_blocks) @@ -223,6 +254,7 @@ ROCPRIM_KERNEL global_digit_offsets_in, global_digit_offsets_out, lookback_states, + decomposer, bit, current_radix_bits, full_blocks); @@ -234,8 +266,9 @@ template -inline hipError_t radix_sort_onesweep_iteration( + class Offset, + class Decomposer> +hipError_t radix_sort_onesweep_iteration( KeysInputIterator keys_input, typename std::iterator_traits::value_type* keys_tmp, KeysOutputIterator keys_output, @@ -248,6 +281,7 @@ inline hipError_t radix_sort_onesweep_iteration( onesweep_lookback_state* lookback_states, const bool from_input, const bool to_output, + Decomposer decomposer, const unsigned int bit, const unsigned int end_bit, const hipStream_t stream, @@ -331,6 +365,7 @@ inline hipError_t radix_sort_onesweep_iteration( global_digit_offsets_in, global_digit_offsets_out, lookback_states, + decomposer, bit, current_radix_bits, full_blocks); @@ -350,6 +385,7 @@ inline hipError_t radix_sort_onesweep_iteration( global_digit_offsets_in, global_digit_offsets_out, lookback_states, + decomposer, bit, current_radix_bits, full_blocks); @@ -369,6 +405,7 @@ inline hipError_t radix_sort_onesweep_iteration( global_digit_offsets_in, global_digit_offsets_out, lookback_states, + decomposer, bit, current_radix_bits, full_blocks); @@ -388,6 +425,7 @@ inline hipError_t radix_sort_onesweep_iteration( global_digit_offsets_in, global_digit_offsets_out, lookback_states, + decomposer, bit, current_radix_bits, full_blocks); @@ -406,8 +444,9 @@ template -inline hipError_t radix_sort_onesweep_impl( + class Size, + class Decomposer> +hipError_t radix_sort_onesweep_impl( void* temporary_storage, size_t& storage_size, KeysInputIterator keys_input, @@ -418,6 +457,7 @@ inline hipError_t radix_sort_onesweep_impl( ValuesOutputIterator values_output, const Size size, bool& is_result_in_output, + Decomposer decomposer, const unsigned int begin_bit, const unsigned int end_bit, const hipStream_t stream, @@ -500,6 +540,7 @@ inline hipError_t radix_sort_onesweep_impl( global_digit_offsets, static_cast(size), places, + decomposer, begin_bit, end_bit, stream, @@ -568,6 +609,7 @@ inline hipError_t radix_sort_onesweep_impl( lookback_states, from_input, to_output, + decomposer, bit, end_bit, stream, @@ -589,8 +631,9 @@ template -inline hipError_t + class Size, + class Decomposer> +hipError_t radix_sort_impl(void* temporary_storage, size_t& storage_size, KeysInputIterator keys_input, @@ -601,6 +644,7 @@ inline hipError_t ValuesOutputIterator values_output, Size size, bool& is_result_in_output, + Decomposer decomposer, unsigned int begin_bit, unsigned int end_bit, hipStream_t stream, @@ -642,6 +686,11 @@ inline hipError_t default_block_sort_config, typename Config::single_sort_config>::type; + if(::rocprim::is_floating_point::value + && ((begin_bit != 0) || (end_bit != sizeof(key_type) * 8))) + { + return hipErrorInvalidValue; + } unsigned int single_sort_items_per_block = block_sort_config::block_size * block_sort_config::items_per_thread; if(size <= single_sort_items_per_block) @@ -665,6 +714,7 @@ inline hipError_t values_output, static_cast(size), single_sort_items_per_block, + decomposer, begin_bit, end_bit, stream, @@ -686,6 +736,7 @@ inline hipError_t values_tmp, values_output, static_cast(size), + decomposer, begin_bit, end_bit, stream, @@ -705,6 +756,7 @@ inline hipError_t values_output, size, is_result_in_output, + decomposer, begin_bit, end_bit, stream, @@ -724,7 +776,7 @@ inline hipError_t /// \par Overview /// * The contents of the inputs are not altered by the sorting function. /// * Returns the required size of \p temporary_storage in \p storage_size -/// if \p temporary_storage in a null pointer. +/// if \p temporary_storage is a null pointer. /// * \p Key type (a \p value_type of \p KeysInputIterator and \p KeysOutputIterator) must be /// an arithmetic type (that is, an integral type or a floating-point type). /// * Ranges specified by \p keys_input and \p keys_output must have at least \p size elements. @@ -732,28 +784,28 @@ inline hipError_t /// can be improved by setting \p begin_bit and \p end_bit, for example if all keys are in range /// [100, 10000], begin_bit = 0 and end_bit = 14 will cover the whole range. /// -/// \tparam Config - [optional] configuration of the primitive. It has to be \p radix_sort_config or a class derived from it. -/// \tparam KeysInputIterator - random-access iterator type of the input range. Must meet the +/// \tparam Config [optional] configuration of the primitive. It has to be \p radix_sort_config or a class derived from it. +/// \tparam KeysInputIterator random-access iterator type of the input range. Must meet the /// requirements of a C++ InputIterator concept. It can be a simple pointer type. -/// \tparam KeysOutputIterator - random-access iterator type of the output range. Must meet the +/// \tparam KeysOutputIterator random-access iterator type of the output range. Must meet the /// requirements of a C++ OutputIterator concept. It can be a simple pointer type. -/// \tparam Size - integral type that represents the problem size. +/// \tparam Size integral type that represents the problem size. /// -/// \param [in] temporary_storage - pointer to a device-accessible temporary storage. When +/// \param [in] temporary_storage pointer to a device-accessible temporary storage. When /// a null pointer is passed, the required allocation size (in bytes) is written to /// \p storage_size and function returns without performing the sort operation. -/// \param [in,out] storage_size - reference to a size (in bytes) of \p temporary_storage. -/// \param [in] keys_input - pointer to the first element in the range to sort. -/// \param [out] keys_output - pointer to the first element in the output range. -/// \param [in] size - number of element in the input range. -/// \param [in] begin_bit - [optional] index of the first (least significant) bit used in +/// \param [in,out] storage_size reference to a size (in bytes) of \p temporary_storage. +/// \param [in] keys_input pointer to the first element in the range to sort. +/// \param [out] keys_output pointer to the first element in the output range. +/// \param [in] size number of element in the input range. +/// \param [in] begin_bit [optional] index of the first (least significant) bit used in /// key comparison. Must be in range [0; 8 * sizeof(Key)). Default value: \p 0. /// Non-default value not supported for floating-point key-types. -/// \param [in] end_bit - [optional] past-the-end index (most significant) bit used in +/// \param [in] end_bit [optional] past-the-end index (most significant) bit used in /// key comparison. Must be in range (begin_bit; 8 * sizeof(Key)]. Default /// value: \p 8 * sizeof(Key). Non-default value not supported for floating-point key-types. -/// \param [in] stream - [optional] HIP stream object. Default is \p 0 (default stream). -/// \param [in] debug_synchronous - [optional] If true, synchronization after every kernel +/// \param [in] stream [optional] HIP stream object. Default is \p 0 (default stream). +/// \param [in] debug_synchronous [optional] If true, synchronization after every kernel /// launch is forced in order to check for errors. Default value is \p false. /// /// \returns \p hipSuccess (\p 0) after successful sort; otherwise a HIP runtime error of @@ -791,75 +843,81 @@ inline hipError_t /// // keys_output: [0.08, 0.2, 0.3, 0.4, 0.6, 0.65, 0.7, 1] /// \endcode /// \endparblock -template< - class Config = default_config, - class KeysInputIterator, - class KeysOutputIterator, - class Size, - class Key = typename std::iterator_traits::value_type -> -inline -hipError_t radix_sort_keys(void * temporary_storage, - size_t& storage_size, - KeysInputIterator keys_input, +template::value_type> +hipError_t radix_sort_keys(void* temporary_storage, + size_t& storage_size, + KeysInputIterator keys_input, KeysOutputIterator keys_output, - Size size, - unsigned int begin_bit = 0, - unsigned int end_bit = 8 * sizeof(Key), - hipStream_t stream = 0, - bool debug_synchronous = false) + Size size, + unsigned int begin_bit = 0, + unsigned int end_bit = 8 * sizeof(Key), + hipStream_t stream = 0, + bool debug_synchronous = false) { static_assert(std::is_integral::value, "Size must be an integral type."); empty_type * values = nullptr; bool ignored; - return detail::radix_sort_impl( - temporary_storage, storage_size, - keys_input, nullptr, keys_output, - values, nullptr, values, - size, ignored, - begin_bit, end_bit, - stream, debug_synchronous - ); + return detail::radix_sort_impl(temporary_storage, + storage_size, + keys_input, + nullptr, + keys_output, + values, + nullptr, + values, + size, + ignored, + identity_decomposer{}, + begin_bit, + end_bit, + stream, + debug_synchronous); } -/// \brief Parallel descending radix sort primitive for device level. +/// \brief Parallel ascending radix sort primitive for device level. /// -/// \p radix_sort_keys_desc function performs a device-wide radix sort -/// of keys. Function sorts input keys in descending order. +/// \p radix_sort_keys function performs a device-wide radix sort +/// of keys. Function sorts input keys in ascending order. /// /// \par Overview -/// * The contents of the inputs are not altered by the sorting function. +/// * The contents of both buffers of \p keys may be altered by the sorting function. +/// * \p current() of \p keys is used as the input. +/// * The function will update \p current() of \p keys to point to the buffer +/// that contains the output range. /// * Returns the required size of \p temporary_storage in \p storage_size -/// if \p temporary_storage in a null pointer. -/// * \p Key type (a \p value_type of \p KeysInputIterator and \p KeysOutputIterator) must be -/// an arithmetic type (that is, an integral type or a floating-point type). -/// * Ranges specified by \p keys_input and \p keys_output must have at least \p size elements. +/// if \p temporary_storage is a null pointer. +/// * The function requires small \p temporary_storage as it does not need +/// a temporary buffer of \p size elements. +/// * \p Key type must be an arithmetic type (that is, an integral type or a floating-point +/// type). +/// * Buffers of \p keys must have at least \p size elements. /// * If \p Key is an integer type and the range of keys is known in advance, the performance /// can be improved by setting \p begin_bit and \p end_bit, for example if all keys are in range /// [100, 10000], begin_bit = 0 and end_bit = 14 will cover the whole range. /// -/// \tparam Config - [optional] configuration of the primitive. It has to be \p radix_sort_config or a class derived from it. -/// \tparam KeysInputIterator - random-access iterator type of the input range. Must meet the -/// requirements of a C++ InputIterator concept. It can be a simple pointer type. -/// \tparam KeysOutputIterator - random-access iterator type of the output range. Must meet the -/// requirements of a C++ OutputIterator concept. It can be a simple pointer type. -/// \tparam Size - integral type that represents the problem size. +/// \tparam Config [optional] configuration of the primitive. It has to be \p radix_sort_config or a class derived from it. +/// \tparam Key key type. Must be an integral type or a floating-point type. +/// \tparam Size integral type that represents the problem size. /// -/// \param [in] temporary_storage - pointer to a device-accessible temporary storage. When +/// \param [in] temporary_storage pointer to a device-accessible temporary storage. When /// a null pointer is passed, the required allocation size (in bytes) is written to /// \p storage_size and function returns without performing the sort operation. -/// \param [in,out] storage_size - reference to a size (in bytes) of \p temporary_storage. -/// \param [in] keys_input - pointer to the first element in the range to sort. -/// \param [out] keys_output - pointer to the first element in the output range. -/// \param [in] size - number of element in the input range. -/// \param [in] begin_bit - [optional] index of the first (least significant) bit used in +/// \param [in,out] storage_size reference to a size (in bytes) of \p temporary_storage. +/// \param [in,out] keys reference to the double-buffer of keys, its \p current() +/// contains the input range and will be updated to point to the output range. +/// \param [in] size number of element in the input range. +/// \param [in] begin_bit [optional] index of the first (least significant) bit used in /// key comparison. Must be in range [0; 8 * sizeof(Key)). Default value: \p 0. /// Non-default value not supported for floating-point key-types. -/// \param [in] end_bit - [optional] past-the-end index (most significant) bit used in +/// \param [in] end_bit [optional] past-the-end index (most significant) bit used in /// key comparison. Must be in range (begin_bit; 8 * sizeof(Key)]. Default /// value: \p 8 * sizeof(Key). Non-default value not supported for floating-point key-types. -/// \param [in] stream - [optional] HIP stream object. Default is \p 0 (default stream). -/// \param [in] debug_synchronous - [optional] If true, synchronization after every kernel +/// \param [in] stream [optional] HIP stream object. Default is \p 0 (default stream). +/// \param [in] debug_synchronous [optional] If true, synchronization after every kernel /// launch is forced in order to check for errors. Default value is \p false. /// /// \returns \p hipSuccess (\p 0) after successful sort; otherwise a HIP runtime error of @@ -867,112 +925,119 @@ hipError_t radix_sort_keys(void * temporary_storage, /// /// \par Example /// \parblock -/// In this example a device-level descending radix sort is performed on an array of -/// integer values. +/// In this example a device-level ascending radix sort is performed on an array of +/// \p float values. /// /// \code{.cpp} /// #include /// -/// // Prepare input and output (declare pointers, allocate device memory etc.) -/// size_t input_size; // e.g., 8 -/// int * input; // e.g., [6, 3, 5, 4, 2, 8, 1, 7] -/// int * output; // empty array of 8 elements +/// // Prepare input and tmp (declare pointers, allocate device memory etc.) +/// size_t input_size; // e.g., 8 +/// float * input; // e.g., [0.6, 0.3, 0.65, 0.4, 0.2, 0.08, 1, 0.7] +/// float * tmp; // empty array of 8 elements +/// // Create double-buffer +/// rocprim::double_buffer keys(input, tmp); /// /// size_t temporary_storage_size_bytes; /// void * temporary_storage_ptr = nullptr; /// // Get required size of the temporary storage -/// rocprim::radix_sort_keys_desc( +/// rocprim::radix_sort_keys( /// temporary_storage_ptr, temporary_storage_size_bytes, -/// input, output, input_size +/// keys, input_size /// ); /// /// // allocate temporary storage /// hipMalloc(&temporary_storage_ptr, temporary_storage_size_bytes); /// /// // perform sort -/// rocprim::radix_sort_keys_desc( +/// rocprim::radix_sort_keys( /// temporary_storage_ptr, temporary_storage_size_bytes, -/// input, output, input_size +/// keys, input_size /// ); -/// // keys_output: [8, 7, 6, 5, 4, 3, 2, 1] +/// // keys.current(): [0.08, 0.2, 0.3, 0.4, 0.6, 0.65, 0.7, 1] /// \endcode /// \endparblock -template< - class Config = default_config, - class KeysInputIterator, - class KeysOutputIterator, - class Size, - class Key = typename std::iterator_traits::value_type -> -inline -hipError_t radix_sort_keys_desc(void * temporary_storage, - size_t& storage_size, - KeysInputIterator keys_input, - KeysOutputIterator keys_output, - Size size, - unsigned int begin_bit = 0, - unsigned int end_bit = 8 * sizeof(Key), - hipStream_t stream = 0, - bool debug_synchronous = false) +template +hipError_t radix_sort_keys(void* temporary_storage, + size_t& storage_size, + double_buffer& keys, + Size size, + unsigned int begin_bit = 0, + unsigned int end_bit = 8 * sizeof(Key), + hipStream_t stream = 0, + bool debug_synchronous = false) { static_assert(std::is_integral::value, "Size must be an integral type."); empty_type * values = nullptr; - bool ignored; - return detail::radix_sort_impl( - temporary_storage, storage_size, - keys_input, nullptr, keys_output, - values, nullptr, values, - size, ignored, - begin_bit, end_bit, - stream, debug_synchronous - ); + bool is_result_in_output; + hipError_t error = detail::radix_sort_impl(temporary_storage, + storage_size, + keys.current(), + keys.current(), + keys.alternate(), + values, + values, + values, + size, + is_result_in_output, + identity_decomposer{}, + begin_bit, + end_bit, + stream, + debug_synchronous); + if(temporary_storage != nullptr && error == hipSuccess && is_result_in_output) + { + keys.swap(); + } + return error; } -/// \brief Parallel ascending radix sort-by-key primitive for device level. +/// \brief Parallel ascending radix sort primitive for device level. /// -/// \p radix_sort_pairs_desc function performs a device-wide radix sort -/// of (key, value) pairs. Function sorts input pairs in ascending order of keys. +/// \p radix_sort_keys function performs a device-wide radix sort +/// of keys. Function sorts input keys in ascending order. /// /// \par Overview /// * The contents of the inputs are not altered by the sorting function. /// * Returns the required size of \p temporary_storage in \p storage_size -/// if \p temporary_storage in a null pointer. -/// * \p Key type (a \p value_type of \p KeysInputIterator and \p KeysOutputIterator) must be -/// an arithmetic type (that is, an integral type or a floating-point type). -/// * Ranges specified by \p keys_input, \p keys_output, \p values_input and \p values_output must -/// have at least \p size elements. -/// * If \p Key is an integer type and the range of keys is known in advance, the performance -/// can be improved by setting \p begin_bit and \p end_bit, for example if all keys are in range -/// [100, 10000], begin_bit = 0 and end_bit = 14 will cover the whole range. -/// -/// \tparam Config - [optional] configuration of the primitive. It has to be \p radix_sort_config or a class derived from it. -/// \tparam KeysInputIterator - random-access iterator type of the input range. Must meet the -/// requirements of a C++ InputIterator concept. It can be a simple pointer type. -/// \tparam KeysOutputIterator - random-access iterator type of the output range. Must meet the -/// requirements of a C++ OutputIterator concept. It can be a simple pointer type. -/// \tparam ValuesInputIterator - random-access iterator type of the input range. Must meet the +/// if \p temporary_storage is a null pointer. +/// * \p Key type (a \p value_type of \p KeysInputIterator and \p KeysOutputIterator) can be any +/// trivially copyable type. +/// * \p decomposer must be a functor that implements `operator()(Key&) const`. This operator +/// must return a \p rocprim::tuple that contains one or more reference to value(s) of arithmetic types. +/// These references must point to member variables of `Key`, however not every member variable has to be +/// exposed this way. +/// * Ranges specified by \p keys_input and \p keys_output must have at least \p size elements. +/// * \p begin_bit and \p end_bit can be provided to control the radix range that is considered in the +/// decomposed tuple. For example, if the decomposer returns `rocprim::tuple`, +/// `begin_bit==6` and `end_bit==12`, then the 2 MSBs of the `uint8_t` value and the 4 LSBs of the +/// `int16_t` value are considered for sorting. The range specified by \p begin_bit and \p end_bit +/// must be valid with regards to the sizes of the return tuple's elements. +/// +/// \tparam Config [optional] configuration of the primitive. It has to be \p radix_sort_config or a class derived from it. +/// \tparam KeysInputIterator Random-access iterator type of the input range. Must meet the /// requirements of a C++ InputIterator concept. It can be a simple pointer type. -/// \tparam ValuesOutputIterator - random-access iterator type of the output range. Must meet the +/// \tparam KeysOutputIterator Random-access iterator type of the output range. Must meet the /// requirements of a C++ OutputIterator concept. It can be a simple pointer type. -/// \tparam Size - integral type that represents the problem size. +/// \tparam Size Integral type that represents the problem size. +/// \tparam Key The value type of the input and output iterators. +/// \tparam Decomposer The type of the decomposer functor. /// -/// \param [in] temporary_storage - pointer to a device-accessible temporary storage. When +/// \param [in] temporary_storage pointer to a device-accessible temporary storage. When /// a null pointer is passed, the required allocation size (in bytes) is written to /// \p storage_size and function returns without performing the sort operation. -/// \param [in,out] storage_size - reference to a size (in bytes) of \p temporary_storage. -/// \param [in] keys_input - pointer to the first element in the range to sort. -/// \param [out] keys_output - pointer to the first element in the output range. -/// \param [in] values_input - pointer to the first element in the range to sort. -/// \param [out] values_output - pointer to the first element in the output range. -/// \param [in] size - number of element in the input range. -/// \param [in] begin_bit - [optional] index of the first (least significant) bit used in -/// key comparison. Must be in range [0; 8 * sizeof(Key)). Default value: \p 0. -/// Non-default value not supported for floating-point key-types. -/// \param [in] end_bit - [optional] past-the-end index (most significant) bit used in -/// key comparison. Must be in range (begin_bit; 8 * sizeof(Key)]. Default -/// value: \p 8 * sizeof(Key). Non-default value not supported for floating-point key-types. -/// \param [in] stream - [optional] HIP stream object. Default is \p 0 (default stream). -/// \param [in] debug_synchronous - [optional] If true, synchronization after every kernel +/// \param [in,out] storage_size reference to a size (in bytes) of \p temporary_storage. +/// \param [in] keys_input pointer to the first element in the range to sort. +/// \param [out] keys_output pointer to the first element in the output range. +/// \param [in] size number of element in the input range. +/// \param [in] decomposer decomposer functor that produces a tuple of references from the +/// input key type. +/// \param [in] begin_bit index of the first (least significant) bit used in +/// key comparison. +/// \param [in] end_bit past-the-end index (most significant) bit used in +/// key comparison. +/// \param [in] stream [optional] HIP stream object. Default is \p 0 (default stream). +/// \param [in] debug_synchronous [optional] If true, synchronization after every kernel /// launch is forced in order to check for errors. Default value is \p false. /// /// \returns \p hipSuccess (\p 0) after successful sort; otherwise a HIP runtime error of @@ -980,124 +1045,127 @@ hipError_t radix_sort_keys_desc(void * temporary_storage, /// /// \par Example /// \parblock -/// In this example a device-level ascending radix sort is performed where input keys are -/// represented by an array of unsigned integers and input values by an array of doubles. +/// In this example a device-level ascending radix sort is performed on an array of +/// values of a custom type, using a custom decomposer. /// /// \code{.cpp} /// #include /// -/// // Prepare input and output (declare pointers, allocate device memory etc.) -/// size_t input_size; // e.g., 8 -/// unsigned int * keys_input; // e.g., [ 6, 3, 5, 4, 1, 8, 1, 7] -/// double * values_input; // e.g., [-5, 2, -4, 3, -1, -8, -2, 7] -/// unsigned int * keys_output; // empty array of 8 elements -/// double * values_output; // empty array of 8 elements +/// struct custom_type +/// { +/// int i; +/// double d; +/// }; /// -/// // Keys are in range [0; 8], so we can limit compared bit to bits on indexes -/// // 0, 1, 2, 3, and 4. In order to do this begin_bit is set to 0 and end_bit -/// // is set to 5. +/// struct custom_type_decomposer +/// { +/// rocprim::tuple operator()(custom_type& key) const +/// { +/// return rocprim::tuple(key.i, key.d); +/// } +/// }; +/// +/// // Prepare input and output (declare pointers, allocate device memory etc.) +/// size_t input_size; // e.g., 8 +/// custom_type * input; // e.g., [{2, 0.6}, {-3, 0.3}, {2, 0.65}, {0, 0.4}, {0, 0.2}, {11, 0.08}, {11, 1}, {-1, 0.7}] +/// custom_type * output; // empty array of 8 elements /// +/// constexpr unsigned int begin_bit = 0; +/// constexpr unsigned int end_bit = 96; /// size_t temporary_storage_size_bytes; /// void * temporary_storage_ptr = nullptr; /// // Get required size of the temporary storage -/// rocprim::radix_sort_pairs( +/// rocprim::radix_sort_keys( /// temporary_storage_ptr, temporary_storage_size_bytes, -/// keys_input, keys_output, values_input, values_output, -/// input_size, 0, 5 +/// input, output, input_size, custom_type_decomposer{}, begin_bit, end_bit /// ); /// /// // allocate temporary storage /// hipMalloc(&temporary_storage_ptr, temporary_storage_size_bytes); /// /// // perform sort -/// rocprim::radix_sort_pairs( +/// rocprim::radix_sort_keys( /// temporary_storage_ptr, temporary_storage_size_bytes, -/// keys_input, keys_output, values_input, values_output, -/// input_size, 0, 5 +/// input, output, input_size, custom_type_decomposer{}, begin_bit, end_bit /// ); -/// // keys_output: [ 1, 1, 3, 4, 5, 6, 7, 8] -/// // values_output: [-1, -2, 2, 3, -4, -5, 7, -8] +/// // keys_output: [{-3, 0.3}, {-1, 0.7}, {0, 0.2}, {0, 0.4}, {2, 0.6}, {2, 0.65}, {11, 0.08}, {11, 1.0}] /// \endcode /// \endparblock -template< - class Config = default_config, - class KeysInputIterator, - class KeysOutputIterator, - class ValuesInputIterator, - class ValuesOutputIterator, - class Size, - class Key = typename std::iterator_traits::value_type -> -inline -hipError_t radix_sort_pairs(void * temporary_storage, - size_t& storage_size, - KeysInputIterator keys_input, - KeysOutputIterator keys_output, - ValuesInputIterator values_input, - ValuesOutputIterator values_output, - Size size, - unsigned int begin_bit = 0, - unsigned int end_bit = 8 * sizeof(Key), - hipStream_t stream = 0, - bool debug_synchronous = false) +template::value_type, + class Decomposer> +auto radix_sort_keys(void* temporary_storage, + size_t& storage_size, + KeysInputIterator keys_input, + KeysOutputIterator keys_output, + Size size, + Decomposer decomposer, + unsigned int begin_bit, + unsigned int end_bit, + hipStream_t stream = 0, + bool debug_synchronous = false) + -> std::enable_if_t::value, hipError_t> { static_assert(std::is_integral::value, "Size must be an integral type."); - bool ignored; - return detail::radix_sort_impl( - temporary_storage, storage_size, - keys_input, nullptr, keys_output, - values_input, nullptr, values_output, - size, ignored, - begin_bit, end_bit, - stream, debug_synchronous - ); + empty_type* values = nullptr; + bool ignored; + return detail::radix_sort_impl(temporary_storage, + storage_size, + keys_input, + nullptr, + keys_output, + values, + nullptr, + values, + size, + ignored, + decomposer, + begin_bit, + end_bit, + stream, + debug_synchronous); } -/// \brief Parallel descending radix sort-by-key primitive for device level. +/// \brief Parallel ascending radix sort primitive for device level. /// -/// \p radix_sort_pairs_desc function performs a device-wide radix sort -/// of (key, value) pairs. Function sorts input pairs in descending order of keys. +/// \p radix_sort_keys function performs a device-wide radix sort +/// of keys. Function sorts input keys in ascending order. /// /// \par Overview /// * The contents of the inputs are not altered by the sorting function. /// * Returns the required size of \p temporary_storage in \p storage_size -/// if \p temporary_storage in a null pointer. -/// * \p Key type (a \p value_type of \p KeysInputIterator and \p KeysOutputIterator) must be -/// an arithmetic type (that is, an integral type or a floating-point type). -/// * Ranges specified by \p keys_input, \p keys_output, \p values_input and \p values_output must -/// have at least \p size elements. -/// * If \p Key is an integer type and the range of keys is known in advance, the performance -/// can be improved by setting \p begin_bit and \p end_bit, for example if all keys are in range -/// [100, 10000], begin_bit = 0 and end_bit = 14 will cover the whole range. +/// if \p temporary_storage is a null pointer. +/// * \p Key type (a \p value_type of \p KeysInputIterator and \p KeysOutputIterator) can be any +/// trivially copyable type. +/// * \p decomposer must be a functor that implements `operator()(Key&) const`. This operator +/// must return a \p rocprim::tuple that contains one or more reference to value(s) of arithmetic types. +/// These references must point to member variables of `Key`, however not every member variable has to be +/// exposed this way. +/// * Ranges specified by \p keys_input and \p keys_output must have at least \p size elements. /// -/// \tparam Config - [optional] configuration of the primitive. It has to be \p radix_sort_config or a class derived from it. -/// \tparam KeysInputIterator - random-access iterator type of the input range. Must meet the -/// requirements of a C++ InputIterator concept. It can be a simple pointer type. -/// \tparam KeysOutputIterator - random-access iterator type of the output range. Must meet the -/// requirements of a C++ OutputIterator concept. It can be a simple pointer type. -/// \tparam ValuesInputIterator - random-access iterator type of the input range. Must meet the +/// \tparam Config [optional] configuration of the primitive. It has to be \p radix_sort_config or a class derived from it. +/// \tparam KeysInputIterator Random-access iterator type of the input range. Must meet the /// requirements of a C++ InputIterator concept. It can be a simple pointer type. -/// \tparam ValuesOutputIterator - random-access iterator type of the output range. Must meet the +/// \tparam KeysOutputIterator Random-access iterator type of the output range. Must meet the /// requirements of a C++ OutputIterator concept. It can be a simple pointer type. -/// \tparam Size - integral type that represents the problem size. +/// \tparam Size Integral type that represents the problem size. +/// \tparam Key The value type of the input and output iterators. +/// \tparam Decomposer The type of the decomposer functor. /// -/// \param [in] temporary_storage - pointer to a device-accessible temporary storage. When +/// \param [in] temporary_storage pointer to a device-accessible temporary storage. When /// a null pointer is passed, the required allocation size (in bytes) is written to /// \p storage_size and function returns without performing the sort operation. -/// \param [in,out] storage_size - reference to a size (in bytes) of \p temporary_storage. -/// \param [in] keys_input - pointer to the first element in the range to sort. -/// \param [out] keys_output - pointer to the first element in the output range. -/// \param [in] values_input - pointer to the first element in the range to sort. -/// \param [out] values_output - pointer to the first element in the output range. -/// \param [in] size - number of element in the input range. -/// \param [in] begin_bit - [optional] index of the first (least significant) bit used in -/// key comparison. Must be in range [0; 8 * sizeof(Key)). Default value: \p 0. -/// Non-default value not supported for floating-point key-types. -/// \param [in] end_bit - [optional] past-the-end index (most significant) bit used in -/// key comparison. Must be in range (begin_bit; 8 * sizeof(Key)]. Default -/// value: \p 8 * sizeof(Key). Non-default value not supported for floating-point key-types. -/// \param [in] stream - [optional] HIP stream object. Default is \p 0 (default stream). -/// \param [in] debug_synchronous - [optional] If true, synchronization after every kernel +/// \param [in,out] storage_size reference to a size (in bytes) of \p temporary_storage. +/// \param [in] keys_input pointer to the first element in the range to sort. +/// \param [out] keys_output pointer to the first element in the output range. +/// \param [in] size number of element in the input range. +/// \param [in] decomposer decomposer functor that produces a tuple of references from the +/// input key type. +/// \param [in] stream [optional] HIP stream object. Default is \p 0 (default stream). +/// \param [in] debug_synchronous [optional] If true, synchronization after every kernel /// launch is forced in order to check for errors. Default value is \p false. /// /// \returns \p hipSuccess (\p 0) after successful sort; otherwise a HIP runtime error of @@ -1105,73 +1173,85 @@ hipError_t radix_sort_pairs(void * temporary_storage, /// /// \par Example /// \parblock -/// In this example a device-level descending radix sort is performed where input keys are -/// represented by an array of integers and input values by an array of doubles. +/// In this example a device-level ascending radix sort is performed on an array of +/// values of a custom type, using a custom decomposer. /// /// \code{.cpp} /// #include /// +/// struct custom_type +/// { +/// int i; +/// double d; +/// }; +/// +/// struct custom_type_decomposer +/// { +/// rocprim::tuple operator()(custom_type& key) const +/// { +/// return rocprim::tuple(key.i, key.d); +/// } +/// }; +/// /// // Prepare input and output (declare pointers, allocate device memory etc.) -/// size_t input_size; // e.g., 8 -/// int * keys_input; // e.g., [ 6, 3, 5, 4, 1, 8, 1, 7] -/// double * values_input; // e.g., [-5, 2, -4, 3, -1, -8, -2, 7] -/// int * keys_output; // empty array of 8 elements -/// double * values_output; // empty array of 8 elements +/// size_t input_size; // e.g., 8 +/// custom_type * input; // e.g., [{2, 0.6}, {-3, 0.3}, {2, 0.65}, {0, 0.4}, {0, 0.2}, {11, 0.08}, {11, 1}, {-1, 0.7}] +/// custom_type * output; // empty array of 8 elements /// /// size_t temporary_storage_size_bytes; /// void * temporary_storage_ptr = nullptr; /// // Get required size of the temporary storage -/// rocprim::radix_sort_pairs_desc( +/// rocprim::radix_sort_keys( /// temporary_storage_ptr, temporary_storage_size_bytes, -/// keys_input, keys_output, values_input, values_output, -/// input_size +/// input, output, input_size, custom_type_decomposer{} /// ); /// /// // allocate temporary storage /// hipMalloc(&temporary_storage_ptr, temporary_storage_size_bytes); /// /// // perform sort -/// rocprim::radix_sort_pairs_desc( +/// rocprim::radix_sort_keys( /// temporary_storage_ptr, temporary_storage_size_bytes, -/// keys_input, keys_output, values_input, values_output, -/// input_size +/// input, output, input_size, custom_type_decomposer{} /// ); -/// // keys_output: [ 8, 7, 6, 5, 4, 3, 1, 1] -/// // values_output: [-8, 7, -5, -4, 3, 2, -1, -2] +/// // keys_output: [{-3, 0.3}, {-1, 0.7}, {0, 0.2}, {0, 0.4}, {2, 0.6}, {2, 0.65}, {11, 0.08}, {11, 1.0}] /// \endcode /// \endparblock -template< - class Config = default_config, - class KeysInputIterator, - class KeysOutputIterator, - class ValuesInputIterator, - class ValuesOutputIterator, - class Size, - class Key = typename std::iterator_traits::value_type -> -inline -hipError_t radix_sort_pairs_desc(void * temporary_storage, - size_t& storage_size, - KeysInputIterator keys_input, - KeysOutputIterator keys_output, - ValuesInputIterator values_input, - ValuesOutputIterator values_output, - Size size, - unsigned int begin_bit = 0, - unsigned int end_bit = 8 * sizeof(Key), - hipStream_t stream = 0, - bool debug_synchronous = false) +template::value_type, + class Decomposer> +auto radix_sort_keys(void* temporary_storage, + size_t& storage_size, + KeysInputIterator keys_input, + KeysOutputIterator keys_output, + Size size, + Decomposer decomposer, + hipStream_t stream = 0, + bool debug_synchronous = false) + -> std::enable_if_t::value, hipError_t> { static_assert(std::is_integral::value, "Size must be an integral type."); - bool ignored; - return detail::radix_sort_impl( - temporary_storage, storage_size, - keys_input, nullptr, keys_output, - values_input, nullptr, values_output, - size, ignored, - begin_bit, end_bit, - stream, debug_synchronous - ); + empty_type* values = nullptr; + bool ignored; + return detail::radix_sort_impl( + temporary_storage, + storage_size, + keys_input, + nullptr, + keys_output, + values, + nullptr, + values, + size, + ignored, + decomposer, + 0, + detail::decomposer_max_bits::value, + stream, + debug_synchronous); } /// \brief Parallel ascending radix sort primitive for device level. @@ -1185,35 +1265,42 @@ hipError_t radix_sort_pairs_desc(void * temporary_storage, /// * The function will update \p current() of \p keys to point to the buffer /// that contains the output range. /// * Returns the required size of \p temporary_storage in \p storage_size -/// if \p temporary_storage in a null pointer. +/// if \p temporary_storage is a null pointer. /// * The function requires small \p temporary_storage as it does not need /// a temporary buffer of \p size elements. -/// * \p Key type must be an arithmetic type (that is, an integral type or a floating-point -/// type). -/// * Buffers of \p keys must have at least \p size elements. -/// * If \p Key is an integer type and the range of keys is known in advance, the performance -/// can be improved by setting \p begin_bit and \p end_bit, for example if all keys are in range -/// [100, 10000], begin_bit = 0 and end_bit = 14 will cover the whole range. -/// -/// \tparam Config - [optional] configuration of the primitive. It has to be \p radix_sort_config or a class derived from it. -/// \tparam Key - key type. Must be an integral type or a floating-point type. -/// \tparam Size - integral type that represents the problem size. -/// -/// \param [in] temporary_storage - pointer to a device-accessible temporary storage. When +/// * \p Key type (a \p value_type of \p KeysInputIterator and \p KeysOutputIterator) can be any +/// trivially copyable type. +/// * \p decomposer must be a functor that implements `operator()(Key&) const`. This operator +/// must return a \p rocprim::tuple that contains one or more reference to value(s) of arithmetic types. +/// These references must point to member variables of `Key`, however not every member variable has to be +/// exposed this way. +/// * Ranges specified by \p keys_input and \p keys_output must have at least \p size elements. +/// * \p begin_bit and \p end_bit can be provided to control the radix range that is considered in the +/// decomposed tuple. For example, if the decomposer returns `rocprim::tuple`, +/// `begin_bit==6` and `end_bit==12`, then the 2 MSBs of the `uint8_t` value and the 4 LSBs of the +/// `int16_t` value are considered for sorting. The range specified by \p begin_bit and \p end_bit +/// must be valid with regards to the sizes of the return tuple's elements. +/// +/// \tparam Config [optional] configuration of the primitive. It has to be \p radix_sort_config or a class derived from it. +/// \tparam Key key type. Must be an integral type or a floating-point type. +/// \tparam Size integral type that represents the problem size. +/// \tparam Decomposer The type of the decomposer functor. +/// +/// \param [in] temporary_storage pointer to a device-accessible temporary storage. When /// a null pointer is passed, the required allocation size (in bytes) is written to /// \p storage_size and function returns without performing the sort operation. -/// \param [in,out] storage_size - reference to a size (in bytes) of \p temporary_storage. -/// \param [in,out] keys - reference to the double-buffer of keys, its \p current() +/// \param [in,out] storage_size reference to a size (in bytes) of \p temporary_storage. +/// \param [in,out] keys reference to the double-buffer of keys, its \p current() /// contains the input range and will be updated to point to the output range. -/// \param [in] size - number of element in the input range. -/// \param [in] begin_bit - [optional] index of the first (least significant) bit used in -/// key comparison. Must be in range [0; 8 * sizeof(Key)). Default value: \p 0. -/// Non-default value not supported for floating-point key-types. -/// \param [in] end_bit - [optional] past-the-end index (most significant) bit used in -/// key comparison. Must be in range (begin_bit; 8 * sizeof(Key)]. Default -/// value: \p 8 * sizeof(Key). Non-default value not supported for floating-point key-types. -/// \param [in] stream - [optional] HIP stream object. Default is \p 0 (default stream). -/// \param [in] debug_synchronous - [optional] If true, synchronization after every kernel +/// \param [in] size number of element in the input range. +/// \param [in] decomposer decomposer functor that produces a tuple of references from the +/// input key type. +/// \param [in] begin_bit index of the first (least significant) bit used in +/// key comparison. +/// \param [in] end_bit past-the-end index (most significant) bit used in +/// key comparison. +/// \param [in] stream [optional] HIP stream object. Default is \p 0 (default stream). +/// \param [in] debug_synchronous [optional] If true, synchronization after every kernel /// launch is forced in order to check for errors. Default value is \p false. /// /// \returns \p hipSuccess (\p 0) after successful sort; otherwise a HIP runtime error of @@ -1222,24 +1309,40 @@ hipError_t radix_sort_pairs_desc(void * temporary_storage, /// \par Example /// \parblock /// In this example a device-level ascending radix sort is performed on an array of -/// \p float values. +/// values of a custom type, using a custom decomposer. /// /// \code{.cpp} /// #include /// +/// struct custom_type +/// { +/// int i; +/// double d; +/// }; +/// +/// struct custom_type_decomposer +/// { +/// rocprim::tuple operator()(custom_type& key) const +/// { +/// return rocprim::tuple(key.i, key.d); +/// } +/// }; +/// /// // Prepare input and tmp (declare pointers, allocate device memory etc.) -/// size_t input_size; // e.g., 8 -/// float * input; // e.g., [0.6, 0.3, 0.65, 0.4, 0.2, 0.08, 1, 0.7] -/// float * tmp; // empty array of 8 elements +/// constexpr unsigned int begin_bit = 0; +/// constexpr unsigned int end_bit = 96; +/// size_t input_size; // e.g., 8 +/// custom_type * input; // e.g., [{2, 0.6}, {-3, 0.3}, {2, 0.65}, {0, 0.4}, {0, 0.2}, {11, 0.08}, {11, 1}, {-1, 0.7}] +/// custom_type * tmp; // empty array of 8 elements /// // Create double-buffer -/// rocprim::double_buffer keys(input, tmp); +/// rocprim::double_buffer keys(input, tmp); /// /// size_t temporary_storage_size_bytes; /// void * temporary_storage_ptr = nullptr; /// // Get required size of the temporary storage /// rocprim::radix_sort_keys( /// temporary_storage_ptr, temporary_storage_size_bytes, -/// keys, input_size +/// keys, input_size, custom_type_decomposer{}, begin_bit, end_bit /// ); /// /// // allocate temporary storage @@ -1248,37 +1351,41 @@ hipError_t radix_sort_pairs_desc(void * temporary_storage, /// // perform sort /// rocprim::radix_sort_keys( /// temporary_storage_ptr, temporary_storage_size_bytes, -/// keys, input_size +/// keys, input_size, custom_type_decomposer{}, begin_bit, end_bit /// ); -/// // keys.current(): [0.08, 0.2, 0.3, 0.4, 0.6, 0.65, 0.7, 1] +/// // keys.current(): [{-3, 0.3}, {-1, 0.7}, {0, 0.2}, {0, 0.4}, {2, 0.6}, {2, 0.65}, {11, 0.08}, {11, 1.0}] /// \endcode /// \endparblock -template< - class Config = default_config, - class Key, - class Size -> -inline -hipError_t radix_sort_keys(void * temporary_storage, - size_t& storage_size, - double_buffer& keys, - Size size, - unsigned int begin_bit = 0, - unsigned int end_bit = 8 * sizeof(Key), - hipStream_t stream = 0, - bool debug_synchronous = false) +template +auto radix_sort_keys(void* temporary_storage, + size_t& storage_size, + double_buffer& keys, + Size size, + Decomposer decomposer, + unsigned int begin_bit, + unsigned int end_bit, + hipStream_t stream = 0, + bool debug_synchronous = false) + -> std::enable_if_t::value, hipError_t> { static_assert(std::is_integral::value, "Size must be an integral type."); - empty_type * values = nullptr; - bool is_result_in_output; - hipError_t error = detail::radix_sort_impl( - temporary_storage, storage_size, - keys.current(), keys.current(), keys.alternate(), - values, values, values, - size, is_result_in_output, - begin_bit, end_bit, - stream, debug_synchronous - ); + empty_type* values = nullptr; + bool is_result_in_output; + hipError_t error = detail::radix_sort_impl(temporary_storage, + storage_size, + keys.current(), + keys.current(), + keys.alternate(), + values, + values, + values, + size, + is_result_in_output, + decomposer, + begin_bit, + end_bit, + stream, + debug_synchronous); if(temporary_storage != nullptr && error == hipSuccess && is_result_in_output) { keys.swap(); @@ -1286,10 +1393,10 @@ hipError_t radix_sort_keys(void * temporary_storage, return error; } -/// \brief Parallel descending radix sort primitive for device level. +/// \brief Parallel ascending radix sort primitive for device level. /// -/// \p radix_sort_keys_desc function performs a device-wide radix sort -/// of keys. Function sorts input keys in descending order. +/// \p radix_sort_keys function performs a device-wide radix sort +/// of keys. Function sorts input keys in ascending order. /// /// \par Overview /// * The contents of both buffers of \p keys may be altered by the sorting function. @@ -1297,35 +1404,33 @@ hipError_t radix_sort_keys(void * temporary_storage, /// * The function will update \p current() of \p keys to point to the buffer /// that contains the output range. /// * Returns the required size of \p temporary_storage in \p storage_size -/// if \p temporary_storage in a null pointer. +/// if \p temporary_storage is a null pointer. /// * The function requires small \p temporary_storage as it does not need /// a temporary buffer of \p size elements. -/// * \p Key type must be an arithmetic type (that is, an integral type or a floating-point -/// type). -/// * Buffers of \p keys must have at least \p size elements. -/// * If \p Key is an integer type and the range of keys is known in advance, the performance -/// can be improved by setting \p begin_bit and \p end_bit, for example if all keys are in range -/// [100, 10000], begin_bit = 0 and end_bit = 14 will cover the whole range. +/// * \p Key type (a \p value_type of \p KeysInputIterator and \p KeysOutputIterator) can be any +/// trivially copyable type. +/// * \p decomposer must be a functor that implements `operator()(Key&) const`. This operator +/// must return a \p rocprim::tuple that contains one or more reference to value(s) of arithmetic types. +/// These references must point to member variables of `Key`, however not every member variable has to be +/// exposed this way. +/// * Ranges specified by \p keys_input and \p keys_output must have at least \p size elements. /// -/// \tparam Config - [optional] configuration of the primitive. It has to be \p radix_sort_config or a class derived from it. -/// \tparam Key - key type. Must be an integral type or a floating-point type. -/// \tparam Size - integral type that represents the problem size. +/// \tparam Config [optional] configuration of the primitive. It has to be \p radix_sort_config or a class derived from it. +/// \tparam Key key type. Must be an integral type or a floating-point type. +/// \tparam Size integral type that represents the problem size. +/// \tparam Decomposer The type of the decomposer functor. /// -/// \param [in] temporary_storage - pointer to a device-accessible temporary storage. When +/// \param [in] temporary_storage pointer to a device-accessible temporary storage. When /// a null pointer is passed, the required allocation size (in bytes) is written to /// \p storage_size and function returns without performing the sort operation. -/// \param [in,out] storage_size - reference to a size (in bytes) of \p temporary_storage. -/// \param [in,out] keys - reference to the double-buffer of keys, its \p current() +/// \param [in,out] storage_size reference to a size (in bytes) of \p temporary_storage. +/// \param [in,out] keys reference to the double-buffer of keys, its \p current() /// contains the input range and will be updated to point to the output range. -/// \param [in] size - number of element in the input range. -/// \param [in] begin_bit - [optional] index of the first (least significant) bit used in -/// key comparison. Must be in range [0; 8 * sizeof(Key)). Default value: \p 0. -/// Non-default value not supported for floating-point key-types. -/// \param [in] end_bit - [optional] past-the-end index (most significant) bit used in -/// key comparison. Must be in range (begin_bit; 8 * sizeof(Key)]. Default -/// value: \p 8 * sizeof(Key). Non-default value not supported for floating-point key-types. -/// \param [in] stream - [optional] HIP stream object. Default is \p 0 (default stream). -/// \param [in] debug_synchronous - [optional] If true, synchronization after every kernel +/// \param [in] size number of element in the input range. +/// \param [in] decomposer decomposer functor that produces a tuple of references from the +/// input key type. +/// \param [in] stream [optional] HIP stream object. Default is \p 0 (default stream). +/// \param [in] debug_synchronous [optional] If true, synchronization after every kernel /// launch is forced in order to check for errors. Default value is \p false. /// /// \returns \p hipSuccess (\p 0) after successful sort; otherwise a HIP runtime error of @@ -1333,64 +1438,81 @@ hipError_t radix_sort_keys(void * temporary_storage, /// /// \par Example /// \parblock -/// In this example a device-level descending radix sort is performed on an array of -/// integer values. +/// In this example a device-level ascending radix sort is performed on an array of +/// values of a custom type, using a custom decomposer. /// /// \code{.cpp} /// #include /// +/// struct custom_type +/// { +/// int i; +/// double d; +/// }; +/// +/// struct custom_type_decomposer +/// { +/// rocprim::tuple operator()(custom_type& key) const +/// { +/// return rocprim::tuple(key.i, key.d); +/// } +/// }; +/// /// // Prepare input and tmp (declare pointers, allocate device memory etc.) -/// size_t input_size; // e.g., 8 -/// int * input; // e.g., [6, 3, 5, 4, 2, 8, 1, 7] -/// int * tmp; // empty array of 8 elements +/// size_t input_size; // e.g., 8 +/// custom_type * input; // e.g., [{2, 0.6}, {-3, 0.3}, {2, 0.65}, {0, 0.4}, {0, 0.2}, {11, 0.08}, {11, 1}, {-1, 0.7}] +/// custom_type * tmp; // empty array of 8 elements /// // Create double-buffer -/// rocprim::double_buffer keys(input, tmp); +/// rocprim::double_buffer keys(input, tmp); /// /// size_t temporary_storage_size_bytes; /// void * temporary_storage_ptr = nullptr; /// // Get required size of the temporary storage -/// rocprim::radix_sort_keys_desc( +/// rocprim::radix_sort_keys( /// temporary_storage_ptr, temporary_storage_size_bytes, -/// keys, input_size +/// keys, input_size, custom_type_decomposer{} /// ); /// /// // allocate temporary storage /// hipMalloc(&temporary_storage_ptr, temporary_storage_size_bytes); /// /// // perform sort -/// rocprim::radix_sort_keys_desc( +/// rocprim::radix_sort_keys( /// temporary_storage_ptr, temporary_storage_size_bytes, -/// keys, input_size +/// keys, input_size, custom_type_decomposer{} /// ); -/// // keys.current(): [8, 7, 6, 5, 4, 3, 2, 1] +/// // keys.current(): [{-3, 0.3}, {-1, 0.7}, {0, 0.2}, {0, 0.4}, {2, 0.6}, {2, 0.65}, {11, 0.08}, {11, 1.0}] /// \endcode /// \endparblock -template< - class Config = default_config, - class Key, - class Size -> -inline -hipError_t radix_sort_keys_desc(void * temporary_storage, - size_t& storage_size, - double_buffer& keys, - Size size, - unsigned int begin_bit = 0, - unsigned int end_bit = 8 * sizeof(Key), - hipStream_t stream = 0, - bool debug_synchronous = false) +template +auto radix_sort_keys(void* temporary_storage, + size_t& storage_size, + double_buffer& keys, + Size size, + Decomposer decomposer, + hipStream_t stream = 0, + bool debug_synchronous = false) + -> std::enable_if_t::value, hipError_t> { static_assert(std::is_integral::value, "Size must be an integral type."); - empty_type * values = nullptr; - bool is_result_in_output; - hipError_t error = detail::radix_sort_impl( - temporary_storage, storage_size, - keys.current(), keys.current(), keys.alternate(), - values, values, values, - size, is_result_in_output, - begin_bit, end_bit, - stream, debug_synchronous - ); + empty_type* values = nullptr; + bool is_result_in_output; + hipError_t error = detail::radix_sort_impl( + temporary_storage, + storage_size, + keys.current(), + keys.current(), + keys.alternate(), + values, + values, + values, + size, + is_result_in_output, + decomposer, + 0, + detail::decomposer_max_bits::value, + stream, + debug_synchronous); if(temporary_storage != nullptr && error == hipSuccess && is_result_in_output) { keys.swap(); @@ -1398,49 +1520,1360 @@ hipError_t radix_sort_keys_desc(void * temporary_storage, return error; } -/// \brief Parallel ascending radix sort-by-key primitive for device level. +/// \brief Parallel descending radix sort primitive for device level. /// -/// \p radix_sort_pairs_desc function performs a device-wide radix sort -/// of (key, value) pairs. Function sorts input pairs in ascending order of keys. +/// \p radix_sort_keys_desc function performs a device-wide radix sort +/// of keys. Function sorts input keys in descending order. /// /// \par Overview -/// * The contents of both buffers of \p keys and \p values may be altered by the sorting function. -/// * \p current() of \p keys and \p values are used as the input. -/// * The function will update \p current() of \p keys and \p values to point to buffers -/// that contains the output range. +/// * The contents of the inputs are not altered by the sorting function. /// * Returns the required size of \p temporary_storage in \p storage_size -/// if \p temporary_storage in a null pointer. -/// * The function requires small \p temporary_storage as it does not need -/// a temporary buffer of \p size elements. -/// * \p Key type must be an arithmetic type (that is, an integral type or a floating-point -/// type). -/// * Buffers of \p keys must have at least \p size elements. +/// if \p temporary_storage is a null pointer. +/// * \p Key type (a \p value_type of \p KeysInputIterator and \p KeysOutputIterator) must be +/// an arithmetic type (that is, an integral type or a floating-point type). +/// * Ranges specified by \p keys_input and \p keys_output must have at least \p size elements. /// * If \p Key is an integer type and the range of keys is known in advance, the performance /// can be improved by setting \p begin_bit and \p end_bit, for example if all keys are in range /// [100, 10000], begin_bit = 0 and end_bit = 14 will cover the whole range. /// -/// \tparam Config - [optional] configuration of the primitive. It has to be \p radix_sort_config or a class derived from it. -/// \tparam Key - key type. Must be an integral type or a floating-point type. -/// \tparam Value - value type. -/// \tparam Size - integral type that represents the problem size. +/// \tparam Config [optional] configuration of the primitive. It has to be \p radix_sort_config or a class derived from it. +/// \tparam KeysInputIterator random-access iterator type of the input range. Must meet the +/// requirements of a C++ InputIterator concept. It can be a simple pointer type. +/// \tparam KeysOutputIterator random-access iterator type of the output range. Must meet the +/// requirements of a C++ OutputIterator concept. It can be a simple pointer type. +/// \tparam Size integral type that represents the problem size. /// -/// \param [in] temporary_storage - pointer to a device-accessible temporary storage. When +/// \param [in] temporary_storage pointer to a device-accessible temporary storage. When /// a null pointer is passed, the required allocation size (in bytes) is written to /// \p storage_size and function returns without performing the sort operation. -/// \param [in,out] storage_size - reference to a size (in bytes) of \p temporary_storage. -/// \param [in,out] keys - reference to the double-buffer of keys, its \p current() -/// contains the input range and will be updated to point to the output range. -/// \param [in,out] values - reference to the double-buffer of values, its \p current() -/// contains the input range and will be updated to point to the output range. -/// \param [in] size - number of element in the input range. -/// \param [in] begin_bit - [optional] index of the first (least significant) bit used in +/// \param [in,out] storage_size reference to a size (in bytes) of \p temporary_storage. +/// \param [in] keys_input pointer to the first element in the range to sort. +/// \param [out] keys_output pointer to the first element in the output range. +/// \param [in] size number of element in the input range. +/// \param [in] begin_bit [optional] index of the first (least significant) bit used in /// key comparison. Must be in range [0; 8 * sizeof(Key)). Default value: \p 0. /// Non-default value not supported for floating-point key-types. -/// \param [in] end_bit - [optional] past-the-end index (most significant) bit used in +/// \param [in] end_bit [optional] past-the-end index (most significant) bit used in /// key comparison. Must be in range (begin_bit; 8 * sizeof(Key)]. Default /// value: \p 8 * sizeof(Key). Non-default value not supported for floating-point key-types. -/// \param [in] stream - [optional] HIP stream object. Default is \p 0 (default stream). -/// \param [in] debug_synchronous - [optional] If true, synchronization after every kernel +/// \param [in] stream [optional] HIP stream object. Default is \p 0 (default stream). +/// \param [in] debug_synchronous [optional] If true, synchronization after every kernel +/// launch is forced in order to check for errors. Default value is \p false. +/// +/// \returns \p hipSuccess (\p 0) after successful sort; otherwise a HIP runtime error of +/// type \p hipError_t. +/// +/// \par Example +/// \parblock +/// In this example a device-level descending radix sort is performed on an array of +/// integer values. +/// +/// \code{.cpp} +/// #include +/// +/// // Prepare input and output (declare pointers, allocate device memory etc.) +/// size_t input_size; // e.g., 8 +/// int * input; // e.g., [6, 3, 5, 4, 2, 8, 1, 7] +/// int * output; // empty array of 8 elements +/// +/// size_t temporary_storage_size_bytes; +/// void * temporary_storage_ptr = nullptr; +/// // Get required size of the temporary storage +/// rocprim::radix_sort_keys_desc( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// input, output, input_size +/// ); +/// +/// // allocate temporary storage +/// hipMalloc(&temporary_storage_ptr, temporary_storage_size_bytes); +/// +/// // perform sort +/// rocprim::radix_sort_keys_desc( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// input, output, input_size +/// ); +/// // keys_output: [8, 7, 6, 5, 4, 3, 2, 1] +/// \endcode +/// \endparblock +template::value_type> + +hipError_t radix_sort_keys_desc(void* temporary_storage, + size_t& storage_size, + KeysInputIterator keys_input, + KeysOutputIterator keys_output, + Size size, + unsigned int begin_bit = 0, + unsigned int end_bit = 8 * sizeof(Key), + hipStream_t stream = 0, + bool debug_synchronous = false) +{ + static_assert(std::is_integral::value, "Size must be an integral type."); + empty_type* values = nullptr; + bool ignored; + return detail::radix_sort_impl(temporary_storage, + storage_size, + keys_input, + nullptr, + keys_output, + values, + nullptr, + values, + size, + ignored, + identity_decomposer{}, + begin_bit, + end_bit, + stream, + debug_synchronous); +} + +/// \brief Parallel descending radix sort primitive for device level. +/// +/// \p radix_sort_keys_desc function performs a device-wide radix sort +/// of keys. Function sorts input keys in descending order. +/// +/// \par Overview +/// * The contents of both buffers of \p keys may be altered by the sorting function. +/// * \p current() of \p keys is used as the input. +/// * The function will update \p current() of \p keys to point to the buffer +/// that contains the output range. +/// * Returns the required size of \p temporary_storage in \p storage_size +/// if \p temporary_storage is a null pointer. +/// * The function requires small \p temporary_storage as it does not need +/// a temporary buffer of \p size elements. +/// * \p Key type must be an arithmetic type (that is, an integral type or a floating-point +/// type). +/// * Buffers of \p keys must have at least \p size elements. +/// * If \p Key is an integer type and the range of keys is known in advance, the performance +/// can be improved by setting \p begin_bit and \p end_bit, for example if all keys are in range +/// [100, 10000], begin_bit = 0 and end_bit = 14 will cover the whole range. +/// +/// \tparam Config [optional] configuration of the primitive. It has to be \p radix_sort_config or a class derived from it. +/// \tparam Key key type. Must be an integral type or a floating-point type. +/// \tparam Size integral type that represents the problem size. +/// +/// \param [in] temporary_storage pointer to a device-accessible temporary storage. When +/// a null pointer is passed, the required allocation size (in bytes) is written to +/// \p storage_size and function returns without performing the sort operation. +/// \param [in,out] storage_size reference to a size (in bytes) of \p temporary_storage. +/// \param [in,out] keys reference to the double-buffer of keys, its \p current() +/// contains the input range and will be updated to point to the output range. +/// \param [in] size number of element in the input range. +/// \param [in] begin_bit [optional] index of the first (least significant) bit used in +/// key comparison. Must be in range [0; 8 * sizeof(Key)). Default value: \p 0. +/// Non-default value not supported for floating-point key-types. +/// \param [in] end_bit [optional] past-the-end index (most significant) bit used in +/// key comparison. Must be in range (begin_bit; 8 * sizeof(Key)]. Default +/// value: \p 8 * sizeof(Key). Non-default value not supported for floating-point key-types. +/// \param [in] stream [optional] HIP stream object. Default is \p 0 (default stream). +/// \param [in] debug_synchronous [optional] If true, synchronization after every kernel +/// launch is forced in order to check for errors. Default value is \p false. +/// +/// \returns \p hipSuccess (\p 0) after successful sort; otherwise a HIP runtime error of +/// type \p hipError_t. +/// +/// \par Example +/// \parblock +/// In this example a device-level descending radix sort is performed on an array of +/// integer values. +/// +/// \code{.cpp} +/// #include +/// +/// // Prepare input and tmp (declare pointers, allocate device memory etc.) +/// size_t input_size; // e.g., 8 +/// int * input; // e.g., [6, 3, 5, 4, 2, 8, 1, 7] +/// int * tmp; // empty array of 8 elements +/// // Create double-buffer +/// rocprim::double_buffer keys(input, tmp); +/// +/// size_t temporary_storage_size_bytes; +/// void * temporary_storage_ptr = nullptr; +/// // Get required size of the temporary storage +/// rocprim::radix_sort_keys_desc( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// keys, input_size +/// ); +/// +/// // allocate temporary storage +/// hipMalloc(&temporary_storage_ptr, temporary_storage_size_bytes); +/// +/// // perform sort +/// rocprim::radix_sort_keys_desc( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// keys, input_size +/// ); +/// // keys.current(): [8, 7, 6, 5, 4, 3, 2, 1] +/// \endcode +/// \endparblock +template +hipError_t radix_sort_keys_desc(void* temporary_storage, + size_t& storage_size, + double_buffer& keys, + Size size, + unsigned int begin_bit = 0, + unsigned int end_bit = 8 * sizeof(Key), + hipStream_t stream = 0, + bool debug_synchronous = false) +{ + static_assert(std::is_integral::value, "Size must be an integral type."); + empty_type* values = nullptr; + bool is_result_in_output; + hipError_t error = detail::radix_sort_impl(temporary_storage, + storage_size, + keys.current(), + keys.current(), + keys.alternate(), + values, + values, + values, + size, + is_result_in_output, + identity_decomposer{}, + begin_bit, + end_bit, + stream, + debug_synchronous); + if(temporary_storage != nullptr && error == hipSuccess && is_result_in_output) + { + keys.swap(); + } + return error; +} + +/// \brief Parallel descending radix sort primitive for device level. +/// +/// \p radix_sort_keys_desc function performs a device-wide radix sort +/// of keys. Function sorts input keys in descending order. +/// +/// \par Overview +/// * The contents of the inputs are not altered by the sorting function. +/// * Returns the required size of \p temporary_storage in \p storage_size +/// if \p temporary_storage is a null pointer. +/// * \p Key type (a \p value_type of \p KeysInputIterator and \p KeysOutputIterator) can be any +/// trivially copyable type. +/// * \p decomposer must be a functor that implements `operator()(Key&) const`. This operator +/// must return a \p rocprim::tuple that contains one or more reference to value(s) of arithmetic types. +/// These references must point to member variables of `Key`, however not every member variable has to be +/// exposed this way. +/// * Ranges specified by \p keys_input and \p keys_output must have at least \p size elements. +/// * \p begin_bit and \p end_bit can be provided to control the radix range that is considered in the +/// decomposed tuple. For example, if the decomposer returns `rocprim::tuple`, +/// `begin_bit==6` and `end_bit==12`, then the 2 MSBs of the `uint8_t` value and the 4 LSBs of the +/// `int16_t` value are considered for sorting. The range specified by \p begin_bit and \p end_bit +/// must be valid with regards to the sizes of the return tuple's elements. +/// +/// \tparam Config [optional] configuration of the primitive. It has to be \p radix_sort_config or a class derived from it. +/// \tparam KeysInputIterator random-access iterator type of the input range. Must meet the +/// requirements of a C++ InputIterator concept. It can be a simple pointer type. +/// \tparam KeysOutputIterator random-access iterator type of the output range. Must meet the +/// requirements of a C++ OutputIterator concept. It can be a simple pointer type. +/// \tparam Size integral type that represents the problem size. +/// \tparam Key The value type of the input and output iterators. +/// \tparam Decomposer The type of the decomposer functor. +/// +/// \param [in] temporary_storage pointer to a device-accessible temporary storage. When +/// a null pointer is passed, the required allocation size (in bytes) is written to +/// \p storage_size and function returns without performing the sort operation. +/// \param [in,out] storage_size reference to a size (in bytes) of \p temporary_storage. +/// \param [in] keys_input pointer to the first element in the range to sort. +/// \param [out] keys_output pointer to the first element in the output range. +/// \param [in] size number of element in the input range. +/// \param [in] decomposer decomposer functor that produces a tuple of references from the +/// input key type. +/// \param [in] begin_bit index of the first (least significant) bit used in +/// key comparison. +/// \param [in] end_bit past-the-end index (most significant) bit used in +/// key comparison. +/// \param [in] stream [optional] HIP stream object. Default is \p 0 (default stream). +/// \param [in] debug_synchronous [optional] If true, synchronization after every kernel +/// launch is forced in order to check for errors. Default value is \p false. +/// +/// \returns \p hipSuccess (\p 0) after successful sort; otherwise a HIP runtime error of +/// type \p hipError_t. +/// +/// \par Example +/// \parblock +/// In this example a device-level descending radix sort is performed on an array of +/// values of a custom type, using a custom decomposer. +/// +/// \code{.cpp} +/// #include +/// +/// struct custom_type +/// { +/// int i; +/// double d; +/// }; +/// +/// struct custom_type_decomposer +/// { +/// rocprim::tuple operator()(custom_type& key) const +/// { +/// return rocprim::tuple(key.i, key.d); +/// } +/// }; +/// +/// // Prepare input and output (declare pointers, allocate device memory etc.) +/// constexpr unsigned int begin_bit = 0; +/// constexpr unsigned int end_bit = 96; +/// size_t input_size; // e.g., 8 +/// custom_type * input; // e.g., [{2, 0.6}, {-3, 0.3}, {2, 0.65}, {0, 0.4}, {0, 0.2}, {11, 0.08}, {11, 1}, {-1, 0.7}] +/// custom_type * output; // empty array of 8 elements +/// +/// size_t temporary_storage_size_bytes; +/// void * temporary_storage_ptr = nullptr; +/// // Get required size of the temporary storage +/// rocprim::radix_sort_keys_desc( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// input, output, input_size, custom_type_decomposer{}, begin_bit, end_bit +/// ); +/// +/// // allocate temporary storage +/// hipMalloc(&temporary_storage_ptr, temporary_storage_size_bytes); +/// +/// // perform sort +/// rocprim::radix_sort_keys_desc( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// input, output, input_size, custom_type_decomposer{}, begin_bit, end_bit +/// ); +/// // keys_output: [{11, 1.0}, {11, 0.08}, {2, 0.65}, {2, 0.6}, {0, 0.4}, {0, 0.2}, {-1, 0.7}, {-3, 0.3},] +/// \endcode +/// \endparblock +template::value_type, + class Decomposer> +auto radix_sort_keys_desc(void* temporary_storage, + size_t& storage_size, + KeysInputIterator keys_input, + KeysOutputIterator keys_output, + Size size, + Decomposer decomposer, + unsigned int begin_bit, + unsigned int end_bit, + hipStream_t stream = 0, + bool debug_synchronous = false) + -> std::enable_if_t::value, hipError_t> +{ + static_assert(std::is_integral::value, "Size must be an integral type."); + empty_type* values = nullptr; + bool ignored; + return detail::radix_sort_impl(temporary_storage, + storage_size, + keys_input, + nullptr, + keys_output, + values, + nullptr, + values, + size, + ignored, + decomposer, + begin_bit, + end_bit, + stream, + debug_synchronous); +} + +/// \brief Parallel descending radix sort primitive for device level. +/// +/// \p radix_sort_keys_desc function performs a device-wide radix sort +/// of keys. Function sorts input keys in descending order. +/// +/// \par Overview +/// * The contents of the inputs are not altered by the sorting function. +/// * Returns the required size of \p temporary_storage in \p storage_size +/// if \p temporary_storage is a null pointer. +/// * \p Key type (a \p value_type of \p KeysInputIterator and \p KeysOutputIterator) can be any +/// trivially copyable type. +/// * \p decomposer must be a functor that implements `operator()(Key&) const`. This operator +/// must return a \p rocprim::tuple that contains one or more reference to value(s) of arithmetic types. +/// These references must point to member variables of `Key`, however not every member variable has to be +/// exposed this way. +/// * Ranges specified by \p keys_input and \p keys_output must have at least \p size elements. +/// +/// \tparam Config [optional] configuration of the primitive. It has to be \p radix_sort_config or a class derived from it. +/// \tparam KeysInputIterator random-access iterator type of the input range. Must meet the +/// requirements of a C++ InputIterator concept. It can be a simple pointer type. +/// \tparam KeysOutputIterator random-access iterator type of the output range. Must meet the +/// requirements of a C++ OutputIterator concept. It can be a simple pointer type. +/// \tparam Size integral type that represents the problem size. +/// \tparam Key The value type of the input and output iterators. +/// \tparam Decomposer The type of the decomposer functor. +/// +/// \param [in] temporary_storage pointer to a device-accessible temporary storage. When +/// a null pointer is passed, the required allocation size (in bytes) is written to +/// \p storage_size and function returns without performing the sort operation. +/// \param [in,out] storage_size reference to a size (in bytes) of \p temporary_storage. +/// \param [in] keys_input pointer to the first element in the range to sort. +/// \param [out] keys_output pointer to the first element in the output range. +/// \param [in] size number of element in the input range. +/// \param [in] decomposer decomposer functor that produces a tuple of references from the +/// input key type. +/// \param [in] stream [optional] HIP stream object. Default is \p 0 (default stream). +/// \param [in] debug_synchronous [optional] If true, synchronization after every kernel +/// launch is forced in order to check for errors. Default value is \p false. +/// +/// \returns \p hipSuccess (\p 0) after successful sort; otherwise a HIP runtime error of +/// type \p hipError_t. +/// +/// \par Example +/// \parblock +/// In this example a device-level descending radix sort is performed on an array of +/// values of a custom type, using a custom decomposer. +/// +/// \code{.cpp} +/// #include +/// +/// struct custom_type +/// { +/// int i; +/// double d; +/// }; +/// +/// struct custom_type_decomposer +/// { +/// rocprim::tuple operator()(custom_type& key) const +/// { +/// return rocprim::tuple(key.i, key.d); +/// } +/// }; +/// +/// // Prepare input and output (declare pointers, allocate device memory etc.) +/// size_t input_size; // e.g., 8 +/// custom_type * input; // e.g., [{2, 0.6}, {-3, 0.3}, {2, 0.65}, {0, 0.4}, {0, 0.2}, {11, 0.08}, {11, 1}, {-1, 0.7}] +/// custom_type * output; // empty array of 8 elements +/// +/// size_t temporary_storage_size_bytes; +/// void * temporary_storage_ptr = nullptr; +/// // Get required size of the temporary storage +/// rocprim::radix_sort_keys_desc( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// input, output, input_size, custom_type_decomposer{} +/// ); +/// +/// // allocate temporary storage +/// hipMalloc(&temporary_storage_ptr, temporary_storage_size_bytes); +/// +/// // perform sort +/// rocprim::radix_sort_keys_desc( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// input, output, input_size, custom_type_decomposer{} +/// ); +/// // keys_output: [{11, 1.0}, {11, 0.08}, {2, 0.65}, {2, 0.6}, {0, 0.4}, {0, 0.2}, {-1, 0.7}, {-3, 0.3},] +/// \endcode +/// \endparblock +template::value_type, + class Decomposer> +auto radix_sort_keys_desc(void* temporary_storage, + size_t& storage_size, + KeysInputIterator keys_input, + KeysOutputIterator keys_output, + Size size, + Decomposer decomposer, + hipStream_t stream = 0, + bool debug_synchronous = false) + -> std::enable_if_t::value, hipError_t> +{ + static_assert(std::is_integral::value, "Size must be an integral type."); + empty_type* values = nullptr; + bool ignored; + return detail::radix_sort_impl( + temporary_storage, + storage_size, + keys_input, + nullptr, + keys_output, + values, + nullptr, + values, + size, + ignored, + decomposer, + 0, + detail::decomposer_max_bits::value, + stream, + debug_synchronous); +} + +/// \brief Parallel descending radix sort primitive for device level. +/// +/// \p radix_sort_keys_desc function performs a device-wide radix sort +/// of keys. Function sorts input keys in descending order. +/// +/// \par Overview +/// * The contents of both buffers of \p keys may be altered by the sorting function. +/// * \p current() of \p keys is used as the input. +/// * The function will update \p current() of \p keys to point to the buffer +/// that contains the output range. +/// * Returns the required size of \p temporary_storage in \p storage_size +/// if \p temporary_storage is a null pointer. +/// * The function requires small \p temporary_storage as it does not need +/// a temporary buffer of \p size elements. +/// * \p Key type (a \p value_type of \p KeysInputIterator and \p KeysOutputIterator) can be any +/// trivially copyable type. +/// * \p decomposer must be a functor that implements `operator()(Key&) const`. This operator +/// must return a \p rocprim::tuple that contains one or more reference to value(s) of arithmetic types. +/// These references must point to member variables of `Key`, however not every member variable has to be +/// exposed this way. +/// * Ranges specified by \p keys_input and \p keys_output must have at least \p size elements. +/// * \p begin_bit and \p end_bit can be provided to control the radix range that is considered in the +/// decomposed tuple. For example, if the decomposer returns `rocprim::tuple`, +/// `begin_bit==6` and `end_bit==12`, then the 2 MSBs of the `uint8_t` value and the 4 LSBs of the +/// `int16_t` value are considered for sorting. The range specified by \p begin_bit and \p end_bit +/// must be valid with regards to the sizes of the return tuple's elements. +/// +/// \tparam Config [optional] configuration of the primitive. It has to be \p radix_sort_config or a class derived from it. +/// \tparam Key key type. Must be an integral type or a floating-point type. +/// \tparam Size integral type that represents the problem size. +/// \tparam Decomposer The type of the decomposer functor. +/// +/// \param [in] temporary_storage pointer to a device-accessible temporary storage. When +/// a null pointer is passed, the required allocation size (in bytes) is written to +/// \p storage_size and function returns without performing the sort operation. +/// \param [in,out] storage_size reference to a size (in bytes) of \p temporary_storage. +/// \param [in,out] keys reference to the double-buffer of keys, its \p current() +/// contains the input range and will be updated to point to the output range. +/// \param [in] size number of element in the input range. +/// \param [in] decomposer decomposer functor that produces a tuple of references from the +/// input key type. +/// \param [in] begin_bit [optional] index of the first (least significant) bit used in +/// key comparison. Defaults to `0`. +/// \param [in] end_bit [optional] past-the-end index (most significant) bit used in +/// key comparison. Defaults to the size of the decomposed tuple's bit range. +/// \param [in] stream [optional] HIP stream object. Default is \p 0 (default stream). +/// \param [in] debug_synchronous [optional] If true, synchronization after every kernel +/// launch is forced in order to check for errors. Default value is \p false. +/// +/// \returns \p hipSuccess (\p 0) after successful sort; otherwise a HIP runtime error of +/// type \p hipError_t. +/// +/// \par Example +/// \parblock +/// In this example a device-level descending radix sort is performed on an array of +/// values of a custom type, using a custom decomposer. +/// +/// \code{.cpp} +/// #include +/// +/// struct custom_type +/// { +/// int i; +/// double d; +/// }; +/// +/// struct custom_type_decomposer +/// { +/// rocprim::tuple operator()(custom_type& key) const +/// { +/// return rocprim::tuple(key.i, key.d); +/// } +/// }; +/// +/// // Prepare input and tmp (declare pointers, allocate device memory etc.) +/// constexpr unsigned int begin_bit = 0; +/// constexpr unsigned int end_bit = 96; +/// size_t input_size; // e.g., 8 +/// custom_type * input; // e.g., [{2, 0.6}, {-3, 0.3}, {2, 0.65}, {0, 0.4}, {0, 0.2}, {11, 0.08}, {11, 1}, {-1, 0.7}] +/// custom_type * tmp; // empty array of 8 elements +/// // Create double-buffer +/// rocprim::double_buffer keys(input, tmp); +/// +/// size_t temporary_storage_size_bytes; +/// void * temporary_storage_ptr = nullptr; +/// // Get required size of the temporary storage +/// rocprim::radix_sort_keys_desc( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// keys, input_size, custom_type_decomposer{}, begin_bit, end_bit +/// ); +/// +/// // allocate temporary storage +/// hipMalloc(&temporary_storage_ptr, temporary_storage_size_bytes); +/// +/// // perform sort +/// rocprim::radix_sort_keys_desc( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// keys, input_size, custom_type_decomposer{} +/// ); +/// // keys.current(): [{11, 1.0}, {11, 0.08}, {2, 0.65}, {2, 0.6}, {0, 0.4}, {0, 0.2}, {-1, 0.7}, {-3, 0.3},] +/// \endcode +/// \endparblock +template +auto radix_sort_keys_desc(void* temporary_storage, + size_t& storage_size, + double_buffer& keys, + Size size, + Decomposer decomposer, + unsigned int begin_bit, + unsigned int end_bit, + hipStream_t stream = 0, + bool debug_synchronous = false) + -> std::enable_if_t::value, hipError_t> +{ + static_assert(std::is_integral::value, "Size must be an integral type."); + empty_type* values = nullptr; + bool is_result_in_output; + hipError_t error = detail::radix_sort_impl(temporary_storage, + storage_size, + keys.current(), + keys.current(), + keys.alternate(), + values, + values, + values, + size, + is_result_in_output, + decomposer, + begin_bit, + end_bit, + stream, + debug_synchronous); + if(temporary_storage != nullptr && error == hipSuccess && is_result_in_output) + { + keys.swap(); + } + return error; +} + +/// \brief Parallel descending radix sort primitive for device level. +/// +/// \p radix_sort_keys_desc function performs a device-wide radix sort +/// of keys. Function sorts input keys in descending order. +/// +/// \par Overview +/// * The contents of both buffers of \p keys may be altered by the sorting function. +/// * \p current() of \p keys is used as the input. +/// * The function will update \p current() of \p keys to point to the buffer +/// that contains the output range. +/// * Returns the required size of \p temporary_storage in \p storage_size +/// if \p temporary_storage is a null pointer. +/// * The function requires small \p temporary_storage as it does not need +/// a temporary buffer of \p size elements. +/// * \p Key type (a \p value_type of \p KeysInputIterator and \p KeysOutputIterator) can be any +/// trivially copyable type. +/// * \p decomposer must be a functor that implements `operator()(Key&) const`. This operator +/// must return a \p rocprim::tuple that contains one or more reference to value(s) of arithmetic types. +/// These references must point to member variables of `Key`, however not every member variable has to be +/// exposed this way. +/// * Ranges specified by \p keys_input and \p keys_output must have at least \p size elements. +/// +/// \tparam Config [optional] configuration of the primitive. It has to be \p radix_sort_config or a class derived from it. +/// \tparam Key key type. Must be an integral type or a floating-point type. +/// \tparam Size integral type that represents the problem size. +/// \tparam Decomposer The type of the decomposer functor. +/// +/// \param [in] temporary_storage pointer to a device-accessible temporary storage. When +/// a null pointer is passed, the required allocation size (in bytes) is written to +/// \p storage_size and function returns without performing the sort operation. +/// \param [in,out] storage_size reference to a size (in bytes) of \p temporary_storage. +/// \param [in,out] keys reference to the double-buffer of keys, its \p current() +/// contains the input range and will be updated to point to the output range. +/// \param [in] size number of element in the input range. +/// \param [in] decomposer decomposer functor that produces a tuple of references from the +/// input key type. +/// \param [in] stream [optional] HIP stream object. Default is \p 0 (default stream). +/// \param [in] debug_synchronous [optional] If true, synchronization after every kernel +/// launch is forced in order to check for errors. Default value is \p false. +/// +/// \returns \p hipSuccess (\p 0) after successful sort; otherwise a HIP runtime error of +/// type \p hipError_t. +/// +/// \par Example +/// \parblock +/// In this example a device-level descending radix sort is performed on an array of +/// values of a custom type, using a custom decomposer. +/// +/// \code{.cpp} +/// #include +/// +/// struct custom_type +/// { +/// int i; +/// double d; +/// }; +/// +/// struct custom_type_decomposer +/// { +/// rocprim::tuple operator()(custom_type& key) const +/// { +/// return rocprim::tuple(key.i, key.d); +/// } +/// }; +/// +/// // Prepare input and tmp (declare pointers, allocate device memory etc.) +/// size_t input_size; // e.g., 8 +/// custom_type * input; // e.g., [{2, 0.6}, {-3, 0.3}, {2, 0.65}, {0, 0.4}, {0, 0.2}, {11, 0.08}, {11, 1}, {-1, 0.7}] +/// custom_type * tmp; // empty array of 8 elements +/// // Create double-buffer +/// rocprim::double_buffer keys(input, tmp); +/// +/// size_t temporary_storage_size_bytes; +/// void * temporary_storage_ptr = nullptr; +/// // Get required size of the temporary storage +/// rocprim::radix_sort_keys_desc( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// keys, input_size, custom_type_decomposer{} +/// ); +/// +/// // allocate temporary storage +/// hipMalloc(&temporary_storage_ptr, temporary_storage_size_bytes); +/// +/// // perform sort +/// rocprim::radix_sort_keys_desc( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// keys, input_size, custom_type_decomposer{} +/// ); +/// // keys.current(): [{11, 1.0}, {11, 0.08}, {2, 0.65}, {2, 0.6}, {0, 0.4}, {0, 0.2}, {-1, 0.7}, {-3, 0.3},] +/// \endcode +/// \endparblock +template +auto radix_sort_keys_desc(void* temporary_storage, + size_t& storage_size, + double_buffer& keys, + Size size, + Decomposer decomposer, + hipStream_t stream = 0, + bool debug_synchronous = false) + -> std::enable_if_t::value, hipError_t> +{ + static_assert(std::is_integral::value, "Size must be an integral type."); + empty_type* values = nullptr; + bool is_result_in_output; + hipError_t error + = detail::radix_sort_impl(temporary_storage, + storage_size, + keys.current(), + keys.current(), + keys.alternate(), + values, + values, + values, + size, + is_result_in_output, + decomposer, + 0, + detail::decomposer_max_bits::value, + stream, + debug_synchronous); + if(temporary_storage != nullptr && error == hipSuccess && is_result_in_output) + { + keys.swap(); + } + return error; +} + +/// \brief Parallel ascending radix sort-by-key primitive for device level. +/// +/// \p radix_sort_pairs function performs a device-wide radix sort +/// of (key, value) pairs. Function sorts input pairs in ascending order of keys. +/// +/// \par Overview +/// * The contents of the inputs are not altered by the sorting function. +/// * Returns the required size of \p temporary_storage in \p storage_size +/// if \p temporary_storage is a null pointer. +/// * \p Key type (a \p value_type of \p KeysInputIterator and \p KeysOutputIterator) must be +/// an arithmetic type (that is, an integral type or a floating-point type). +/// * Ranges specified by \p keys_input, \p keys_output, \p values_input and \p values_output must +/// have at least \p size elements. +/// * If \p Key is an integer type and the range of keys is known in advance, the performance +/// can be improved by setting \p begin_bit and \p end_bit, for example if all keys are in range +/// [100, 10000], begin_bit = 0 and end_bit = 14 will cover the whole range. +/// +/// \tparam Config [optional] configuration of the primitive. It has to be \p radix_sort_config or a class derived from it. +/// \tparam KeysInputIterator random-access iterator type of the input range. Must meet the +/// requirements of a C++ InputIterator concept. It can be a simple pointer type. +/// \tparam KeysOutputIterator random-access iterator type of the output range. Must meet the +/// requirements of a C++ OutputIterator concept. It can be a simple pointer type. +/// \tparam ValuesInputIterator random-access iterator type of the input range. Must meet the +/// requirements of a C++ InputIterator concept. It can be a simple pointer type. +/// \tparam ValuesOutputIterator random-access iterator type of the output range. Must meet the +/// requirements of a C++ OutputIterator concept. It can be a simple pointer type. +/// \tparam Size integral type that represents the problem size. +/// +/// \param [in] temporary_storage pointer to a device-accessible temporary storage. When +/// a null pointer is passed, the required allocation size (in bytes) is written to +/// \p storage_size and function returns without performing the sort operation. +/// \param [in,out] storage_size reference to a size (in bytes) of \p temporary_storage. +/// \param [in] keys_input pointer to the first element in the range to sort. +/// \param [out] keys_output pointer to the first element in the output range. +/// \param [in] values_input pointer to the first element in the range to sort. +/// \param [out] values_output pointer to the first element in the output range. +/// \param [in] size number of element in the input range. +/// \param [in] begin_bit [optional] index of the first (least significant) bit used in +/// key comparison. Must be in range [0; 8 * sizeof(Key)). Default value: \p 0. +/// Non-default value not supported for floating-point key-types. +/// \param [in] end_bit [optional] past-the-end index (most significant) bit used in +/// key comparison. Must be in range (begin_bit; 8 * sizeof(Key)]. Default +/// value: \p 8 * sizeof(Key). Non-default value not supported for floating-point key-types. +/// \param [in] stream [optional] HIP stream object. Default is \p 0 (default stream). +/// \param [in] debug_synchronous [optional] If true, synchronization after every kernel +/// launch is forced in order to check for errors. Default value is \p false. +/// +/// \returns \p hipSuccess (\p 0) after successful sort; otherwise a HIP runtime error of +/// type \p hipError_t. +/// +/// \par Example +/// \parblock +/// In this example a device-level ascending radix sort is performed where input keys are +/// represented by an array of unsigned integers and input values by an array of doubles. +/// +/// \code{.cpp} +/// #include +/// +/// // Prepare input and output (declare pointers, allocate device memory etc.) +/// size_t input_size; // e.g., 8 +/// unsigned int * keys_input; // e.g., [ 6, 3, 5, 4, 1, 8, 1, 7] +/// double * values_input; // e.g., [-5, 2, -4, 3, -1, -8, -2, 7] +/// unsigned int * keys_output; // empty array of 8 elements +/// double * values_output; // empty array of 8 elements +/// +/// // Keys are in range [0; 8], so we can limit compared bit to bits on indexes +/// // 0, 1, 2, 3, and 4. In order to do this begin_bit is set to 0 and end_bit +/// // is set to 5. +/// +/// size_t temporary_storage_size_bytes; +/// void * temporary_storage_ptr = nullptr; +/// // Get required size of the temporary storage +/// rocprim::radix_sort_pairs( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// keys_input, keys_output, values_input, values_output, +/// input_size, 0, 5 +/// ); +/// +/// // allocate temporary storage +/// hipMalloc(&temporary_storage_ptr, temporary_storage_size_bytes); +/// +/// // perform sort +/// rocprim::radix_sort_pairs( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// keys_input, keys_output, values_input, values_output, +/// input_size, 0, 5 +/// ); +/// // keys_output: [ 1, 1, 3, 4, 5, 6, 7, 8] +/// // values_output: [-1, -2, 2, 3, -4, -5, 7, -8] +/// \endcode +/// \endparblock +template::value_type> +hipError_t radix_sort_pairs(void* temporary_storage, + size_t& storage_size, + KeysInputIterator keys_input, + KeysOutputIterator keys_output, + ValuesInputIterator values_input, + ValuesOutputIterator values_output, + Size size, + unsigned int begin_bit = 0, + unsigned int end_bit = 8 * sizeof(Key), + hipStream_t stream = 0, + bool debug_synchronous = false) +{ + static_assert(std::is_integral::value, "Size must be an integral type."); + bool ignored; + return detail::radix_sort_impl(temporary_storage, + storage_size, + keys_input, + nullptr, + keys_output, + values_input, + nullptr, + values_output, + size, + ignored, + identity_decomposer{}, + begin_bit, + end_bit, + stream, + debug_synchronous); +} + +/// \brief Parallel ascending radix sort-by-key primitive for device level. +/// +/// \p radix_sort_pairs function performs a device-wide radix sort +/// of (key, value) pairs. Function sorts input pairs in ascending order of keys. +/// +/// \par Overview +/// * The contents of both buffers of \p keys and \p values may be altered by the sorting function. +/// * \p current() of \p keys and \p values are used as the input. +/// * The function will update \p current() of \p keys and \p values to point to buffers +/// that contains the output range. +/// * Returns the required size of \p temporary_storage in \p storage_size +/// if \p temporary_storage is a null pointer. +/// * The function requires small \p temporary_storage as it does not need +/// a temporary buffer of \p size elements. +/// * \p Key type must be an arithmetic type (that is, an integral type or a floating-point +/// type). +/// * Buffers of \p keys must have at least \p size elements. +/// * If \p Key is an integer type and the range of keys is known in advance, the performance +/// can be improved by setting \p begin_bit and \p end_bit, for example if all keys are in range +/// [100, 10000], begin_bit = 0 and end_bit = 14 will cover the whole range. +/// +/// \tparam Config [optional] configuration of the primitive. It has to be \p radix_sort_config or a class derived from it. +/// \tparam Key key type. Must be an integral type or a floating-point type. +/// \tparam Value value type. +/// \tparam Size integral type that represents the problem size. +/// +/// \param [in] temporary_storage pointer to a device-accessible temporary storage. When +/// a null pointer is passed, the required allocation size (in bytes) is written to +/// \p storage_size and function returns without performing the sort operation. +/// \param [in,out] storage_size reference to a size (in bytes) of \p temporary_storage. +/// \param [in,out] keys reference to the double-buffer of keys, its \p current() +/// contains the input range and will be updated to point to the output range. +/// \param [in,out] values reference to the double-buffer of values, its \p current() +/// contains the input range and will be updated to point to the output range. +/// \param [in] size number of element in the input range. +/// \param [in] begin_bit [optional] index of the first (least significant) bit used in +/// key comparison. Must be in range [0; 8 * sizeof(Key)). Default value: \p 0. +/// Non-default value not supported for floating-point key-types. +/// \param [in] end_bit [optional] past-the-end index (most significant) bit used in +/// key comparison. Must be in range (begin_bit; 8 * sizeof(Key)]. Default +/// value: \p 8 * sizeof(Key). Non-default value not supported for floating-point key-types. +/// \param [in] stream [optional] HIP stream object. Default is \p 0 (default stream). +/// \param [in] debug_synchronous [optional] If true, synchronization after every kernel +/// launch is forced in order to check for errors. Default value is \p false. +/// +/// \returns \p hipSuccess (\p 0) after successful sort; otherwise a HIP runtime error of +/// type \p hipError_t. +/// +/// \par Example +/// \parblock +/// In this example a device-level ascending radix sort is performed where input keys are +/// represented by an array of unsigned integers and input values by an array of doubles. +/// +/// \code{.cpp} +/// #include +/// +/// // Prepare input and tmp (declare pointers, allocate device memory etc.) +/// size_t input_size; // e.g., 8 +/// unsigned int * keys_input; // e.g., [ 6, 3, 5, 4, 1, 8, 1, 7] +/// double * values_input; // e.g., [-5, 2, -4, 3, -1, -8, -2, 7] +/// unsigned int * keys_tmp; // empty array of 8 elements +/// double* values_tmp; // empty array of 8 elements +/// // Create double-buffers +/// rocprim::double_buffer keys(keys_input, keys_tmp); +/// rocprim::double_buffer values(values_input, values_tmp); +/// +/// // Keys are in range [0; 8], so we can limit compared bit to bits on indexes +/// // 0, 1, 2, 3, and 4. In order to do this begin_bit is set to 0 and end_bit +/// // is set to 5. +/// +/// size_t temporary_storage_size_bytes; +/// void * temporary_storage_ptr = nullptr; +/// // Get required size of the temporary storage +/// rocprim::radix_sort_pairs( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// keys, values, input_size, +/// 0, 5 +/// ); +/// +/// // allocate temporary storage +/// hipMalloc(&temporary_storage_ptr, temporary_storage_size_bytes); +/// +/// // perform sort +/// rocprim::radix_sort_pairs( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// keys, values, input_size, +/// 0, 5 +/// ); +/// // keys.current(): [ 1, 1, 3, 4, 5, 6, 7, 8] +/// // values.current(): [-1, -2, 2, 3, -4, -5, 7, -8] +/// \endcode +/// \endparblock +template +hipError_t radix_sort_pairs(void* temporary_storage, + size_t& storage_size, + double_buffer& keys, + double_buffer& values, + Size size, + unsigned int begin_bit = 0, + unsigned int end_bit = 8 * sizeof(Key), + hipStream_t stream = 0, + bool debug_synchronous = false) +{ + static_assert(std::is_integral::value, "Size must be an integral type."); + bool is_result_in_output; + hipError_t error = detail::radix_sort_impl(temporary_storage, + storage_size, + keys.current(), + keys.current(), + keys.alternate(), + values.current(), + values.current(), + values.alternate(), + size, + is_result_in_output, + identity_decomposer{}, + begin_bit, + end_bit, + stream, + debug_synchronous); + if(temporary_storage != nullptr && error == hipSuccess && is_result_in_output) + { + keys.swap(); + values.swap(); + } + return error; +} + +/// \brief Parallel ascending radix sort-by-key primitive for device level. +/// +/// \p radix_sort_pairs function performs a device-wide radix sort +/// of (key, value) pairs. Function sorts input pairs in ascending order of keys. +/// +/// \par Overview +/// * The contents of the inputs are not altered by the sorting function. +/// * Returns the required size of \p temporary_storage in \p storage_size +/// if \p temporary_storage is a null pointer. +/// * \p Key type (a \p value_type of \p KeysInputIterator and \p KeysOutputIterator) can be any +/// trivially copyable type. +/// * \p decomposer must be a functor that implements `operator()(Key&) const`. This operator +/// must return a \p rocprim::tuple that contains one or more reference to value(s) of arithmetic types. +/// These references must point to member variables of `Key`, however not every member variable has to be +/// exposed this way. +/// * Ranges specified by \p keys_input and \p keys_output must have at least \p size elements. +/// * \p begin_bit and \p end_bit can be provided to control the radix range that is considered in the +/// decomposed tuple. For example, if the decomposer returns `rocprim::tuple`, +/// `begin_bit==6` and `end_bit==12`, then the 2 MSBs of the `uint8_t` value and the 4 LSBs of the +/// `int16_t` value are considered for sorting. The range specified by \p begin_bit and \p end_bit +/// must be valid with regards to the sizes of the return tuple's elements. +/// +/// \tparam Config [optional] configuration of the primitive. It has to be \p radix_sort_config or a class derived from it. +/// \tparam KeysInputIterator random-access iterator type of the input range. Must meet the +/// requirements of a C++ InputIterator concept. It can be a simple pointer type. +/// \tparam KeysOutputIterator random-access iterator type of the output range. Must meet the +/// requirements of a C++ OutputIterator concept. It can be a simple pointer type. +/// \tparam ValuesInputIterator random-access iterator type of the input range. Must meet the +/// requirements of a C++ InputIterator concept. It can be a simple pointer type. +/// \tparam ValuesOutputIterator random-access iterator type of the output range. Must meet the +/// requirements of a C++ OutputIterator concept. It can be a simple pointer type. +/// \tparam Size integral type that represents the problem size. +/// \tparam Key The value type of the input and output iterators. +/// \tparam Decomposer The type of the decomposer functor. +/// +/// \param [in] temporary_storage pointer to a device-accessible temporary storage. When +/// a null pointer is passed, the required allocation size (in bytes) is written to +/// \p storage_size and function returns without performing the sort operation. +/// \param [in,out] storage_size reference to a size (in bytes) of \p temporary_storage. +/// \param [in] keys_input pointer to the first element in the range to sort. +/// \param [out] keys_output pointer to the first element in the output range. +/// \param [in] values_input pointer to the first element in the range to sort. +/// \param [out] values_output pointer to the first element in the output range. +/// \param [in] size number of element in the input range. +/// \param [in] decomposer decomposer functor that produces a tuple of references from the +/// input key type. +/// \param [in] begin_bit index of the first (least significant) bit used in +/// key comparison. +/// \param [in] end_bit past-the-end index (most significant) bit used in +/// key comparison. +/// \param [in] stream [optional] HIP stream object. Default is \p 0 (default stream). +/// \param [in] debug_synchronous [optional] If true, synchronization after every kernel +/// launch is forced in order to check for errors. Default value is \p false. +/// +/// \returns \p hipSuccess (\p 0) after successful sort; otherwise a HIP runtime error of +/// type \p hipError_t. +/// +/// \par Example +/// \parblock +/// In this example a device-level ascending radix sort is performed where input keys are +/// represented by an array of a custom type and input values by an array of doubles. +/// +/// \code{.cpp} +/// #include +/// +/// struct custom_type +/// { +/// int i; +/// double d; +/// }; +/// +/// struct custom_type_decomposer +/// { +/// rocprim::tuple operator()(custom_type& key) const +/// { +/// return rocprim::tuple(key.i, key.d); +/// } +/// }; +/// +/// // Prepare input and output (declare pointers, allocate device memory etc.) +/// size_t input_size; // e.g., 8 +/// custom_type * keys_input; // e.g., [{2, 0.6}, {0, 0.3}, {2, 0.65}, {0, 0.4}, {0, 0.2}, {11, 0.08}, {11, 1.0}, {5, 0.7}] +/// double * values_input; // e.g., [-5, 2, -4, 3, -1, -8, -2, 7] +/// custom_type * keys_output; // empty array of 8 elements +/// double * values_output; // empty array of 8 elements +/// +/// // The integer field of the keys is in range 0-11, which can be represented on 4 bits, +/// // while for the double member we must specify full bit range [0; 63]. Therefore begin_bit +/// // is set to 0 and end_bit is set to 68. +/// constexpr unsigned int begin_bit = 0; +/// constexpr unsigned int end_bit = 68; +/// +/// size_t temporary_storage_size_bytes; +/// void * temporary_storage_ptr = nullptr; +/// // Get required size of the temporary storage +/// rocprim::radix_sort_pairs( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// keys_input, keys_output, values_input, values_output, +/// input_size, custom_type_decomposer{}, begin_bit, end_bit +/// ); +/// +/// // allocate temporary storage +/// hipMalloc(&temporary_storage_ptr, temporary_storage_size_bytes); +/// +/// // perform sort +/// rocprim::radix_sort_pairs( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// keys_input, keys_output, values_input, values_output, +/// input_size, custom_type_decomposer{}, begin_bit, end_bit +/// ); +/// // keys_output: [{0, 0.2}, {0, 0.3}, {0, 0.4}, {2, 0.6}, {2, 0.65}, {5, 0.7}, {11, 0.08}, {11, 1.0}] +/// // values_output: [-1, 2, 3, -5, -4, 7, -8, -2] +/// \endcode +/// \endparblock +template::value_type, + class Decomposer> +auto radix_sort_pairs(void* temporary_storage, + size_t& storage_size, + KeysInputIterator keys_input, + KeysOutputIterator keys_output, + ValuesInputIterator values_input, + ValuesOutputIterator values_output, + Size size, + Decomposer decomposer, + unsigned int begin_bit, + unsigned int end_bit, + hipStream_t stream = 0, + bool debug_synchronous = false) + -> std::enable_if_t::value, hipError_t> +{ + static_assert(std::is_integral::value, "Size must be an integral type."); + bool ignored; + return detail::radix_sort_impl(temporary_storage, + storage_size, + keys_input, + nullptr, + keys_output, + values_input, + nullptr, + values_output, + size, + ignored, + decomposer, + begin_bit, + end_bit, + stream, + debug_synchronous); +} + +/// \brief Parallel ascending radix sort-by-key primitive for device level. +/// +/// \p radix_sort_pairs function performs a device-wide radix sort +/// of (key, value) pairs. Function sorts input pairs in ascending order of keys. +/// +/// \par Overview +/// * The contents of the inputs are not altered by the sorting function. +/// * Returns the required size of \p temporary_storage in \p storage_size +/// if \p temporary_storage is a null pointer. +/// * \p Key type (a \p value_type of \p KeysInputIterator and \p KeysOutputIterator) can be any +/// trivially copyable type. +/// * \p decomposer must be a functor that implements `operator()(Key&) const`. This operator +/// must return a \p rocprim::tuple that contains one or more reference to value(s) of arithmetic types. +/// These references must point to member variables of `Key`, however not every member variable has to be +/// exposed this way. +/// * Ranges specified by \p keys_input and \p keys_output must have at least \p size elements. +/// +/// \tparam Config [optional] configuration of the primitive. It has to be \p radix_sort_config or a class derived from it. +/// \tparam KeysInputIterator random-access iterator type of the input range. Must meet the +/// requirements of a C++ InputIterator concept. It can be a simple pointer type. +/// \tparam KeysOutputIterator random-access iterator type of the output range. Must meet the +/// requirements of a C++ OutputIterator concept. It can be a simple pointer type. +/// \tparam ValuesInputIterator random-access iterator type of the input range. Must meet the +/// requirements of a C++ InputIterator concept. It can be a simple pointer type. +/// \tparam ValuesOutputIterator random-access iterator type of the output range. Must meet the +/// requirements of a C++ OutputIterator concept. It can be a simple pointer type. +/// \tparam Size integral type that represents the problem size. +/// \tparam Key The value type of the input and output iterators. +/// \tparam Decomposer The type of the decomposer functor. +/// +/// \param [in] temporary_storage pointer to a device-accessible temporary storage. When +/// a null pointer is passed, the required allocation size (in bytes) is written to +/// \p storage_size and function returns without performing the sort operation. +/// \param [in,out] storage_size reference to a size (in bytes) of \p temporary_storage. +/// \param [in] keys_input pointer to the first element in the range to sort. +/// \param [out] keys_output pointer to the first element in the output range. +/// \param [in] values_input pointer to the first element in the range to sort. +/// \param [out] values_output pointer to the first element in the output range. +/// \param [in] size number of element in the input range. +/// \param [in] decomposer decomposer functor that produces a tuple of references from the +/// input key type. +/// \param [in] stream [optional] HIP stream object. Default is \p 0 (default stream). +/// \param [in] debug_synchronous [optional] If true, synchronization after every kernel +/// launch is forced in order to check for errors. Default value is \p false. +/// +/// \returns \p hipSuccess (\p 0) after successful sort; otherwise a HIP runtime error of +/// type \p hipError_t. +/// +/// \par Example +/// \parblock +/// In this example a device-level ascending radix sort is performed where input keys are +/// represented by an array of a custom type and input values by an array of doubles. +/// +/// \code{.cpp} +/// #include +/// +/// struct custom_type +/// { +/// int i; +/// double d; +/// }; +/// +/// struct custom_type_decomposer +/// { +/// rocprim::tuple operator()(custom_type& key) const +/// { +/// return rocprim::tuple(key.i, key.d); +/// } +/// }; +/// +/// // Prepare input and output (declare pointers, allocate device memory etc.) +/// size_t input_size; // e.g., 8 +/// custom_type * keys_input; // e.g., [{2, 0.6}, {0, 0.3}, {2, 0.65}, {0, 0.4}, {0, 0.2}, {11, 0.08}, {11, 1.0}, {5, 0.7}] +/// double * values_input; // e.g., [-5, 2, -4, 3, -1, -8, -2, 7] +/// custom_type * keys_output; // empty array of 8 elements +/// double * values_output; // empty array of 8 elements +/// +/// size_t temporary_storage_size_bytes; +/// void * temporary_storage_ptr = nullptr; +/// // Get required size of the temporary storage +/// rocprim::radix_sort_pairs( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// keys_input, keys_output, values_input, values_output, +/// input_size, custom_type_decomposer{} +/// ); +/// +/// // allocate temporary storage +/// hipMalloc(&temporary_storage_ptr, temporary_storage_size_bytes); +/// +/// // perform sort +/// rocprim::radix_sort_pairs( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// keys_input, keys_output, values_input, values_output, +/// input_size, custom_type_decomposer{} +/// ); +/// // keys_output: [{0, 0.2}, {0, 0.3}, {0, 0.4}, {2, 0.6}, {2, 0.65}, {5, 0.7}, {11, 0.08}, {11, 1.0}] +/// // values_output: [-1, 2, 3, -5, -4, 7, -8, -2] +/// \endcode +/// \endparblock +template::value_type, + class Decomposer> +auto radix_sort_pairs(void* temporary_storage, + size_t& storage_size, + KeysInputIterator keys_input, + KeysOutputIterator keys_output, + ValuesInputIterator values_input, + ValuesOutputIterator values_output, + Size size, + Decomposer decomposer, + hipStream_t stream = 0, + bool debug_synchronous = false) + -> std::enable_if_t::value, hipError_t> +{ + static_assert(std::is_integral::value, "Size must be an integral type."); + bool ignored; + return detail::radix_sort_impl( + temporary_storage, + storage_size, + keys_input, + nullptr, + keys_output, + values_input, + nullptr, + values_output, + size, + ignored, + decomposer, + 0, + detail::decomposer_max_bits::value, + stream, + debug_synchronous); +} + +/// \brief Parallel ascending radix sort-by-key primitive for device level. +/// +/// \p radix_sort_pairs function performs a device-wide radix sort +/// of (key, value) pairs. Function sorts input pairs in ascending order of keys. +/// +/// \par Overview +/// * The contents of both buffers of \p keys and \p values may be altered by the sorting function. +/// * \p current() of \p keys and \p values are used as the input. +/// * The function will update \p current() of \p keys and \p values to point to buffers +/// that contains the output range. +/// * Returns the required size of \p temporary_storage in \p storage_size +/// if \p temporary_storage is a null pointer. +/// * The function requires small \p temporary_storage as it does not need +/// a temporary buffer of \p size elements. +/// * \p Key type (a \p value_type of \p KeysInputIterator and \p KeysOutputIterator) can be any +/// trivially copyable type. +/// * \p decomposer must be a functor that implements `operator()(Key&) const`. This operator +/// must return a \p rocprim::tuple that contains one or more reference to value(s) of arithmetic types. +/// These references must point to member variables of `Key`, however not every member variable has to be +/// exposed this way. +/// * Ranges specified by \p keys_input and \p keys_output must have at least \p size elements. +/// * \p begin_bit and \p end_bit can be provided to control the radix range that is considered in the +/// decomposed tuple. For example, if the decomposer returns `rocprim::tuple`, +/// `begin_bit==6` and `end_bit==12`, then the 2 MSBs of the `uint8_t` value and the 4 LSBs of the +/// `int16_t` value are considered for sorting. The range specified by \p begin_bit and \p end_bit +/// must be valid with regards to the sizes of the return tuple's elements. +/// +/// \tparam Config [optional] configuration of the primitive. It has to be \p radix_sort_config or a class derived from it. +/// \tparam Key key type. Must be an integral type or a floating-point type. +/// \tparam Value value type. +/// \tparam Size integral type that represents the problem size. +/// \tparam Decomposer The type of the decomposer functor. +/// +/// \param [in] temporary_storage pointer to a device-accessible temporary storage. When +/// a null pointer is passed, the required allocation size (in bytes) is written to +/// \p storage_size and function returns without performing the sort operation. +/// \param [in,out] storage_size reference to a size (in bytes) of \p temporary_storage. +/// \param [in,out] keys reference to the double-buffer of keys, its \p current() +/// contains the input range and will be updated to point to the output range. +/// \param [in,out] values reference to the double-buffer of values, its \p current() +/// contains the input range and will be updated to point to the output range. +/// \param [in] size number of element in the input range. +/// \param [in] decomposer decomposer functor that produces a tuple of references from the +/// input key type. +/// \param [in] begin_bit index of the first (least significant) bit used in +/// key comparison. +/// \param [in] end_bit past-the-end index (most significant) bit used in +/// key comparison. Defaults to the size of the decomposed tuple's bit range. +/// \param [in] stream [optional] HIP stream object. Default is \p 0 (default stream). +/// \param [in] debug_synchronous [optional] If true, synchronization after every kernel /// launch is forced in order to check for errors. Default value is \p false. /// /// \returns \p hipSuccess (\p 0) after successful sort; otherwise a HIP runtime error of @@ -1449,74 +2882,918 @@ hipError_t radix_sort_keys_desc(void * temporary_storage, /// \par Example /// \parblock /// In this example a device-level ascending radix sort is performed where input keys are -/// represented by an array of unsigned integers and input values by an array of doubles. +/// represented by an array of a custom type and input values by an array of doubles. +/// +/// \code{.cpp} +/// #include +/// +/// struct custom_type +/// { +/// int i; +/// double d; +/// }; +/// +/// struct custom_type_decomposer +/// { +/// rocprim::tuple operator()(custom_type& key) const +/// { +/// return rocprim::tuple(key.i, key.d); +/// } +/// }; +/// +/// // Prepare input and tmp (declare pointers, allocate device memory etc.) +/// size_t input_size; // e.g., 8 +/// custom_type * keys_input; // e.g., [{2, 0.6}, {0, 0.3}, {2, 0.65}, {0, 0.4}, {0, 0.2}, {11, 0.08}, {11, 1.0}, {5, 0.7}] +/// double * values_input; // e.g., [-5, 2, -4, 3, -1, -8, -2, 7] +/// custom_type * keys_tmp; // empty array of 8 elements +/// double* values_tmp; // empty array of 8 elements +/// // Create double-buffers +/// rocprim::double_buffer keys(keys_input, keys_tmp); +/// rocprim::double_buffer values(values_input, values_tmp); +/// +/// // The integer field of the keys is in range 0-11, which can be represented on 4 bits, +/// // while for the double member we must specify full bit range [0; 63]. Therefore begin_bit +/// // is set to 0 and end_bit is set to 68. +/// constexpr unsigned int begin_bit = 0; +/// constexpr unsigned int end_bit = 68; +/// +/// size_t temporary_storage_size_bytes; +/// void * temporary_storage_ptr = nullptr; +/// // Get required size of the temporary storage +/// rocprim::radix_sort_pairs( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// keys, values, input_size, custom_type_decomposer{}, begin_bit, end_bit +/// ); +/// +/// // allocate temporary storage +/// hipMalloc(&temporary_storage_ptr, temporary_storage_size_bytes); +/// +/// // perform sort +/// rocprim::radix_sort_pairs( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// keys, values, input_size, custom_type_decomposer{}, begin_bit, end_bit +/// ); +/// // keys.current(): [{0, 0.2}, {0, 0.3}, {0, 0.4}, {2, 0.6}, {2, 0.65}, {5, 0.7}, {11, 0.08}, {11, 1.0}] +/// // values.current(): [-1, 2, 3, -5, -4, 7, -8, -2] +/// \endcode +/// \endparblock +template +auto radix_sort_pairs(void* temporary_storage, + size_t& storage_size, + double_buffer& keys, + double_buffer& values, + Size size, + Decomposer decomposer, + unsigned int begin_bit, + unsigned int end_bit, + hipStream_t stream = 0, + bool debug_synchronous = false) + -> std::enable_if_t::value, hipError_t> +{ + static_assert(std::is_integral::value, "Size must be an integral type."); + bool is_result_in_output; + hipError_t error = detail::radix_sort_impl(temporary_storage, + storage_size, + keys.current(), + keys.current(), + keys.alternate(), + values.current(), + values.current(), + values.alternate(), + size, + is_result_in_output, + decomposer, + begin_bit, + end_bit, + stream, + debug_synchronous); + if(temporary_storage != nullptr && error == hipSuccess && is_result_in_output) + { + keys.swap(); + values.swap(); + } + return error; +} + +/// \brief Parallel ascending radix sort-by-key primitive for device level. +/// +/// \p radix_sort_pairs function performs a device-wide radix sort +/// of (key, value) pairs. Function sorts input pairs in ascending order of keys. +/// +/// \par Overview +/// * The contents of both buffers of \p keys and \p values may be altered by the sorting function. +/// * \p current() of \p keys and \p values are used as the input. +/// * The function will update \p current() of \p keys and \p values to point to buffers +/// that contains the output range. +/// * Returns the required size of \p temporary_storage in \p storage_size +/// if \p temporary_storage is a null pointer. +/// * The function requires small \p temporary_storage as it does not need +/// a temporary buffer of \p size elements. +/// * \p Key type (a \p value_type of \p KeysInputIterator and \p KeysOutputIterator) can be any +/// trivially copyable type. +/// * \p decomposer must be a functor that implements `operator()(Key&) const`. This operator +/// must return a \p rocprim::tuple that contains one or more reference to value(s) of arithmetic types. +/// These references must point to member variables of `Key`, however not every member variable has to be +/// exposed this way. +/// * Ranges specified by \p keys_input and \p keys_output must have at least \p size elements. +/// +/// \tparam Config [optional] configuration of the primitive. It has to be \p radix_sort_config or a class derived from it. +/// \tparam Key key type. Must be an integral type or a floating-point type. +/// \tparam Value value type. +/// \tparam Size integral type that represents the problem size. +/// \tparam Decomposer The type of the decomposer functor. +/// +/// \param [in] temporary_storage pointer to a device-accessible temporary storage. When +/// a null pointer is passed, the required allocation size (in bytes) is written to +/// \p storage_size and function returns without performing the sort operation. +/// \param [in,out] storage_size reference to a size (in bytes) of \p temporary_storage. +/// \param [in,out] keys reference to the double-buffer of keys, its \p current() +/// contains the input range and will be updated to point to the output range. +/// \param [in,out] values reference to the double-buffer of values, its \p current() +/// contains the input range and will be updated to point to the output range. +/// \param [in] size number of element in the input range. +/// \param [in] decomposer decomposer functor that produces a tuple of references from the +/// input key type. +/// \param [in] stream [optional] HIP stream object. Default is \p 0 (default stream). +/// \param [in] debug_synchronous [optional] If true, synchronization after every kernel +/// launch is forced in order to check for errors. Default value is \p false. +/// +/// \returns \p hipSuccess (\p 0) after successful sort; otherwise a HIP runtime error of +/// type \p hipError_t. +/// +/// \par Example +/// \parblock +/// In this example a device-level ascending radix sort is performed where input keys are +/// represented by an array of a custom type and input values by an array of doubles. +/// +/// \code{.cpp} +/// #include +/// +/// struct custom_type +/// { +/// int i; +/// double d; +/// }; +/// +/// struct custom_type_decomposer +/// { +/// rocprim::tuple operator()(custom_type& key) const +/// { +/// return rocprim::tuple(key.i, key.d); +/// } +/// }; +/// +/// // Prepare input and tmp (declare pointers, allocate device memory etc.) +/// size_t input_size; // e.g., 8 +/// custom_type * keys_input; // e.g., [{2, 0.6}, {0, 0.3}, {2, 0.65}, {0, 0.4}, {0, 0.2}, {11, 0.08}, {11, 1.0}, {5, 0.7}] +/// double * values_input; // e.g., [-5, 2, -4, 3, -1, -8, -2, 7] +/// custom_type * keys_tmp; // empty array of 8 elements +/// double* values_tmp; // empty array of 8 elements +/// // Create double-buffers +/// rocprim::double_buffer keys(keys_input, keys_tmp); +/// rocprim::double_buffer values(values_input, values_tmp); +/// +/// size_t temporary_storage_size_bytes; +/// void * temporary_storage_ptr = nullptr; +/// // Get required size of the temporary storage +/// rocprim::radix_sort_pairs( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// keys, values, input_size, custom_type_decomposer{} +/// ); +/// +/// // allocate temporary storage +/// hipMalloc(&temporary_storage_ptr, temporary_storage_size_bytes); +/// +/// // perform sort +/// rocprim::radix_sort_pairs( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// keys, values, input_size, custom_type_decomposer{} +/// ); +/// // keys.current(): [{0, 0.2}, {0, 0.3}, {0, 0.4}, {2, 0.6}, {2, 0.65}, {5, 0.7}, {11, 0.08}, {11, 1.0}] +/// // values.current(): [-1, 2, 3, -5, -4, 7, -8, -2] +/// \endcode +/// \endparblock +template +auto radix_sort_pairs(void* temporary_storage, + size_t& storage_size, + double_buffer& keys, + double_buffer& values, + Size size, + Decomposer decomposer, + hipStream_t stream = 0, + bool debug_synchronous = false) + -> std::enable_if_t::value, hipError_t> +{ + static_assert(std::is_integral::value, "Size must be an integral type."); + bool is_result_in_output; + hipError_t error = detail::radix_sort_impl( + temporary_storage, + storage_size, + keys.current(), + keys.current(), + keys.alternate(), + values.current(), + values.current(), + values.alternate(), + size, + is_result_in_output, + decomposer, + 0, + detail::decomposer_max_bits::value, + stream, + debug_synchronous); + if(temporary_storage != nullptr && error == hipSuccess && is_result_in_output) + { + keys.swap(); + values.swap(); + } + return error; +} + +/// \brief Parallel descending radix sort-by-key primitive for device level. +/// +/// \p radix_sort_pairs_desc function performs a device-wide radix sort +/// of (key, value) pairs. Function sorts input pairs in descending order of keys. +/// +/// \par Overview +/// * The contents of the inputs are not altered by the sorting function. +/// * Returns the required size of \p temporary_storage in \p storage_size +/// if \p temporary_storage is a null pointer. +/// * \p Key type (a \p value_type of \p KeysInputIterator and \p KeysOutputIterator) must be +/// an arithmetic type (that is, an integral type or a floating-point type). +/// * Ranges specified by \p keys_input, \p keys_output, \p values_input and \p values_output must +/// have at least \p size elements. +/// * If \p Key is an integer type and the range of keys is known in advance, the performance +/// can be improved by setting \p begin_bit and \p end_bit, for example if all keys are in range +/// [100, 10000], begin_bit = 0 and end_bit = 14 will cover the whole range. +/// +/// \tparam Config [optional] configuration of the primitive. It has to be \p radix_sort_config or a class derived from it. +/// \tparam KeysInputIterator random-access iterator type of the input range. Must meet the +/// requirements of a C++ InputIterator concept. It can be a simple pointer type. +/// \tparam KeysOutputIterator random-access iterator type of the output range. Must meet the +/// requirements of a C++ OutputIterator concept. It can be a simple pointer type. +/// \tparam ValuesInputIterator random-access iterator type of the input range. Must meet the +/// requirements of a C++ InputIterator concept. It can be a simple pointer type. +/// \tparam ValuesOutputIterator random-access iterator type of the output range. Must meet the +/// requirements of a C++ OutputIterator concept. It can be a simple pointer type. +/// \tparam Size integral type that represents the problem size. +/// +/// \param [in] temporary_storage pointer to a device-accessible temporary storage. When +/// a null pointer is passed, the required allocation size (in bytes) is written to +/// \p storage_size and function returns without performing the sort operation. +/// \param [in,out] storage_size reference to a size (in bytes) of \p temporary_storage. +/// \param [in] keys_input pointer to the first element in the range to sort. +/// \param [out] keys_output pointer to the first element in the output range. +/// \param [in] values_input pointer to the first element in the range to sort. +/// \param [out] values_output pointer to the first element in the output range. +/// \param [in] size number of element in the input range. +/// \param [in] begin_bit [optional] index of the first (least significant) bit used in +/// key comparison. Must be in range [0; 8 * sizeof(Key)). Default value: \p 0. +/// Non-default value not supported for floating-point key-types. +/// \param [in] end_bit [optional] past-the-end index (most significant) bit used in +/// key comparison. Must be in range (begin_bit; 8 * sizeof(Key)]. Default +/// value: \p 8 * sizeof(Key). Non-default value not supported for floating-point key-types. +/// \param [in] stream [optional] HIP stream object. Default is \p 0 (default stream). +/// \param [in] debug_synchronous [optional] If true, synchronization after every kernel +/// launch is forced in order to check for errors. Default value is \p false. +/// +/// \returns \p hipSuccess (\p 0) after successful sort; otherwise a HIP runtime error of +/// type \p hipError_t. +/// +/// \par Example +/// \parblock +/// In this example a device-level descending radix sort is performed where input keys are +/// represented by an array of integers and input values by an array of doubles. +/// +/// \code{.cpp} +/// #include +/// +/// // Prepare input and output (declare pointers, allocate device memory etc.) +/// size_t input_size; // e.g., 8 +/// int * keys_input; // e.g., [ 6, 3, 5, 4, 1, 8, 1, 7] +/// double * values_input; // e.g., [-5, 2, -4, 3, -1, -8, -2, 7] +/// int * keys_output; // empty array of 8 elements +/// double * values_output; // empty array of 8 elements +/// +/// size_t temporary_storage_size_bytes; +/// void * temporary_storage_ptr = nullptr; +/// // Get required size of the temporary storage +/// rocprim::radix_sort_pairs_desc( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// keys_input, keys_output, values_input, values_output, +/// input_size +/// ); +/// +/// // allocate temporary storage +/// hipMalloc(&temporary_storage_ptr, temporary_storage_size_bytes); +/// +/// // perform sort +/// rocprim::radix_sort_pairs_desc( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// keys_input, keys_output, values_input, values_output, +/// input_size +/// ); +/// // keys_output: [ 8, 7, 6, 5, 4, 3, 1, 1] +/// // values_output: [-8, 7, -5, -4, 3, 2, -1, -2] +/// \endcode +/// \endparblock +template::value_type> +hipError_t radix_sort_pairs_desc(void* temporary_storage, + size_t& storage_size, + KeysInputIterator keys_input, + KeysOutputIterator keys_output, + ValuesInputIterator values_input, + ValuesOutputIterator values_output, + Size size, + unsigned int begin_bit = 0, + unsigned int end_bit = 8 * sizeof(Key), + hipStream_t stream = 0, + bool debug_synchronous = false) +{ + static_assert(std::is_integral::value, "Size must be an integral type."); + bool ignored; + return detail::radix_sort_impl(temporary_storage, + storage_size, + keys_input, + nullptr, + keys_output, + values_input, + nullptr, + values_output, + size, + ignored, + identity_decomposer{}, + begin_bit, + end_bit, + stream, + debug_synchronous); +} + +/// \brief Parallel descending radix sort-by-key primitive for device level. +/// +/// \p radix_sort_pairs_desc function performs a device-wide radix sort +/// of (key, value) pairs. Function sorts input pairs in descending order of keys. +/// +/// \par Overview +/// * The contents of both buffers of \p keys and \p values may be altered by the sorting function. +/// * \p current() of \p keys and \p values are used as the input. +/// * The function will update \p current() of \p keys and \p values to point to buffers +/// that contains the output range. +/// * Returns the required size of \p temporary_storage in \p storage_size +/// if \p temporary_storage is a null pointer. +/// * The function requires small \p temporary_storage as it does not need +/// a temporary buffer of \p size elements. +/// * \p Key type must be an arithmetic type (that is, an integral type or a floating-point +/// type). +/// * Buffers of \p keys must have at least \p size elements. +/// * If \p Key is an integer type and the range of keys is known in advance, the performance +/// can be improved by setting \p begin_bit and \p end_bit, for example if all keys are in range +/// [100, 10000], begin_bit = 0 and end_bit = 14 will cover the whole range. +/// +/// \tparam Config [optional] configuration of the primitive. It has to be \p radix_sort_config or a class derived from it. +/// \tparam Key key type. Must be an integral type or a floating-point type. +/// \tparam Value value type. +/// \tparam Size integral type that represents the problem size. +/// +/// \param [in] temporary_storage pointer to a device-accessible temporary storage. When +/// a null pointer is passed, the required allocation size (in bytes) is written to +/// \p storage_size and function returns without performing the sort operation. +/// \param [in,out] storage_size reference to a size (in bytes) of \p temporary_storage. +/// \param [in,out] keys reference to the double-buffer of keys, its \p current() +/// contains the input range and will be updated to point to the output range. +/// \param [in,out] values reference to the double-buffer of values, its \p current() +/// contains the input range and will be updated to point to the output range. +/// \param [in] size number of element in the input range. +/// \param [in] begin_bit [optional] index of the first (least significant) bit used in +/// key comparison. Must be in range [0; 8 * sizeof(Key)). Default value: \p 0. +/// Non-default value not supported for floating-point key-types. +/// \param [in] end_bit [optional] past-the-end index (most significant) bit used in +/// key comparison. Must be in range (begin_bit; 8 * sizeof(Key)]. Default +/// value: \p 8 * sizeof(Key). Non-default value not supported for floating-point key-types. +/// \param [in] stream [optional] HIP stream object. Default is \p 0 (default stream). +/// \param [in] debug_synchronous [optional] If true, synchronization after every kernel +/// launch is forced in order to check for errors. Default value is \p false. +/// +/// \returns \p hipSuccess (\p 0) after successful sort; otherwise a HIP runtime error of +/// type \p hipError_t. +/// +/// \par Example +/// \parblock +/// In this example a device-level descending radix sort is performed where input keys are +/// represented by an array of integers and input values by an array of doubles. /// /// \code{.cpp} /// #include /// /// // Prepare input and tmp (declare pointers, allocate device memory etc.) +/// size_t input_size; // e.g., 8 +/// int * keys_input; // e.g., [ 6, 3, 5, 4, 1, 8, 1, 7] +/// double * values_input; // e.g., [-5, 2, -4, 3, -1, -8, -2, 7] +/// int * keys_tmp; // empty array of 8 elements +/// double * values_tmp; // empty array of 8 elements +/// // Create double-buffers +/// rocprim::double_buffer keys(keys_input, keys_tmp); +/// rocprim::double_buffer values(values_input, values_tmp); +/// +/// size_t temporary_storage_size_bytes; +/// void * temporary_storage_ptr = nullptr; +/// // Get required size of the temporary storage +/// rocprim::radix_sort_pairs_desc( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// keys, values, input_size +/// ); +/// +/// // allocate temporary storage +/// hipMalloc(&temporary_storage_ptr, temporary_storage_size_bytes); +/// +/// // perform sort +/// rocprim::radix_sort_pairs_desc( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// keys, values, input_size +/// ); +/// // keys.current(): [ 8, 7, 6, 5, 4, 3, 1, 1] +/// // values.current(): [-8, 7, -5, -4, 3, 2, -1, -2] +/// \endcode +/// \endparblock +template +hipError_t radix_sort_pairs_desc(void* temporary_storage, + size_t& storage_size, + double_buffer& keys, + double_buffer& values, + Size size, + unsigned int begin_bit = 0, + unsigned int end_bit = 8 * sizeof(Key), + hipStream_t stream = 0, + bool debug_synchronous = false) +{ + static_assert(std::is_integral::value, "Size must be an integral type."); + bool is_result_in_output; + hipError_t error = detail::radix_sort_impl(temporary_storage, + storage_size, + keys.current(), + keys.current(), + keys.alternate(), + values.current(), + values.current(), + values.alternate(), + size, + is_result_in_output, + identity_decomposer{}, + begin_bit, + end_bit, + stream, + debug_synchronous); + if(temporary_storage != nullptr && error == hipSuccess && is_result_in_output) + { + keys.swap(); + values.swap(); + } + return error; +} + +/// \brief Parallel descending radix sort-by-key primitive for device level. +/// +/// \p radix_sort_pairs_desc function performs a device-wide radix sort +/// of (key, value) pairs. Function sorts input pairs in descending order of keys. +/// +/// \par Overview +/// * The contents of the inputs are not altered by the sorting function. +/// * Returns the required size of \p temporary_storage in \p storage_size +/// if \p temporary_storage is a null pointer. +/// * \p Key type (a \p value_type of \p KeysInputIterator and \p KeysOutputIterator) can be any +/// trivially copyable type. +/// * \p decomposer must be a functor that implements `operator()(Key&) const`. This operator +/// must return a \p rocprim::tuple that contains one or more reference to value(s) of arithmetic types. +/// These references must point to member variables of `Key`, however not every member variable has to be +/// exposed this way. +/// * Ranges specified by \p keys_input and \p keys_output must have at least \p size elements. +/// * \p begin_bit and \p end_bit can be provided to control the radix range that is considered in the +/// decomposed tuple. For example, if the decomposer returns `rocprim::tuple`, +/// `begin_bit==6` and `end_bit==12`, then the 2 MSBs of the `uint8_t` value and the 4 LSBs of the +/// `int16_t` value are considered for sorting. The range specified by \p begin_bit and \p end_bit +/// must be valid with regards to the sizes of the return tuple's elements. +/// +/// \tparam Config [optional] configuration of the primitive. It has to be \p radix_sort_config or a class derived from it. +/// \tparam KeysInputIterator random-access iterator type of the input range. Must meet the +/// requirements of a C++ InputIterator concept. It can be a simple pointer type. +/// \tparam KeysOutputIterator random-access iterator type of the output range. Must meet the +/// requirements of a C++ OutputIterator concept. It can be a simple pointer type. +/// \tparam ValuesInputIterator random-access iterator type of the input range. Must meet the +/// requirements of a C++ InputIterator concept. It can be a simple pointer type. +/// \tparam ValuesOutputIterator random-access iterator type of the output range. Must meet the +/// requirements of a C++ OutputIterator concept. It can be a simple pointer type. +/// \tparam Size integral type that represents the problem size. +/// \tparam Key The value type of the input and output iterators. +/// \tparam Decomposer The type of the decomposer functor. +/// +/// \param [in] temporary_storage pointer to a device-accessible temporary storage. When +/// a null pointer is passed, the required allocation size (in bytes) is written to +/// \p storage_size and function returns without performing the sort operation. +/// \param [in,out] storage_size reference to a size (in bytes) of \p temporary_storage. +/// \param [in] keys_input pointer to the first element in the range to sort. +/// \param [out] keys_output pointer to the first element in the output range. +/// \param [in] values_input pointer to the first element in the range to sort. +/// \param [out] values_output pointer to the first element in the output range. +/// \param [in] size number of element in the input range. +/// \param [in] decomposer decomposer functor that produces a tuple of references from the +/// input key type. +/// \param [in] begin_bit index of the first (least significant) bit used in +/// key comparison. +/// \param [in] end_bit past-the-end index (most significant) bit used in +/// key comparison. +/// \param [in] stream [optional] HIP stream object. Default is \p 0 (default stream). +/// \param [in] debug_synchronous [optional] If true, synchronization after every kernel +/// launch is forced in order to check for errors. Default value is \p false. +/// +/// \returns \p hipSuccess (\p 0) after successful sort; otherwise a HIP runtime error of +/// type \p hipError_t. +/// +/// \par Example +/// \parblock +/// In this example a device-level descending radix sort is performed where input keys are +/// represented by an array of a custom type and input values by an array of doubles. +/// +/// \code{.cpp} +/// #include +/// +/// struct custom_type +/// { +/// int i; +/// double d; +/// }; +/// +/// struct custom_type_decomposer +/// { +/// rocprim::tuple operator()(custom_type& key) const +/// { +/// return rocprim::tuple(key.i, key.d); +/// } +/// }; +/// +/// // Prepare input and output (declare pointers, allocate device memory etc.) /// size_t input_size; // e.g., 8 -/// unsigned int * keys_input; // e.g., [ 6, 3, 5, 4, 1, 8, 1, 7] +/// custom_type * keys_input; // e.g., [{2, 0.6}, {0, 0.3}, {2, 0.65}, {0, 0.4}, {0, 0.2}, {11, 0.08}, {11, 1.0}, {5, 0.7}] /// double * values_input; // e.g., [-5, 2, -4, 3, -1, -8, -2, 7] -/// unsigned int * keys_tmp; // empty array of 8 elements -/// double* values_tmp; // empty array of 8 elements +/// custom_type * keys_output; // empty array of 8 elements +/// double * values_output; // empty array of 8 elements +/// +/// // The integer field of the keys is in range 0-11, which can be represented on 4 bits, +/// // while for the double member we must specify full bit range [0; 63]. Therefore begin_bit +/// // is set to 0 and end_bit is set to 68. +/// constexpr unsigned int begin_bit = 0; +/// constexpr unsigned int end_bit = 68; +/// +/// size_t temporary_storage_size_bytes; +/// void * temporary_storage_ptr = nullptr; +/// // Get required size of the temporary storage +/// rocprim::radix_sort_pairs_desc( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// keys_input, keys_output, values_input, values_output, +/// input_size, custom_type_decomposer{}, begin_bit, end_bit +/// ); +/// +/// // allocate temporary storage +/// hipMalloc(&temporary_storage_ptr, temporary_storage_size_bytes); +/// +/// // perform sort +/// rocprim::radix_sort_pairs_desc( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// keys_input, keys_output, values_input, values_output, +/// input_size, custom_type_decomposer{}, begin_bit, end_bit +/// ); +/// // keys_output: [{11, 1.0}, {11, 0.08}, {5, 0.7}, {2, 0.65}, {2, 0.6}, {0, 0.4}, {0, 0.3}, {0, 0.2}] +/// // values_output: [-2, -1, 2, 3, -4, -5, 7, -8] +/// \endcode +/// \endparblock +template::value_type, + class Decomposer> +auto radix_sort_pairs_desc(void* temporary_storage, + size_t& storage_size, + KeysInputIterator keys_input, + KeysOutputIterator keys_output, + ValuesInputIterator values_input, + ValuesOutputIterator values_output, + Size size, + Decomposer decomposer, + unsigned int begin_bit, + unsigned int end_bit, + hipStream_t stream = 0, + bool debug_synchronous = false) + -> std::enable_if_t::value, hipError_t> +{ + static_assert(std::is_integral::value, "Size must be an integral type."); + bool ignored; + return detail::radix_sort_impl(temporary_storage, + storage_size, + keys_input, + nullptr, + keys_output, + values_input, + nullptr, + values_output, + size, + ignored, + decomposer, + begin_bit, + end_bit, + stream, + debug_synchronous); +} + +/// \brief Parallel descending radix sort-by-key primitive for device level. +/// +/// \p radix_sort_pairs_desc function performs a device-wide radix sort +/// of (key, value) pairs. Function sorts input pairs in descending order of keys. +/// +/// \par Overview +/// * The contents of the inputs are not altered by the sorting function. +/// * Returns the required size of \p temporary_storage in \p storage_size +/// if \p temporary_storage is a null pointer. +/// * \p Key type (a \p value_type of \p KeysInputIterator and \p KeysOutputIterator) can be any +/// trivially copyable type. +/// * \p decomposer must be a functor that implements `operator()(Key&) const`. This operator +/// must return a \p rocprim::tuple that contains one or more reference to value(s) of arithmetic types. +/// These references must point to member variables of `Key`, however not every member variable has to be +/// exposed this way. +/// * Ranges specified by \p keys_input and \p keys_output must have at least \p size elements. +/// +/// \tparam Config [optional] configuration of the primitive. It has to be \p radix_sort_config or a class derived from it. +/// \tparam KeysInputIterator random-access iterator type of the input range. Must meet the +/// requirements of a C++ InputIterator concept. It can be a simple pointer type. +/// \tparam KeysOutputIterator random-access iterator type of the output range. Must meet the +/// requirements of a C++ OutputIterator concept. It can be a simple pointer type. +/// \tparam ValuesInputIterator random-access iterator type of the input range. Must meet the +/// requirements of a C++ InputIterator concept. It can be a simple pointer type. +/// \tparam ValuesOutputIterator random-access iterator type of the output range. Must meet the +/// requirements of a C++ OutputIterator concept. It can be a simple pointer type. +/// \tparam Size integral type that represents the problem size. +/// \tparam Key The value type of the input and output iterators. +/// \tparam Decomposer The type of the decomposer functor. +/// +/// \param [in] temporary_storage pointer to a device-accessible temporary storage. When +/// a null pointer is passed, the required allocation size (in bytes) is written to +/// \p storage_size and function returns without performing the sort operation. +/// \param [in,out] storage_size reference to a size (in bytes) of \p temporary_storage. +/// \param [in] keys_input pointer to the first element in the range to sort. +/// \param [out] keys_output pointer to the first element in the output range. +/// \param [in] values_input pointer to the first element in the range to sort. +/// \param [out] values_output pointer to the first element in the output range. +/// \param [in] size number of element in the input range. +/// \param [in] decomposer decomposer functor that produces a tuple of references from the +/// input key type. +/// \param [in] stream [optional] HIP stream object. Default is \p 0 (default stream). +/// \param [in] debug_synchronous [optional] If true, synchronization after every kernel +/// launch is forced in order to check for errors. Default value is \p false. +/// +/// \returns \p hipSuccess (\p 0) after successful sort; otherwise a HIP runtime error of +/// type \p hipError_t. +/// +/// \par Example +/// \parblock +/// In this example a device-level descending radix sort is performed where input keys are +/// represented by an array of a custom type and input values by an array of doubles. +/// +/// \code{.cpp} +/// #include +/// +/// struct custom_type +/// { +/// int i; +/// double d; +/// }; +/// +/// struct custom_type_decomposer +/// { +/// rocprim::tuple operator()(custom_type& key) const +/// { +/// return rocprim::tuple(key.i, key.d); +/// } +/// }; +/// +/// // Prepare input and output (declare pointers, allocate device memory etc.) +/// size_t input_size; // e.g., 8 +/// custom_type * keys_input; // e.g., [{2, 0.6}, {0, 0.3}, {2, 0.65}, {0, 0.4}, {0, 0.2}, {11, 0.08}, {11, 1.0}, {5, 0.7}] +/// double * values_input; // e.g., [-5, 2, -4, 3, -1, -8, -2, 7] +/// custom_type * keys_output; // empty array of 8 elements +/// double * values_output; // empty array of 8 elements +/// +/// size_t temporary_storage_size_bytes; +/// void * temporary_storage_ptr = nullptr; +/// // Get required size of the temporary storage +/// rocprim::radix_sort_pairs_desc( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// keys_input, keys_output, values_input, values_output, +/// input_size, custom_type_decomposer{} +/// ); +/// +/// // allocate temporary storage +/// hipMalloc(&temporary_storage_ptr, temporary_storage_size_bytes); +/// +/// // perform sort +/// rocprim::radix_sort_pairs_desc( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// keys_input, keys_output, values_input, values_output, +/// input_size, custom_type_decomposer{} +/// ); +/// // keys_output: [{11, 1.0}, {11, 0.08}, {5, 0.7}, {2, 0.65}, {2, 0.6}, {0, 0.4}, {0, 0.3}, {0, 0.2}] +/// // values_output: [-2, -1, 2, 3, -4, -5, 7, -8] +/// \endcode +/// \endparblock +template::value_type, + class Decomposer> +auto radix_sort_pairs_desc(void* temporary_storage, + size_t& storage_size, + KeysInputIterator keys_input, + KeysOutputIterator keys_output, + ValuesInputIterator values_input, + ValuesOutputIterator values_output, + Size size, + Decomposer decomposer, + hipStream_t stream = 0, + bool debug_synchronous = false) + -> std::enable_if_t::value, hipError_t> +{ + static_assert(std::is_integral::value, "Size must be an integral type."); + bool ignored; + return detail::radix_sort_impl( + temporary_storage, + storage_size, + keys_input, + nullptr, + keys_output, + values_input, + nullptr, + values_output, + size, + ignored, + decomposer, + 0, + detail::decomposer_max_bits::value, + stream, + debug_synchronous); +} + +/// \brief Parallel descending radix sort-by-key primitive for device level. +/// +/// \p radix_sort_pairs_desc function performs a device-wide radix sort +/// of (key, value) pairs. Function sorts input pairs in descending order of keys. +/// +/// \par Overview +/// * The contents of both buffers of \p keys and \p values may be altered by the sorting function. +/// * \p current() of \p keys and \p values are used as the input. +/// * The function will update \p current() of \p keys and \p values to point to buffers +/// that contains the output range. +/// * Returns the required size of \p temporary_storage in \p storage_size +/// if \p temporary_storage is a null pointer. +/// * The function requires small \p temporary_storage as it does not need +/// a temporary buffer of \p size elements. +/// * \p Key type (a \p value_type of \p KeysInputIterator and \p KeysOutputIterator) can be any +/// trivially copyable type. +/// * \p decomposer must be a functor that implements `operator()(Key&) const`. This operator +/// must return a \p rocprim::tuple that contains one or more reference to value(s) of arithmetic types. +/// These references must point to member variables of `Key`, however not every member variable has to be +/// exposed this way. +/// * Ranges specified by \p keys_input and \p keys_output must have at least \p size elements. +/// * \p begin_bit and \p end_bit can be provided to control the radix range that is considered in the +/// decomposed tuple. For example, if the decomposer returns `rocprim::tuple`, +/// `begin_bit==6` and `end_bit==12`, then the 2 MSBs of the `uint8_t` value and the 4 LSBs of the +/// `int16_t` value are considered for sorting. The range specified by \p begin_bit and \p end_bit +/// must be valid with regards to the sizes of the return tuple's elements. +/// +/// \tparam Config [optional] configuration of the primitive. It has to be \p radix_sort_config or a class derived from it. +/// \tparam Key key type. Must be an integral type or a floating-point type. +/// \tparam Value value type. +/// \tparam Size integral type that represents the problem size. +/// \tparam Decomposer The type of the decomposer functor. +/// +/// \param [in] temporary_storage pointer to a device-accessible temporary storage. When +/// a null pointer is passed, the required allocation size (in bytes) is written to +/// \p storage_size and function returns without performing the sort operation. +/// \param [in,out] storage_size reference to a size (in bytes) of \p temporary_storage. +/// \param [in,out] keys reference to the double-buffer of keys, its \p current() +/// contains the input range and will be updated to point to the output range. +/// \param [in,out] values reference to the double-buffer of values, its \p current() +/// contains the input range and will be updated to point to the output range. +/// \param [in] size number of element in the input range. +/// \param [in] decomposer decomposer functor that produces a tuple of references from the +/// input key type. +/// \param [in] begin_bit [optional] index of the first (least significant) bit used in +/// key comparison. Defaults to `0`. +/// \param [in] end_bit [optional] past-the-end index (most significant) bit used in +/// key comparison. Defaults to the size of the decomposed tuple's bit range. +/// \param [in] stream [optional] HIP stream object. Default is \p 0 (default stream). +/// \param [in] debug_synchronous [optional] If true, synchronization after every kernel +/// launch is forced in order to check for errors. Default value is \p false. +/// +/// \returns \p hipSuccess (\p 0) after successful sort; otherwise a HIP runtime error of +/// type \p hipError_t. +/// +/// \par Example +/// \parblock +/// In this example a device-level descending radix sort is performed where input keys are +/// represented by an array of a custom type and input values by an array of doubles. +/// +/// \code{.cpp} +/// #include +/// +/// struct custom_type +/// { +/// int i; +/// double d; +/// }; +/// +/// struct custom_type_decomposer +/// { +/// rocprim::tuple operator()(custom_type& key) const +/// { +/// return rocprim::tuple(key.i, key.d); +/// } +/// }; +/// +/// // Prepare input and tmp (declare pointers, allocate device memory etc.) +/// size_t input_size; // e.g., 8 +/// custom_type * keys_input; // e.g., [{2, 0.6}, {0, 0.3}, {2, 0.65}, {0, 0.4}, {0, 0.2}, {11, 0.08}, {11, 1.0}, {5, 0.7}] +/// double * values_input; // e.g., [-5, 2, -4, 3, -1, -8, -2, 7] +/// custom_type * keys_tmp; // empty array of 8 elements +/// double * values_tmp; // empty array of 8 elements /// // Create double-buffers -/// rocprim::double_buffer keys(keys_input, keys_tmp); +/// rocprim::double_buffer keys(keys_input, keys_tmp); /// rocprim::double_buffer values(values_input, values_tmp); /// -/// // Keys are in range [0; 8], so we can limit compared bit to bits on indexes -/// // 0, 1, 2, 3, and 4. In order to do this begin_bit is set to 0 and end_bit -/// // is set to 5. +/// // The integer field of the keys is in range 0-11, which can be represented on 4 bits, +/// // while for the double member we must specify full bit range [0; 63]. Therefore begin_bit +/// // is set to 0 and end_bit is set to 68. +/// constexpr unsigned int begin_bit = 0; +/// constexpr unsigned int end_bit = 68; /// /// size_t temporary_storage_size_bytes; /// void * temporary_storage_ptr = nullptr; /// // Get required size of the temporary storage -/// rocprim::radix_sort_pairs( +/// rocprim::radix_sort_pairs_desc( /// temporary_storage_ptr, temporary_storage_size_bytes, -/// keys, values, input_size, -/// 0, 5 +/// keys, values, input_size, custom_type_decomposer{}, begin_bit, end_bit /// ); /// /// // allocate temporary storage /// hipMalloc(&temporary_storage_ptr, temporary_storage_size_bytes); /// /// // perform sort -/// rocprim::radix_sort_pairs( +/// rocprim::radix_sort_pairs_desc( /// temporary_storage_ptr, temporary_storage_size_bytes, -/// keys, values, input_size, -/// 0, 5 +/// keys, values, input_size, custom_type_decomposer{}, begin_bit, end_bit /// ); -/// // keys.current(): [ 1, 1, 3, 4, 5, 6, 7, 8] -/// // values.current(): [-1, -2, 2, 3, -4, -5, 7, -8] +/// // keys.current(): [{11, 1.0}, {11, 0.08}, {5, 0.7}, {2, 0.65}, {2, 0.6}, {0, 0.4}, {0, 0.3}, {0, 0.2}] +/// // values.current(): [-2, -1, 2, 3, -4, -5, 7, -8] /// \endcode /// \endparblock -template< - class Config = default_config, - class Key, - class Value, - class Size -> -inline -hipError_t radix_sort_pairs(void * temporary_storage, - size_t& storage_size, - double_buffer& keys, - double_buffer& values, - Size size, - unsigned int begin_bit = 0, - unsigned int end_bit = 8 * sizeof(Key), - hipStream_t stream = 0, - bool debug_synchronous = false) +template +auto radix_sort_pairs_desc(void* temporary_storage, + size_t& storage_size, + double_buffer& keys, + double_buffer& values, + Size size, + Decomposer decomposer, + unsigned int begin_bit, + unsigned int end_bit, + hipStream_t stream = 0, + bool debug_synchronous = false) + -> std::enable_if_t::value, hipError_t> { static_assert(std::is_integral::value, "Size must be an integral type."); bool is_result_in_output; - hipError_t error = detail::radix_sort_impl( - temporary_storage, storage_size, - keys.current(), keys.current(), keys.alternate(), - values.current(), values.current(), values.alternate(), - size, is_result_in_output, - begin_bit, end_bit, - stream, debug_synchronous - ); + hipError_t error = detail::radix_sort_impl(temporary_storage, + storage_size, + keys.current(), + keys.current(), + keys.alternate(), + values.current(), + values.current(), + values.alternate(), + size, + is_result_in_output, + decomposer, + begin_bit, + end_bit, + stream, + debug_synchronous); if(temporary_storage != nullptr && error == hipSuccess && is_result_in_output) { keys.swap(); @@ -1536,38 +3813,36 @@ hipError_t radix_sort_pairs(void * temporary_storage, /// * The function will update \p current() of \p keys and \p values to point to buffers /// that contains the output range. /// * Returns the required size of \p temporary_storage in \p storage_size -/// if \p temporary_storage in a null pointer. +/// if \p temporary_storage is a null pointer. /// * The function requires small \p temporary_storage as it does not need /// a temporary buffer of \p size elements. -/// * \p Key type must be an arithmetic type (that is, an integral type or a floating-point -/// type). -/// * Buffers of \p keys must have at least \p size elements. -/// * If \p Key is an integer type and the range of keys is known in advance, the performance -/// can be improved by setting \p begin_bit and \p end_bit, for example if all keys are in range -/// [100, 10000], begin_bit = 0 and end_bit = 14 will cover the whole range. +/// * \p Key type (a \p value_type of \p KeysInputIterator and \p KeysOutputIterator) can be any +/// trivially copyable type. +/// * \p decomposer must be a functor that implements `operator()(Key&) const`. This operator +/// must return a \p rocprim::tuple that contains one or more reference to value(s) of arithmetic types. +/// These references must point to member variables of `Key`, however not every member variable has to be +/// exposed this way. +/// * Ranges specified by \p keys_input and \p keys_output must have at least \p size elements. /// -/// \tparam Config - [optional] configuration of the primitive. It has to be \p radix_sort_config or a class derived from it. -/// \tparam Key - key type. Must be an integral type or a floating-point type. -/// \tparam Value - value type. -/// \tparam Size - integral type that represents the problem size. +/// \tparam Config [optional] configuration of the primitive. It has to be \p radix_sort_config or a class derived from it. +/// \tparam Key key type. Must be an integral type or a floating-point type. +/// \tparam Value value type. +/// \tparam Size integral type that represents the problem size. +/// \tparam Decomposer The type of the decomposer functor. /// -/// \param [in] temporary_storage - pointer to a device-accessible temporary storage. When +/// \param [in] temporary_storage pointer to a device-accessible temporary storage. When /// a null pointer is passed, the required allocation size (in bytes) is written to /// \p storage_size and function returns without performing the sort operation. -/// \param [in,out] storage_size - reference to a size (in bytes) of \p temporary_storage. -/// \param [in,out] keys - reference to the double-buffer of keys, its \p current() +/// \param [in,out] storage_size reference to a size (in bytes) of \p temporary_storage. +/// \param [in,out] keys reference to the double-buffer of keys, its \p current() /// contains the input range and will be updated to point to the output range. -/// \param [in,out] values - reference to the double-buffer of values, its \p current() +/// \param [in,out] values reference to the double-buffer of values, its \p current() /// contains the input range and will be updated to point to the output range. -/// \param [in] size - number of element in the input range. -/// \param [in] begin_bit - [optional] index of the first (least significant) bit used in -/// key comparison. Must be in range [0; 8 * sizeof(Key)). Default value: \p 0. -/// Non-default value not supported for floating-point key-types. -/// \param [in] end_bit - [optional] past-the-end index (most significant) bit used in -/// key comparison. Must be in range (begin_bit; 8 * sizeof(Key)]. Default -/// value: \p 8 * sizeof(Key). Non-default value not supported for floating-point key-types. -/// \param [in] stream - [optional] HIP stream object. Default is \p 0 (default stream). -/// \param [in] debug_synchronous - [optional] If true, synchronization after every kernel +/// \param [in] size number of element in the input range. +/// \param [in] decomposer decomposer functor that produces a tuple of references from the +/// input key type. +/// \param [in] stream [optional] HIP stream object. Default is \p 0 (default stream). +/// \param [in] debug_synchronous [optional] If true, synchronization after every kernel /// launch is forced in order to check for errors. Default value is \p false. /// /// \returns \p hipSuccess (\p 0) after successful sort; otherwise a HIP runtime error of @@ -1576,19 +3851,33 @@ hipError_t radix_sort_pairs(void * temporary_storage, /// \par Example /// \parblock /// In this example a device-level descending radix sort is performed where input keys are -/// represented by an array of integers and input values by an array of doubles. +/// represented by an array of a custom type and input values by an array of doubles. /// /// \code{.cpp} /// #include /// +/// struct custom_type +/// { +/// int i; +/// double d; +/// }; +/// +/// struct custom_type_decomposer +/// { +/// rocprim::tuple operator()(custom_type& key) const +/// { +/// return rocprim::tuple(key.i, key.d); +/// } +/// }; +/// /// // Prepare input and tmp (declare pointers, allocate device memory etc.) -/// size_t input_size; // e.g., 8 -/// int * keys_input; // e.g., [ 6, 3, 5, 4, 1, 8, 1, 7] -/// double * values_input; // e.g., [-5, 2, -4, 3, -1, -8, -2, 7] -/// int * keys_tmp; // empty array of 8 elements -/// double * values_tmp; // empty array of 8 elements +/// size_t input_size; // e.g., 8 +/// custom_type * keys_input; // e.g., [{2, 0.6}, {0, 0.3}, {2, 0.65}, {0, 0.4}, {0, 0.2}, {11, 0.08}, {11, 1.0}, {5, 0.7}] +/// double * values_input; // e.g., [-5, 2, -4, 3, -1, -8, -2, 7] +/// custom_type * keys_tmp; // empty array of 8 elements +/// double * values_tmp; // empty array of 8 elements /// // Create double-buffers -/// rocprim::double_buffer keys(keys_input, keys_tmp); +/// rocprim::double_buffer keys(keys_input, keys_tmp); /// rocprim::double_buffer values(values_input, values_tmp); /// /// size_t temporary_storage_size_bytes; @@ -1596,7 +3885,7 @@ hipError_t radix_sort_pairs(void * temporary_storage, /// // Get required size of the temporary storage /// rocprim::radix_sort_pairs_desc( /// temporary_storage_ptr, temporary_storage_size_bytes, -/// keys, values, input_size +/// keys, values, input_size, custom_type_decomposer{} /// ); /// /// // allocate temporary storage @@ -1605,39 +3894,41 @@ hipError_t radix_sort_pairs(void * temporary_storage, /// // perform sort /// rocprim::radix_sort_pairs_desc( /// temporary_storage_ptr, temporary_storage_size_bytes, -/// keys, values, input_size +/// keys, values, input_size, custom_type_decomposer{} /// ); -/// // keys.current(): [ 8, 7, 6, 5, 4, 3, 1, 1] -/// // values.current(): [-8, 7, -5, -4, 3, 2, -1, -2] +/// // keys.current(): [{11, 1.0}, {11, 0.08}, {5, 0.7}, {2, 0.65}, {2, 0.6}, {0, 0.4}, {0, 0.3}, {0, 0.2}] +/// // values.current(): [-2, -1, 2, 3, -4, -5, 7, -8] /// \endcode /// \endparblock -template< - class Config = default_config, - class Key, - class Value, - class Size -> -inline -hipError_t radix_sort_pairs_desc(void * temporary_storage, - size_t& storage_size, - double_buffer& keys, - double_buffer& values, - Size size, - unsigned int begin_bit = 0, - unsigned int end_bit = 8 * sizeof(Key), - hipStream_t stream = 0, - bool debug_synchronous = false) +template +auto radix_sort_pairs_desc(void* temporary_storage, + size_t& storage_size, + double_buffer& keys, + double_buffer& values, + Size size, + Decomposer decomposer, + hipStream_t stream = 0, + bool debug_synchronous = false) + -> std::enable_if_t::value, hipError_t> { static_assert(std::is_integral::value, "Size must be an integral type."); bool is_result_in_output; - hipError_t error = detail::radix_sort_impl( - temporary_storage, storage_size, - keys.current(), keys.current(), keys.alternate(), - values.current(), values.current(), values.alternate(), - size, is_result_in_output, - begin_bit, end_bit, - stream, debug_synchronous - ); + hipError_t error + = detail::radix_sort_impl(temporary_storage, + storage_size, + keys.current(), + keys.current(), + keys.alternate(), + values.current(), + values.current(), + values.alternate(), + size, + is_result_in_output, + decomposer, + 0, + detail::decomposer_max_bits::value, + stream, + debug_synchronous); if(temporary_storage != nullptr && error == hipSuccess && is_result_in_output) { keys.swap(); diff --git a/rocprim/include/rocprim/device/device_reduce_by_key.hpp b/rocprim/include/rocprim/device/device_reduce_by_key.hpp index db63d50d5..25fb166c1 100644 --- a/rocprim/include/rocprim/device/device_reduce_by_key.hpp +++ b/rocprim/include/rocprim/device/device_reduce_by_key.hpp @@ -387,8 +387,11 @@ hipError_t reduce_by_key_impl(void* temporary_storage, /// /// \par Overview /// * Supports non-commutative reduction operators. However, a reduction operator should be -/// associative. When used with non-associative functions the results may be non-deterministic -/// and/or vary in precision. +/// associative. +/// * When used with non-associative functions (e.g. floating point arithmetic operations): +/// - the results may be non-deterministic and/or vary in precision, +/// - and bit-wise reproducibility is not guaranteed, that is, results from multiple runs +/// using the same input values on the same device may not be bit-wise identical. /// * Returns the required size of \p temporary_storage in \p storage_size /// if \p temporary_storage in a null pointer. /// * Ranges specified by \p keys_input and \p values_input must have at least \p size elements. diff --git a/rocprim/include/rocprim/device/device_scan.hpp b/rocprim/include/rocprim/device/device_scan.hpp index 305bd3416..516fed06f 100644 --- a/rocprim/include/rocprim/device/device_scan.hpp +++ b/rocprim/include/rocprim/device/device_scan.hpp @@ -421,8 +421,11 @@ inline auto scan_impl(void* temporary_storage, /// /// \par Overview /// * Supports non-commutative scan operators. However, a scan operator should be -/// associative. When used with non-associative functions the results may be non-deterministic -/// and/or vary in precision. +/// associative. +/// * When used with non-associative functions (e.g. floating point arithmetic operations): +/// - the results may be non-deterministic and/or vary in precision, +/// - and bit-wise reproducibility is not guaranteed, that is, results from multiple runs +/// using the same input values on the same device may not be bit-wise identical. /// * Returns the required size of \p temporary_storage in \p storage_size /// if \p temporary_storage in a null pointer. /// * Ranges specified by \p input and \p output must have at least \p size elements. @@ -559,8 +562,11 @@ inline hipError_t inclusive_scan(void* temporary_storage, /// /// \par Overview /// * Supports non-commutative scan operators. However, a scan operator should be -/// associative. When used with non-associative functions the results may be non-deterministic -/// and/or vary in precision. +/// associative. +/// * When used with non-associative functions (e.g. floating point arithmetic operations): +/// - the results may be non-deterministic and/or vary in precision, +/// - and bit-wise reproducibility is not guaranteed, that is, results from multiple runs +/// using the same input values on the same device may not be bit-wise identical. /// * Returns the required size of \p temporary_storage in \p storage_size /// if \p temporary_storage in a null pointer. /// * Ranges specified by \p input and \p output must have at least \p size elements. diff --git a/rocprim/include/rocprim/device/device_scan_by_key.hpp b/rocprim/include/rocprim/device/device_scan_by_key.hpp index 73d5fcb0b..0b06bec95 100644 --- a/rocprim/include/rocprim/device/device_scan_by_key.hpp +++ b/rocprim/include/rocprim/device/device_scan_by_key.hpp @@ -314,8 +314,11 @@ inline hipError_t scan_by_key_impl(void* const temporary_storage, /// /// \par Overview /// * Supports non-commutative scan operators. However, a scan operator should be -/// associative. When used with non-associative functions the results may be non-deterministic -/// and/or vary in precision. +/// associative. +/// * When used with non-associative functions (e.g. floating point arithmetic operations): +/// - the results may be non-deterministic and/or vary in precision, +/// - and bit-wise reproducibility is not guaranteed, that is, results from multiple runs +/// using the same input values on the same device may not be bit-wise identical. /// * Returns the required size of \p temporary_storage in \p storage_size /// if \p temporary_storage in a null pointer. /// * Ranges specified by \p keys_input, \p values_input, and \p values_output must have @@ -448,8 +451,11 @@ inline hipError_t inclusive_scan_by_key(void* const temporary_sto /// /// \par Overview /// * Supports non-commutative scan operators. However, a scan operator should be -/// associative. When used with non-associative functions the results may be non-deterministic -/// and/or vary in precision. +/// associative. +/// * When used with non-associative functions (e.g. floating point arithmetic operations): +/// - the results may be non-deterministic and/or vary in precision, +/// - and bit-wise reproducibility is not guaranteed, that is, results from multiple runs +/// using the same input values on the same device may not be bit-wise identical. /// * Returns the required size of \p temporary_storage in \p storage_size /// if \p temporary_storage in a null pointer. /// * Ranges specified by \p keys_input, \p values_input, and \p values_output must have diff --git a/rocprim/include/rocprim/device/device_segmented_radix_sort.hpp b/rocprim/include/rocprim/device/device_segmented_radix_sort.hpp index 58089b810..681e31059 100644 --- a/rocprim/include/rocprim/device/device_segmented_radix_sort.hpp +++ b/rocprim/include/rocprim/device/device_segmented_radix_sort.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -28,7 +28,6 @@ #include #include "../config.hpp" -#include "../detail/radix_sort.hpp" #include "../detail/various.hpp" #include "config_types.hpp" @@ -39,6 +38,7 @@ #include "../block/block_load.hpp" #include "../iterator/counting_iterator.hpp" #include "../iterator/reverse_iterator.hpp" +#include "../thread/radix_key_codec.hpp" #include "detail/device_segmented_radix_sort.hpp" #include "device_partition.hpp" #include "device_segmented_radix_sort_config.hpp" diff --git a/rocprim/include/rocprim/device/specialization/device_radix_block_sort.hpp b/rocprim/include/rocprim/device/specialization/device_radix_block_sort.hpp index a63fb9dce..b357bba4a 100644 --- a/rocprim/include/rocprim/device/specialization/device_radix_block_sort.hpp +++ b/rocprim/include/rocprim/device/specialization/device_radix_block_sort.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2021-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2021-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -51,7 +51,8 @@ template + class ValuesOutputIterator, + class Decomposer> ROCPRIM_KERNEL __launch_bounds__(device_params().block_size) void radix_sort_block_sort_kernel( KeysInputIterator keys_input, @@ -59,6 +60,7 @@ ROCPRIM_KERNEL ValuesInputIterator values_input, ValuesOutputIterator values_output, unsigned int size, + Decomposer decomposer, unsigned int bit, unsigned int current_radix_bits) { @@ -68,6 +70,7 @@ ROCPRIM_KERNEL values_input, values_output, size, + decomposer, bit, current_radix_bits); } @@ -77,13 +80,15 @@ template + class ValuesOutputIterator, + class Decomposer> inline hipError_t radix_sort_block_sort(KeysInputIterator keys_input, KeysOutputIterator keys_output, ValuesInputIterator values_input, ValuesOutputIterator values_output, unsigned int size, unsigned int& sort_items_per_block, + Decomposer decomposer, unsigned int bit, unsigned int end_bit, hipStream_t stream, @@ -120,20 +125,19 @@ inline hipError_t radix_sort_block_sort(KeysInputIterator keys_input, // Start point for time measurements std::chrono::high_resolution_clock::time_point start; if(debug_synchronous) + { start = std::chrono::high_resolution_clock::now(); + } - hipLaunchKernelGGL(HIP_KERNEL_NAME(radix_sort_block_sort_kernel), - dim3(sort_number_of_blocks), - dim3(params.block_size), - 0, - stream, - keys_input, - keys_output, - values_input, - values_output, - size, - bit, - current_radix_bits); + radix_sort_block_sort_kernel + <<>>(keys_input, + keys_output, + values_input, + values_output, + size, + decomposer, + bit, + current_radix_bits); ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("radix_sort_block_sort_kernel", size, start) return hipSuccess; } diff --git a/rocprim/include/rocprim/device/specialization/device_radix_merge_sort.hpp b/rocprim/include/rocprim/device/specialization/device_radix_merge_sort.hpp index f55173bbc..252b82a77 100644 --- a/rocprim/include/rocprim/device/specialization/device_radix_merge_sort.hpp +++ b/rocprim/include/rocprim/device/specialization/device_radix_merge_sort.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2021-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2021-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -21,6 +21,7 @@ #ifndef ROCPRIM_DEVICE_SPECIALIZATION_DEVICE_RADIX_MERGE_SORT_HPP_ #define ROCPRIM_DEVICE_SPECIALIZATION_DEVICE_RADIX_MERGE_SORT_HPP_ +#include "../../type_traits.hpp" #include "../detail/device_radix_sort.hpp" #include "../device_merge_sort.hpp" #include "device_radix_block_sort.hpp" @@ -32,6 +33,133 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { +template +auto invoke_merge_sort_block_merge( + void* temporary_storage, + size_t& storage_size, + KeysIterator keys_output, + ValuesIterator values_output, + const OffsetT size, + unsigned int sort_items_per_block, + identity_decomposer decomposer, + unsigned int bit, + unsigned int current_radix_bits, + const hipStream_t stream, + bool debug_synchronous, + typename std::iterator_traits::value_type* keys_buffer, + typename std::iterator_traits::value_type* values_buffer) + -> std::enable_if_t< + !is_floating_point::value_type>::value, + hipError_t> +{ + using key_type = typename std::iterator_traits::value_type; + (void)decomposer; + if(current_radix_bits == sizeof(key_type) * 8) + { + return merge_sort_block_merge(temporary_storage, + storage_size, + keys_output, + values_output, + size, + sort_items_per_block, + radix_merge_compare(), + stream, + debug_synchronous, + keys_buffer, + values_buffer); + } + else + { + return merge_sort_block_merge( + temporary_storage, + storage_size, + keys_output, + values_output, + size, + sort_items_per_block, + radix_merge_compare(bit, current_radix_bits), + stream, + debug_synchronous, + keys_buffer, + values_buffer); + } +} + +template +auto invoke_merge_sort_block_merge( + void* temporary_storage, + size_t& storage_size, + KeysIterator keys_output, + ValuesIterator values_output, + const OffsetT size, + unsigned int sort_items_per_block, + identity_decomposer decomposer, + unsigned int bit, + unsigned int current_radix_bits, + const hipStream_t stream, + bool debug_synchronous, + typename std::iterator_traits::value_type* keys_buffer, + typename std::iterator_traits::value_type* values_buffer) + -> std::enable_if_t< + is_floating_point::value_type>::value, + hipError_t> +{ + using key_type = typename std::iterator_traits::value_type; + (void)decomposer; + (void)bit; + (void)current_radix_bits; + return merge_sort_block_merge(temporary_storage, + storage_size, + keys_output, + values_output, + size, + sort_items_per_block, + radix_merge_compare(), + stream, + debug_synchronous, + keys_buffer, + values_buffer); +} + +template +auto invoke_merge_sort_block_merge( + void* temporary_storage, + size_t& storage_size, + KeysIterator keys_output, + ValuesIterator values_output, + const OffsetT size, + unsigned int sort_items_per_block, + Decomposer decomposer, + unsigned int bit, + unsigned int current_radix_bits, + const hipStream_t stream, + bool debug_synchronous, + typename std::iterator_traits::value_type* keys_buffer, + typename std::iterator_traits::value_type* values_buffer) + -> std::enable_if_t::value, hipError_t> +{ + using key_type = typename std::iterator_traits::value_type; + return merge_sort_block_merge( + temporary_storage, + storage_size, + keys_output, + values_output, + size, + sort_items_per_block, + radix_merge_compare(bit, + current_radix_bits, + decomposer), + stream, + debug_synchronous, + keys_buffer, + values_buffer); +} + /// In device_radix_sort, we use this device_radix_sort_merge_sort specialization only /// for low input sizes (< 1M elements). template -inline hipError_t radix_sort_merge_impl( + class ValuesOutputIterator, + class Decomposer> +hipError_t radix_sort_merge_impl( void* temporary_storage, size_t& storage_size, KeysInputIterator keys_input, @@ -50,6 +179,7 @@ inline hipError_t radix_sort_merge_impl( typename std::iterator_traits::value_type* values_buffer, ValuesOutputIterator values_output, unsigned int size, + Decomposer decomposer, unsigned int bit, unsigned int end_bit, hipStream_t stream, @@ -102,14 +232,16 @@ inline hipError_t radix_sort_merge_impl( if(temporary_storage == nullptr) { - return merge_sort_block_merge( + return invoke_merge_sort_block_merge( temporary_storage, storage_size, keys_output, values_output, size, sort_items_per_block, - radix_merge_compare(), + decomposer, + bit, + current_radix_bits, stream, debug_synchronous, keys_buffer, @@ -128,6 +260,7 @@ inline hipError_t radix_sort_merge_impl( values_output, size, sort_items_per_block, + decomposer, bit, end_bit, stream, @@ -140,36 +273,20 @@ inline hipError_t radix_sort_merge_impl( // ^ sort_items_per_block is now updated if(size > sort_items_per_block) { - if(current_radix_bits == sizeof(key_type) * 8) - { - return merge_sort_block_merge( - temporary_storage, - storage_size, - keys_output, - values_output, - size, - sort_items_per_block, - radix_merge_compare(), - stream, - debug_synchronous, - keys_buffer, - values_buffer); - } - else - { - return merge_sort_block_merge( - temporary_storage, - storage_size, - keys_output, - values_output, - size, - sort_items_per_block, - radix_merge_compare(bit, current_radix_bits), - stream, - debug_synchronous, - keys_buffer, - values_buffer); - } + return invoke_merge_sort_block_merge( + temporary_storage, + storage_size, + keys_output, + values_output, + size, + sort_items_per_block, + decomposer, + bit, + current_radix_bits, + stream, + debug_synchronous, + keys_buffer, + values_buffer); } return hipSuccess; } diff --git a/rocprim/include/rocprim/intrinsics/thread.hpp b/rocprim/include/rocprim/intrinsics/thread.hpp index d5949b601..9c9444b3e 100644 --- a/rocprim/include/rocprim/intrinsics/thread.hpp +++ b/rocprim/include/rocprim/intrinsics/thread.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -341,6 +341,13 @@ namespace detail void memory_fence_device() { ::__threadfence(); + // Hotfix: On GFX10 (Navi 10/RDNA1, Navi 20/RDNA2) ISA and GFX11 ISA (Navi 30 GPUs), + // the compiler emits the L0 and L1 invalidate in the wrong order. + // + // See: https://github.com/llvm/llvm-project/pull/81450 +#if defined(__GFX10__) || defined(__GFX11__) + asm volatile("buffer_gl0_inv"); +#endif } } diff --git a/rocprim/include/rocprim/intrinsics/warp_shuffle.hpp b/rocprim/include/rocprim/intrinsics/warp_shuffle.hpp index c72a3217d..86c107cd4 100644 --- a/rocprim/include/rocprim/intrinsics/warp_shuffle.hpp +++ b/rocprim/include/rocprim/intrinsics/warp_shuffle.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -24,7 +24,8 @@ #include #include "../config.hpp" -#include "thread.hpp" + #include "../detail/various.hpp" + #include "thread.hpp" /// \addtogroup warpmodule /// @{ @@ -34,24 +35,6 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { -#ifdef __HIP_CPU_RT__ -// TODO: consider adding macro checks relaying to std::bit_cast when compiled -// using C++20. -template -typename std::enable_if_t< - sizeof(To) == sizeof(From) && - std::is_trivially_copyable_v && - std::is_trivially_copyable_v, - To> -// constexpr support needs compiler magic -bit_cast(const From& src) noexcept -{ - To dst; - std::memcpy(&dst, &src, sizeof(To)); - return dst; -} -#endif - template ROCPRIM_DEVICE ROCPRIM_INLINE typename std::enable_if::value && (sizeof(T) % sizeof(int) == 0), T>::type @@ -60,11 +43,8 @@ warp_shuffle_op(const T& input, ShuffleOp&& op) constexpr int words_no = (sizeof(T) + sizeof(int) - 1) / sizeof(int); struct V { int words[words_no]; }; -#ifdef __HIP_CPU_RT__ - V a = bit_cast(input); -#else - V a = __builtin_bit_cast(V, input); -#endif + + auto a = ::rocprim::detail::bit_cast(input); ROCPRIM_UNROLL for(int i = 0; i < words_no; i++) @@ -72,11 +52,7 @@ warp_shuffle_op(const T& input, ShuffleOp&& op) a.words[i] = op(a.words[i]); } -#ifdef __HIP_CPU_RT__ - return bit_cast(a); -#else - return __builtin_bit_cast(T, a); -#endif + return ::rocprim::detail::bit_cast(a); } template @@ -254,6 +230,39 @@ T warp_shuffle_xor(const T& input, const int lane_mask, const int width = device ); } +namespace detail +{ + +/// \brief Shuffle XOR for any data type using warp_swizzle. +/// +/// i-th thread in warp obtains \p input from i^lane_mask-th +/// thread in warp. Makes use of of the swizzle instruction for powers of 2 till 16. +/// Defaults to warp_shuffle_xor. +/// +/// Note: The optional \p width parameter must be a power of 2; results are +/// undefined if it is not a power of 2, or it is greater than device_warp_size(). +/// +/// \param v - input to pass to other threads +/// \param mask - mask used for calculating source lane id +/// \param width - logical warp width +template +ROCPRIM_DEVICE ROCPRIM_INLINE V warp_swizzle_shuffle(V& v, + const int mask, + const int width = device_warp_size()) +{ + switch(mask) + { + case 1: return warp_swizzle(v); + case 2: return warp_swizzle(v); + case 4: return warp_swizzle(v); + case 8: return warp_swizzle(v); + case 16: return warp_swizzle(v); + default: return warp_shuffle_xor(v, mask, width); + } +} + +} // namespace detail + /// \brief Permute items across the threads in a warp. /// /// The value from this thread in the warp is permuted to the dst_lane-th diff --git a/rocprim/include/rocprim/iterator.hpp b/rocprim/include/rocprim/iterator.hpp index 65d545828..c5215222d 100644 --- a/rocprim/include/rocprim/iterator.hpp +++ b/rocprim/include/rocprim/iterator.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017-2021 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -28,6 +28,7 @@ #include "iterator/constant_iterator.hpp" #include "iterator/counting_iterator.hpp" #include "iterator/discard_iterator.hpp" +#include "iterator/predicate_iterator.hpp" #ifndef __HIP_CPU_RT__ #include "iterator/texture_cache_iterator.hpp" #endif diff --git a/rocprim/include/rocprim/iterator/arg_index_iterator.hpp b/rocprim/include/rocprim/iterator/arg_index_iterator.hpp index f8241f79f..5b96abb00 100644 --- a/rocprim/include/rocprim/iterator/arg_index_iterator.hpp +++ b/rocprim/include/rocprim/iterator/arg_index_iterator.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017-2021 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -208,7 +208,8 @@ class arg_index_iterator offset_ = 0; } - friend std::ostream& operator<<(std::ostream& os, const arg_index_iterator& /* iter */) + [[deprecated]] friend std::ostream& operator<<(std::ostream& os, + const arg_index_iterator& /* iter */) { return os; } diff --git a/rocprim/include/rocprim/iterator/constant_iterator.hpp b/rocprim/include/rocprim/iterator/constant_iterator.hpp index cd1cb4f7e..468a56512 100644 --- a/rocprim/include/rocprim/iterator/constant_iterator.hpp +++ b/rocprim/include/rocprim/iterator/constant_iterator.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017-2021 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -204,7 +204,7 @@ class constant_iterator return distance_to(other) <= 0; } - friend std::ostream& operator<<(std::ostream& os, const constant_iterator& iter) + [[deprecated]] friend std::ostream& operator<<(std::ostream& os, const constant_iterator& iter) { os << "[" << iter.value_ << "]"; return os; diff --git a/rocprim/include/rocprim/iterator/counting_iterator.hpp b/rocprim/include/rocprim/iterator/counting_iterator.hpp index 168ee5d28..e8e02f57f 100644 --- a/rocprim/include/rocprim/iterator/counting_iterator.hpp +++ b/rocprim/include/rocprim/iterator/counting_iterator.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017-2021 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -207,7 +207,7 @@ class counting_iterator return distance_to(other) <= 0; } - friend std::ostream& operator<<(std::ostream& os, const counting_iterator& iter) + [[deprecated]] friend std::ostream& operator<<(std::ostream& os, const counting_iterator& iter) { os << "[" << iter.value_ << "]"; return os; diff --git a/rocprim/include/rocprim/iterator/discard_iterator.hpp b/rocprim/include/rocprim/iterator/discard_iterator.hpp index 906cfd4dd..0cfae3407 100644 --- a/rocprim/include/rocprim/iterator/discard_iterator.hpp +++ b/rocprim/include/rocprim/iterator/discard_iterator.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017-2021 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -202,7 +202,8 @@ class discard_iterator return index_ >= other.index_; } - friend std::ostream& operator<<(std::ostream& os, const discard_iterator& /* iter */) + [[deprecated]] friend std::ostream& operator<<(std::ostream& os, + const discard_iterator& /* iter */) { return os; } diff --git a/rocprim/include/rocprim/iterator/predicate_iterator.hpp b/rocprim/include/rocprim/iterator/predicate_iterator.hpp new file mode 100644 index 000000000..d24088e53 --- /dev/null +++ b/rocprim/include/rocprim/iterator/predicate_iterator.hpp @@ -0,0 +1,300 @@ +// Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_ITERATOR_PREDICATE_ITERATOR_HPP_ +#define ROCPRIM_ITERATOR_PREDICATE_ITERATOR_HPP_ + +#include "../config.hpp" + +#include +#include + +/// \addtogroup iteratormodule +/// @{ + +BEGIN_ROCPRIM_NAMESPACE + +/// \class predicate_iterator +/// \brief A random-access iterator which can discard values assigned to it upon dereference based on a predicate. +/// +/// \par Overview +/// * ``predicate_iterator`` can be used to ignore certain input or output of algorithms. +/// * When writing to ``predicate_iterator``, it will only write to the underlying iterator if the predicate holds. +/// Otherwise it will discard the value. +/// * When reading from ``predicate_iterator``, it will only read from the underlying iterator if the predicate holds. +/// Otherwise it will return the default constructed value. +/// +/// \tparam DataIterator Type of the data iterator that will be forwarded upon dereference. +/// \tparam PredicateDataIterator Type of the test iterator used to test the predicate function. +/// \tparam UnaryPredicate Type of the predicate function that tests the test. +template +class predicate_iterator +{ +public: + /// \brief The type of the value that can be obtained by dereferencing the iterator. + using value_type = typename std::iterator_traits::value_type; + + /// \brief A reference type of the type iterated over (``value_type``). + using reference = typename std::iterator_traits::reference; + + /// \brief A pointer type of the type iterated over (``value_type``). + using pointer = typename std::iterator_traits::pointer; + + /// \brief A type used for identify distance between iterators. + using difference_type = typename std::iterator_traits::difference_type; + + /// \brief The category of the iterator. + using iterator_category = std::random_access_iterator_tag; + + /// \brief Assignable proxy for values in ``DataIterator``. + struct proxy + { + public: + /// \brief The return type on the dereference operator. This may be different than ``reference``. + using capture_t = decltype(*std::declval()); + + /// \brief Constructs a ``proxy`` object with the given reference and keep flag. + /// \param val The value or reference to be captured. + /// \param keep Boolean flag that indicates whether to keep the reference. + ROCPRIM_HOST_DEVICE ROCPRIM_INLINE proxy(capture_t val, const bool keep) + : underlying_(val), keep_(keep) + {} + + /// \brief Assigns a value to the held reference if the keep flag is ``true``. + /// \param value The value to assign to the captured value. + /// \return A reference to the (possibly) modified ``proxy`` object. + ROCPRIM_HOST_DEVICE ROCPRIM_INLINE proxy& operator=(const value_type& value) + { + if(keep_) + { + underlying_ = value; + } + return *this; + } + + /// \brief Converts the ``proxy`` to the underlying value type. + /// \return The referenced value or the default-constructed value. + ROCPRIM_HOST_DEVICE ROCPRIM_INLINE operator value_type() const + { + return keep_ ? underlying_ : value_type{}; + } + + private: + /// \brief The reference or value being held. + capture_t underlying_; + + /// \brief Boolean flag indicating whether to keep the reference or discard it. + bool keep_; + }; + + /// \brief Creates a new predicate_iterator. + /// + /// \param data_iterator The data iterator that will be forwarded whenever the predicate is true. + /// \param predicate_iterator The test iterator that is used to test the predicate on. + /// \param predicate Unary function used to select values obtained. + /// from range pointed by \p iterator. + ROCPRIM_HOST_DEVICE ROCPRIM_INLINE predicate_iterator(DataIterator data_iterator, + PredicateDataIterator predicate_iterator, + UnaryPredicate predicate) + : data_it_(data_iterator), predicate_data_it_(predicate_iterator), predicate_(predicate) + {} + +#ifndef DOXYGEN_SHOULD_SKIP_THIS + + ROCPRIM_HOST_DEVICE ROCPRIM_INLINE predicate_iterator& operator++() + { + data_it_++; + predicate_data_it_++; + return *this; + } + + ROCPRIM_HOST_DEVICE ROCPRIM_INLINE predicate_iterator operator++(int) + { + predicate_iterator old = *this; + data_it_++; + predicate_data_it_++; + return old; + } + + ROCPRIM_HOST_DEVICE ROCPRIM_INLINE predicate_iterator& operator--() + { + data_it_--; + predicate_data_it_--; + return *this; + } + + ROCPRIM_HOST_DEVICE ROCPRIM_INLINE predicate_iterator operator--(int) + { + predicate_iterator old = *this; + data_it_--; + predicate_data_it_--; + return old; + } + + ROCPRIM_HOST_DEVICE ROCPRIM_INLINE proxy operator*() + { + return proxy(*data_it_, predicate_(*predicate_data_it_)); + } + + ROCPRIM_HOST_DEVICE ROCPRIM_INLINE proxy operator->() + { + return *(*this); + } + + ROCPRIM_HOST_DEVICE ROCPRIM_INLINE proxy operator[](difference_type distance) + { + return *(*this + distance); + } + + ROCPRIM_HOST_DEVICE ROCPRIM_INLINE predicate_iterator operator+(difference_type distance) const + { + return predicate_iterator(data_it_ + distance, predicate_data_it_ + distance, predicate_); + } + + ROCPRIM_HOST_DEVICE ROCPRIM_INLINE predicate_iterator& operator+=(difference_type distance) + { + data_it_ += distance; + predicate_data_it_ += distance; + return *this; + } + + ROCPRIM_HOST_DEVICE ROCPRIM_INLINE predicate_iterator operator-(difference_type distance) const + { + return predicate_iterator(data_it_ - distance, predicate_data_it_ - distance, predicate_); + } + + ROCPRIM_HOST_DEVICE ROCPRIM_INLINE predicate_iterator& operator-=(difference_type distance) + { + data_it_ -= distance; + predicate_data_it_ -= distance; + return *this; + } + + ROCPRIM_HOST_DEVICE ROCPRIM_INLINE difference_type operator-(predicate_iterator other) const + { + return data_it_ - other.data_it_; + } + + ROCPRIM_HOST_DEVICE ROCPRIM_INLINE bool operator==(predicate_iterator other) const + { + return data_it_ == other.data_it_; + } + + ROCPRIM_HOST_DEVICE ROCPRIM_INLINE bool operator!=(predicate_iterator other) const + { + return data_it_ != other.data_it_; + } + + ROCPRIM_HOST_DEVICE ROCPRIM_INLINE bool operator<(predicate_iterator other) const + { + return data_it_ < other.data_it_; + } + + ROCPRIM_HOST_DEVICE ROCPRIM_INLINE bool operator<=(predicate_iterator other) const + { + return data_it_ <= other.data_it_; + } + + ROCPRIM_HOST_DEVICE ROCPRIM_INLINE bool operator>(predicate_iterator other) const + { + return data_it_ > other.data_it_; + } + + ROCPRIM_HOST_DEVICE ROCPRIM_INLINE bool operator>=(predicate_iterator other) const + { + return data_it_ >= other.data_it_; + } +#endif // DOXYGEN_SHOULD_SKIP_THIS + +private: + DataIterator data_it_; + PredicateDataIterator predicate_data_it_; + UnaryPredicate predicate_; +}; + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +template +ROCPRIM_HOST_DEVICE inline predicate_iterator + operator+( + typename predicate_iterator:: + difference_type distance, + const predicate_iterator& iterator) +{ + return iterator + distance; +} +#endif // DOXYGEN_SHOULD_SKIP_THIS + +/// \brief Constructs a ``predicate_iterator`` which can discard values assigned to it upon dereference based on a predicate. +/// +/// \tparam DataIterator Type of ``data_iterator``. +/// \tparam PredicateDataIterator Type of ``predicate_data_iterator``. +/// \tparam UnaryPredicate Type of ``predicate``. +/// +/// \param data_iterator The data iterator that will be forwarded whenever the predicate is true. +/// \param predicate_data_iterator The test iterator that is used to test the predicate on. +/// \param predicate The predicate function. +template +auto make_predicate_iterator(DataIterator data_iterator, + PredicateDataIterator predicate_data_iterator, + UnaryPredicate predicate) +{ + return predicate_iterator( + data_iterator, + predicate_data_iterator, + predicate); +} + +/// \brief Constructs a ``predicate_iterator`` which can discard values assigned to it upon dereference based on a predicate. +/// +/// \tparam DataIterator Type of ``data_iterator``. +/// \tparam UnaryPredicate Type of ``predicate``. +/// +/// \param data_iterator The data iterator that will be forwarded whenever the predicate is true. +/// \param predicate The predicate function. It will be tested on ``data_iterator``. +template +ROCPRIM_HOST_DEVICE inline predicate_iterator + make_predicate_iterator(DataIterator data_iterator, UnaryPredicate predicate) +{ + return make_predicate_iterator(data_iterator, + data_iterator, + predicate); +} + +/// \brief Constructs a ``predicate_iterator`` which can discard values assigned to it upon dereference based on a predicate. +/// +/// \tparam DataIterator Type of ``data_iterator``. +/// \tparam FlagIterator Type of ``flag_iterator``. Its ``value_type`` should be implicitely be convertible to ``bool``. +/// +/// \param data_iterator The data iterator that will be forwarded when the corresponding flag is set to ``true``. +/// \param flag_iterator The flag iterator. +template +auto make_mask_iterator(DataIterator data_iterator, FlagIterator flag_iterator) +{ + return make_predicate_iterator(data_iterator, + flag_iterator, + [] ROCPRIM_HOST_DEVICE(bool value) { return value; }); +} + +END_ROCPRIM_NAMESPACE + +/// @} +// end of group iteratormodule + +#endif // ROCPRIM_ITERATOR_PREDICATE_ITERATOR_HPP_ diff --git a/rocprim/include/rocprim/iterator/texture_cache_iterator.hpp b/rocprim/include/rocprim/iterator/texture_cache_iterator.hpp index af7a53cb3..b4b013d8c 100644 --- a/rocprim/include/rocprim/iterator/texture_cache_iterator.hpp +++ b/rocprim/include/rocprim/iterator/texture_cache_iterator.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017-2021 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -330,7 +330,8 @@ class texture_cache_iterator return (ptr - other.ptr) <= 0; } - friend std::ostream& operator<<(std::ostream& os, const texture_cache_iterator& /* iter */) + [[deprecated]] friend std::ostream& operator<<(std::ostream& os, + const texture_cache_iterator& /* iter */) { return os; } diff --git a/rocprim/include/rocprim/iterator/transform_iterator.hpp b/rocprim/include/rocprim/iterator/transform_iterator.hpp index e98b8f03d..4f04eda36 100644 --- a/rocprim/include/rocprim/iterator/transform_iterator.hpp +++ b/rocprim/include/rocprim/iterator/transform_iterator.hpp @@ -206,7 +206,8 @@ class transform_iterator return iterator_ >= other.iterator_; } - friend std::ostream& operator<<(std::ostream& os, const transform_iterator& /* iter */) + [[deprecated]] friend std::ostream& operator<<(std::ostream& os, + const transform_iterator& /* iter */) { return os; } diff --git a/rocprim/include/rocprim/iterator/zip_iterator.hpp b/rocprim/include/rocprim/iterator/zip_iterator.hpp index 9f0dfaa88..63acd5eb2 100644 --- a/rocprim/include/rocprim/iterator/zip_iterator.hpp +++ b/rocprim/include/rocprim/iterator/zip_iterator.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017-2021 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -283,7 +283,7 @@ class zip_iterator return !(*this < other); } - friend std::ostream& operator<<(std::ostream& os, const zip_iterator& /* iter */) + [[deprecated]] friend std::ostream& operator<<(std::ostream& os, const zip_iterator& /* iter */) { return os; } diff --git a/rocprim/include/rocprim/rocprim.hpp b/rocprim/include/rocprim/rocprim.hpp index 5b3dd0e1d..8e7937a38 100644 --- a/rocprim/include/rocprim/rocprim.hpp +++ b/rocprim/include/rocprim/rocprim.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -36,6 +36,8 @@ #include "type_traits.hpp" #include "iterator.hpp" +#include "thread/radix_key_codec.hpp" + #include "warp/warp_reduce.hpp" #include "warp/warp_scan.hpp" #include "warp/warp_sort.hpp" @@ -52,6 +54,7 @@ #include "device/device_adjacent_difference.hpp" #include "device/device_binary_search.hpp" +#include "device/device_copy.hpp" #include "device/device_histogram.hpp" #include "device/device_memcpy.hpp" #include "device/device_merge.hpp" diff --git a/rocprim/include/rocprim/thread/radix_key_codec.hpp b/rocprim/include/rocprim/thread/radix_key_codec.hpp new file mode 100644 index 000000000..78a05b0fa --- /dev/null +++ b/rocprim/include/rocprim/thread/radix_key_codec.hpp @@ -0,0 +1,726 @@ +// Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_THREAD_RADIX_KEY_CODEC_HPP_ +#define ROCPRIM_THREAD_RADIX_KEY_CODEC_HPP_ + +#include +#include +#include + +#include "../config.hpp" +#include "../detail/various.hpp" +#include "../functional.hpp" +#include "../type_traits.hpp" +#include "../types.hpp" +#include "../types/tuple.hpp" + +/// \addtogroup threadmodule +/// @{ + +BEGIN_ROCPRIM_NAMESPACE + +namespace detail +{ + +// Encode and decode integral and floating point values for radix sort in such a way that preserves +// correct order of negative and positive keys (i.e. negative keys go before positive ones, +// which is not true for a simple reinterpetation of the key's bits). + +// Digit extractor takes into account that (+0.0 == -0.0) is true for floats, +// so both +0.0 and -0.0 are reflected into the same bit pattern for digit extraction. +// Maximum digit length is 32. + +template +struct radix_key_codec_integral +{}; + +template +struct radix_key_codec_integral::value>::type> +{ + using bit_key_type = BitKey; + + ROCPRIM_HOST_DEVICE static bit_key_type encode(Key key) + { + return ::rocprim::detail::bit_cast(key); + } + + ROCPRIM_HOST_DEVICE static Key decode(bit_key_type bit_key) + { + return ::rocprim::detail::bit_cast(bit_key); + } + + template + ROCPRIM_HOST_DEVICE static unsigned int + extract_digit(bit_key_type bit_key, unsigned int start, unsigned int length) + { + unsigned int mask = (1u << length) - 1; + return static_cast(bit_key >> start) & mask; + } +}; + +template +struct radix_key_codec_integral< + Key, + BitKey, + typename std::enable_if::value>::type> +{ + using bit_key_type = BitKey; + + ROCPRIM_HOST_DEVICE static bit_key_type encode(Key key) + { + return ::rocprim::detail::bit_cast(key); + } + + ROCPRIM_HOST_DEVICE static Key decode(bit_key_type bit_key) + { + return ::rocprim::detail::bit_cast(bit_key); + } + + template + ROCPRIM_HOST_DEVICE static unsigned int + extract_digit(bit_key_type bit_key, unsigned int start, unsigned int length) + { + unsigned int mask = (1u << length) - 1; + return static_cast(bit_key >> start) & mask; + } +}; + +template +struct radix_key_codec_integral::value>::type> +{ + using bit_key_type = BitKey; + + static constexpr bit_key_type sign_bit = bit_key_type(1) << (sizeof(bit_key_type) * 8 - 1); + + ROCPRIM_HOST_DEVICE static bit_key_type encode(Key key) + { + const auto bit_key = ::rocprim::detail::bit_cast(key); + return sign_bit ^ bit_key; + } + + ROCPRIM_HOST_DEVICE static Key decode(bit_key_type bit_key) + { + bit_key ^= sign_bit; + return ::rocprim::detail::bit_cast(bit_key); + } + + template + ROCPRIM_HOST_DEVICE static unsigned int + extract_digit(bit_key_type bit_key, unsigned int start, unsigned int length) + { + unsigned int mask = (1u << length) - 1; + return static_cast(bit_key >> start) & mask; + } +}; + +template +struct radix_key_codec_integral::value>::type> +{ + using bit_key_type = BitKey; + + static constexpr bit_key_type sign_bit = bit_key_type(1) << (sizeof(bit_key_type) * 8 - 1); + + ROCPRIM_HOST_DEVICE static bit_key_type encode(Key key) + { + const auto bit_key = ::rocprim::detail::bit_cast(key); + return sign_bit ^ bit_key; + } + + ROCPRIM_HOST_DEVICE static Key decode(bit_key_type bit_key) + { + bit_key ^= sign_bit; + return ::rocprim::detail::bit_cast(bit_key); + } + + template + ROCPRIM_HOST_DEVICE static unsigned int + extract_digit(bit_key_type bit_key, unsigned int start, unsigned int length) + { + unsigned int mask = (1u << length) - 1; + return static_cast(bit_key >> start) & mask; + } +}; + +template +struct radix_key_codec_floating +{ + using bit_key_type = BitKey; + + static constexpr bit_key_type sign_bit = ::rocprim::detail::float_bit_mask::sign_bit; + + ROCPRIM_HOST_DEVICE ROCPRIM_INLINE static bit_key_type encode(Key key) + { + bit_key_type bit_key = ::rocprim::detail::bit_cast(key); + bit_key ^= (sign_bit & bit_key) == 0 ? sign_bit : bit_key_type(-1); + return bit_key; + } + + ROCPRIM_HOST_DEVICE ROCPRIM_INLINE static Key decode(bit_key_type bit_key) + { + bit_key ^= (sign_bit & bit_key) == 0 ? bit_key_type(-1) : sign_bit; + return ::rocprim::detail::bit_cast(bit_key); + } + + template + ROCPRIM_HOST_DEVICE static unsigned int + extract_digit(bit_key_type bit_key, unsigned int start, unsigned int length) + { + unsigned int mask = (1u << length) - 1; + + // radix_key_codec_floating::encode() maps 0.0 to 0x8000'0000, + // and -0.0 to 0x7FFF'FFFF. + // radix_key_codec::encode() then flips the bits if descending, yielding: + // value | descending | ascending | + // ----- | ----------- | ----------- | + // 0.0 | 0x7FFF'FFFF | 0x8000'0000 | + // -0.0 | 0x8000'0000 | 0x7FFF'FFFF | + // + // For ascending sort, both should be mapped to 0x8000'0000, + // and for descending sort, both should be mapped to 0x7FFF'FFFF. + if ROCPRIM_IF_CONSTEXPR(Descending) + { + bit_key = bit_key == sign_bit ? static_cast(~sign_bit) : bit_key; + } + else + { + bit_key = bit_key == static_cast(~sign_bit) ? sign_bit : bit_key; + } + return static_cast(bit_key >> start) & mask; + } +}; + +template +struct radix_key_codec_base +{ + // Non-fundamental keys (custom keys) will not use any specialization and thus they do not + // have any of the struct members that fundamental types have. +}; + +template +struct radix_key_codec_base::value>::type> + : radix_key_codec_integral::type> +{}; + +template +struct radix_key_codec_base::value>::type> + : radix_key_codec_integral +{}; + +template +struct radix_key_codec_base::value>::type> + : radix_key_codec_integral +{}; + +template<> +struct radix_key_codec_base +{ + using bit_key_type = unsigned char; + + ROCPRIM_HOST_DEVICE static bit_key_type encode(bool key) + { + return static_cast(key); + } + + ROCPRIM_HOST_DEVICE static bool decode(bit_key_type bit_key) + { + return static_cast(bit_key); + } + + template + ROCPRIM_HOST_DEVICE static unsigned int + extract_digit(bit_key_type bit_key, unsigned int start, unsigned int length) + { + unsigned int mask = (1u << length) - 1; + return static_cast(bit_key >> start) & mask; + } +}; + +template<> +struct radix_key_codec_base<::rocprim::half> + : radix_key_codec_floating<::rocprim::half, unsigned short> +{}; + +template<> +struct radix_key_codec_base<::rocprim::bfloat16> + : radix_key_codec_floating<::rocprim::bfloat16, unsigned short> +{}; + +template<> +struct radix_key_codec_base : radix_key_codec_floating +{}; + +template<> +struct radix_key_codec_base : radix_key_codec_floating +{}; + +template +struct radix_key_fundamental +{ + static constexpr bool value = false; +}; + +template +struct radix_key_fundamental< + T, + ::rocprim::detail::void_t::bit_key_type>> +{ + static constexpr bool value = true; +}; + +} // end namespace detail + +/// \brief Key encoder, decoder and bit-extractor for radix-based sorts. +/// +/// \tparam Key Type of the key used. +/// \tparam Descending Whether the sort is increasing or decreasing. +template::value> +class radix_key_codec : protected ::rocprim::detail::radix_key_codec_base +{ + using base_type = ::rocprim::detail::radix_key_codec_base; + +public: + /// \brief Type of the encoded key. + using bit_key_type = typename base_type::bit_key_type; + + /// \brief Encodes a key of type \p Key into \p bit_key_type. + /// + /// \tparam Decomposer Being \p Key a fundamental type, \p Decomposer should be + /// \p identity_decomposer. This is also the type by default. + /// \param [in] key Key to encode. + /// \param [in] decomposer [optional] Decomposer functor. + /// \return A \p bit_key_type encoded key. + template + ROCPRIM_HOST_DEVICE static bit_key_type encode(Key key, Decomposer decomposer = {}) + { + static_assert(std::is_same::value, + "Fundamental types don't use custom decomposer."); + bit_key_type bit_key = base_type::encode(key); + return Descending ? ~bit_key : bit_key; + } + + /// \brief Encodes in-place a key of type \p Key. + /// + /// \tparam Decomposer Being \p Key a fundamental type, \p Decomposer should be + /// \p identity_decomposer. This is also the type by default. + /// \param [in, out] key Key to encode. + /// \param [in] decomposer [optional] Decomposer functor. + template + ROCPRIM_HOST_DEVICE static void encode_inplace(Key& key, Decomposer decomposer = {}) + { + static_assert(std::is_same::value, + "Fundamental types don't use custom decomposer."); + key = ::rocprim::detail::bit_cast(encode(key)); + } + + /// \brief Decodes an encoded key of type \p bit_key_type back into \p Key. + /// + /// \tparam Decomposer Being \p Key a fundamental type, \p Decomposer should be + /// \p identity_decomposer. This is also the type by default. + /// \param [in] bit_key Key to decode. + /// \param [in] decomposer [optional] Decomposer functor. + /// \return A \p Key decoded key. + template + ROCPRIM_HOST_DEVICE static Key decode(bit_key_type bit_key, Decomposer decomposer = {}) + { + static_assert(std::is_same::value, + "Fundamental types don't use custom decomposer."); + bit_key = Descending ? ~bit_key : bit_key; + return base_type::decode(bit_key); + } + + /// \brief Decodes in-place an encoded key of type \p Key. + /// + /// \tparam Decomposer Being \p Key a fundamental type, \p Decomposer should be + /// \p identity_decomposer. This is also the type by default. + /// \param [in, out] key Key to decode. + /// \param [in] decomposer [optional] Decomposer functor. + template + ROCPRIM_HOST_DEVICE static void decode_inplace(Key& key, Decomposer decomposer = {}) + { + static_assert(std::is_same::value, + "Fundamental types don't use custom decomposer."); + key = decode(::rocprim::detail::bit_cast(key)); + } + + /// \brief Extracts the specified bits from a given encoded key. + /// + /// \param [in] bit_key Encoded key. + /// \param [in] start Start bit of the sequence of bits to extract. + /// \param [in] radix_bits How many bits to extract. + /// \return Requested bits from the key. + ROCPRIM_HOST_DEVICE static unsigned int + extract_digit(bit_key_type bit_key, unsigned int start, unsigned int radix_bits) + { + return base_type::template extract_digit(bit_key, start, radix_bits); + } + + /// \brief Extracts the specified bits from a given in-place encoded key. + /// + /// \tparam Decomposer Being \p Key a fundamental type, \p Decomposer should be + /// \p identity_decomposer. This is also the type by default. + /// \param [in] key Key. + /// \param [in] start Start bit of the sequence of bits to extract. + /// \param [in] radix_bits How many bits to extract. + /// \param [in] decomposer [optional] Decomposer functor. + /// \return Requested bits from the key. + template + ROCPRIM_HOST_DEVICE static unsigned int extract_digit(Key key, + unsigned int start, + unsigned int radix_bits, + Decomposer decomposer = {}) + { + static_assert(std::is_same::value, + "Fundamental types don't use custom decomposer."); + return extract_digit(::rocprim::detail::bit_cast(key), start, radix_bits); + } + + /// \brief Gives the default value for out-of-bound keys of type \p Key. + /// + /// \tparam Decomposer Being \p Key a fundamental type, \p Decomposer should be + /// \p identity_decomposer. This is also the type by default. + /// \param [in] decomposer [optional] Decomposer functor. + /// \return Out-of-bound keys' value. + template + ROCPRIM_HOST_DEVICE static Key get_out_of_bounds_key(Decomposer decomposer = {}) + { + static_assert(std::is_same::value, + "Fundamental types don't use custom decomposer."); + return decode(static_cast(-1)); + } +}; + +#ifndef DOXYGEN_SHOULD_SKIP_THIS // skip specializations +template +class radix_key_codec : protected detail::radix_key_codec_base +{ + using base_type = detail::radix_key_codec_base; + +public: + using bit_key_type = typename base_type::bit_key_type; + + template + ROCPRIM_HOST_DEVICE static bit_key_type encode(bool key, Decomposer decomposer = {}) + { + static_assert(std::is_same::value, + "Fundamental types don't use custom decomposer."); + return Descending != key; + } + + template + ROCPRIM_HOST_DEVICE static void encode_inplace(bool& key, Decomposer decomposer = {}) + { + static_assert(std::is_same::value, + "Fundamental types don't use custom decomposer."); + key = ::rocprim::detail::bit_cast(encode(key)); + } + + template + ROCPRIM_HOST_DEVICE static bool decode(bit_key_type bit_key, Decomposer decomposer = {}) + { + static_assert(std::is_same::value, + "Fundamental types don't use custom decomposer."); + const bool key_value = bit_key; + return Descending != key_value; + } + + template + ROCPRIM_HOST_DEVICE static void decode_inplace(bool& key, Decomposer decomposer = {}) + { + static_assert(std::is_same::value, + "Fundamental types don't use custom decomposer."); + key = decode(::rocprim::detail::bit_cast(key)); + } + + ROCPRIM_HOST_DEVICE static unsigned int + extract_digit(bit_key_type bit_key, unsigned int start, unsigned int radix_bits) + { + return base_type::template extract_digit(bit_key, start, radix_bits); + } + + template + ROCPRIM_HOST_DEVICE static unsigned int extract_digit(bool key, + unsigned int start, + unsigned int radix_bits, + Decomposer decomposer = {}) + { + static_assert(std::is_same::value, + "Fundamental types don't use custom decomposer."); + return extract_digit(::rocprim::detail::bit_cast(key), start, radix_bits); + } + + template + ROCPRIM_HOST_DEVICE static bool get_out_of_bounds_key(Decomposer decomposer = {}) + { + static_assert(std::is_same::value, + "Fundamental types don't use custom decomposer."); + return decode(static_cast(-1)); + } +}; +#endif // DOXYGEN_SHOULD_SKIP_THIS + +/// \brief Key encoder, decoder and bit-extractor for radix-based sorts with custom key types. +/// +/// \tparam Key Type of the key used. +/// \tparam Descending Whether the sort is increasing or decreasing.template +template +class radix_key_codec +{ +public: + /// \brief The key in this case is a custom type, so \p bit_key_type cannot be the type of the + /// encoded key because it depends on the decomposer used. It is thus set as the type \p Key. + using bit_key_type = Key; + + /// \brief Encodes a key of type \p Key into \p bit_key_type. + /// + /// \tparam Decomposer Decomposer functor type. Being \p Key a custom key type, the decomposer + /// type must be other than the \p identity_decomposer. + /// \param [in] key Key to encode. + /// \param [in] decomposer [optional] \p Key is a custom key type, so a custom decomposer + /// functor that returns a \p ::rocprim::tuple of references to fundamental types from a + /// \p Key key is needed. + /// \return A \p bit_key_type encoded key. + template + ROCPRIM_HOST_DEVICE static bit_key_type encode(Key key, Decomposer decomposer = {}) + { + encode_inplace(key, decomposer); + return static_cast(key); + } + + /// \brief Encodes in-place a key of type \p Key. + /// + /// \tparam Decomposer Decomposer functor type. Being \p Key a custom key type, the decomposer + /// type must be other than the \p identity_decomposer. + /// \param [in, out] key Key to encode. + /// \param [in] decomposer [optional] \p Key is a custom key type, so a custom decomposer + /// functor that returns a \p ::rocprim::tuple of references to fundamental types from a + /// \p Key key is needed. + template + ROCPRIM_HOST_DEVICE static void encode_inplace(Key& key, Decomposer decomposer = {}) + { + static_assert(!std::is_same::value, + "The decomposer of a custom-type key cannot be the identity decomposer."); + static_assert(::rocprim::detail::is_tuple_of_references::value, + "The decomposer must return a tuple of references."); + const auto per_element_encode = [](auto& tuple_element) + { + using element_type_t = std::decay_t; + using codec_t = radix_key_codec; + codec_t::encode_inplace(tuple_element); + }; + ::rocprim::detail::for_each_in_tuple(decomposer(key), per_element_encode); + } + + /// \brief Decodes an encoded key of type \p bit_key_type back into \p Key. + /// + /// \tparam Decomposer Decomposer functor type. Being \p Key a custom key type, the decomposer + /// type must be other than the \p identity_decomposer. + /// \param [in] bit_key Key to decode. + /// \param [in] decomposer [optional] \p Key is a custom key type, so a custom decomposer + /// functor that returns a \p ::rocprim::tuple of references to fundamental types from a + /// \p Key key is needed. + /// \return A \p Key decoded key. + template + ROCPRIM_HOST_DEVICE static Key decode(bit_key_type bit_key, Decomposer decomposer = {}) + { + decode_inplace(bit_key, decomposer); + return static_cast(bit_key); + } + + /// \brief Decodes in-place an encoded key of type \p Key. + /// + /// \tparam Decomposer Decomposer functor type. Being \p Key a custom key type, the decomposer + /// type must be other than the \p identity_decomposer. + /// \param [in, out] key Key to decode. + /// \param [in] decomposer [optional] Decomposer functor. + template + ROCPRIM_HOST_DEVICE static void decode_inplace(Key& key, Decomposer decomposer = {}) + { + static_assert(!std::is_same::value, + "The decomposer of a custom-type key cannot be the identity decomposer."); + static_assert(::rocprim::detail::is_tuple_of_references::value, + "The decomposer must return a tuple of references."); + const auto per_element_decode = [](auto& tuple_element) + { + using element_type_t = std::decay_t; + using codec_t = radix_key_codec; + codec_t::decode_inplace(tuple_element); + }; + ::rocprim::detail::for_each_in_tuple(decomposer(key), per_element_decode); + } + + /// \brief Extracts the specified bits from a given encoded key. + /// + /// \return Requested bits from the key. + ROCPRIM_HOST_DEVICE static unsigned int extract_digit(bit_key_type, unsigned int, unsigned int) + { + static_assert( + sizeof(bit_key_type) == 0, + "Only fundamental types (integral and floating point) are supported as radix sort" + "keys without a decomposer. " + "For custom key types, use the extract_digit overloads with the decomposer argument"); + } + + /// \brief Extracts the specified bits from a given in-place encoded key. + /// + /// \tparam Decomposer Decomposer functor type. Being \p Key a custom key type, the decomposer + /// type must be other than the \p identity_decomposer. + /// \param [in] key Key. + /// \param [in] start Start bit of the sequence of bits to extract. + /// \param [in] radix_bits How many bits to extract. + /// \param [in] decomposer \p Key is a custom key type, so a custom decomposer + /// functor that returns a \p ::rocprim::tuple of references to fundamental types from a + /// \p Key key is needed. + /// \return Requested bits from the key. + template + ROCPRIM_HOST_DEVICE static unsigned int + extract_digit(Key key, unsigned int start, unsigned int radix_bits, Decomposer decomposer) + { + static_assert(!std::is_same::value, + "The decomposer of a custom-type key cannot be the identity decomposer."); + static_assert(::rocprim::detail::is_tuple_of_references::value, + "The decomposer must return a tuple of references."); + constexpr size_t tuple_size + = ::rocprim::tuple_size>::value; + return extract_digit_from_key_impl(0u, + decomposer(key), + static_cast(start), + static_cast(start + radix_bits), + 0); + } + + /// \brief Gives the default value for out-of-bound keys of type \p Key. + /// + /// \tparam Decomposer Decomposer functor type. Being \p Key a custom key type, the decomposer + /// type must be other than the \p identity_decomposer. + /// \param [in] decomposer \p Key is a custom key type, so a custom decomposer + /// functor that returns a \p ::rocprim::tuple of references to fundamental types from a + /// \p Key key is needed. + /// \return Out-of-bound keys' value. + template + ROCPRIM_HOST_DEVICE static Key get_out_of_bounds_key(Decomposer decomposer) + { + static_assert(!std::is_same::value, + "The decomposer of a custom-type key cannot be the identity decomposer."); + static_assert(std::is_default_constructible::value, + "The sorted Key type must be default constructible"); + Key key; + ::rocprim::detail::for_each_in_tuple( + decomposer(key), + [](auto& element) + { + using element_t = std::decay_t; + using codec_t = radix_key_codec; + using bit_key_type = typename codec_t::bit_key_type; + element = codec_t::decode(static_cast(-1)); + }); + return key; + } + +private: + template + ROCPRIM_HOST_DEVICE static auto + extract_digit_from_key_impl(unsigned int digit, + const ::rocprim::tuple& key_tuple, + const int start, + const int end, + const int previous_bits) + -> std::enable_if_t<(ElementIndex >= 0), unsigned int> + { + using T = std::decay_t<::rocprim::tuple_element_t>>; + using bit_key_type = typename radix_key_codec::bit_key_type; + constexpr int current_element_bits = 8 * sizeof(T); + + const int total_extracted_bits = end - start; + const int current_element_end_bit = previous_bits + current_element_bits; + if(start < current_element_end_bit && end > previous_bits) + { + // unsigned integral representation of the current tuple element + const auto element_value = ::rocprim::detail::bit_cast( + ::rocprim::get(key_tuple)); + + const int bits_extracted_previously = ::rocprim::max(0, previous_bits - start); + + // start of the bit range copied from the current tuple element + const int current_start_bit = ::rocprim::max(0, start - previous_bits); + + // end of the bit range copied from the current tuple element + const int current_end_bit = ::rocprim::min(current_element_bits, + current_start_bit + total_extracted_bits + - bits_extracted_previously); + + // number of bits extracted from the current tuple element + const int current_length = current_end_bit - current_start_bit; + + // bits extracted from the current tuple element, aligned to LSB + unsigned int current_extract = (element_value >> current_start_bit); + + if(current_length != 32) + { + current_extract &= (1u << current_length) - 1; + } + + digit |= current_extract << bits_extracted_previously; + } + return extract_digit_from_key_impl(digit, + key_tuple, + start, + end, + previous_bits + current_element_bits); + } + + /// + template + ROCPRIM_HOST_DEVICE static auto + extract_digit_from_key_impl(unsigned int digit, + const ::rocprim::tuple& /*key_tuple*/, + const int /*start*/, + const int /*end*/, + const int /*previous_bits*/) + -> std::enable_if_t<(ElementIndex < 0), unsigned int> + { + return digit; + } +}; + +namespace detail +{ + +template +using radix_key_codec [[deprecated("radix_key_codec is now public API.")]] += rocprim::radix_key_codec; + +} // namespace detail +END_ROCPRIM_NAMESPACE + +/// @} +// end of group threadmodule + +#endif // ROCPRIM_THREAD_RADIX_KEY_CODEC_HPP_ diff --git a/rocprim/include/rocprim/thread/thread_load.hpp b/rocprim/include/rocprim/thread/thread_load.hpp index d10f6cf2c..cb724f5ab 100644 --- a/rocprim/include/rocprim/thread/thread_load.hpp +++ b/rocprim/include/rocprim/thread/thread_load.hpp @@ -1,7 +1,7 @@ /****************************************************************************** * Copyright (c) 2010-2011, Duane Merrill. All rights reserved. * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. - * Modifications Copyright (c) 2021, Advanced Micro Devices, Inc. All rights reserved. + * Modifications Copyright (c) 2021-2024, Advanced Micro Devices, Inc. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: @@ -52,30 +52,32 @@ namespace detail template ROCPRIM_DEVICE __forceinline__ T AsmThreadLoad(void * ptr) { - T retval = 0; + T retval{}; __builtin_memcpy(&retval, ptr, sizeof(T)); return retval; } #if ROCPRIM_THREAD_LOAD_USE_CACHE_MODIFIERS == 1 -// Important for syncing. Check section 9.2.2 or 7.3 in the following document -// http://developer.amd.com/wordpress/media/2013/12/AMD_GCN3_Instruction_Set_Architecture_rev1.1.pdf -#define ROCPRIM_ASM_THREAD_LOAD(cache_modifier, \ - llvm_cache_modifier, \ - type, \ - interim_type, \ - asm_operator, \ - output_modifier, \ - wait_cmd) \ - template<> \ - ROCPRIM_DEVICE __forceinline__ type AsmThreadLoad(void * ptr) \ - { \ - interim_type retval; \ - asm volatile(#asm_operator " %0, %1 " llvm_cache_modifier : "=" #output_modifier(retval) : "v"(ptr)); \ - asm volatile("s_waitcnt " wait_cmd "(%0)" : : "I"(0x00)); \ - return retval; \ - } + // Important for syncing. Check section 9.2.2 or 7.3 in the following document + // http://developer.amd.com/wordpress/media/2013/12/AMD_GCN3_Instruction_Set_Architecture_rev1.1.pdf + #define ROCPRIM_ASM_THREAD_LOAD(cache_modifier, \ + llvm_cache_modifier, \ + type, \ + interim_type, \ + asm_operator, \ + output_modifier, \ + wait_cmd) \ + template<> \ + ROCPRIM_DEVICE __forceinline__ type AsmThreadLoad(void* ptr) \ + { \ + interim_type retval; \ + asm volatile(#asm_operator " %0, %1 " llvm_cache_modifier "\n\t" \ + "s_waitcnt " wait_cmd "(%2)" \ + : "=" #output_modifier(retval) \ + : "v"(ptr), "I"(0x00)); \ + return retval; \ + } // TODO Add specialization for custom larger data types #define ROCPRIM_ASM_THREAD_LOAD_GROUP(cache_modifier, llvm_cache_modifier, wait_cmd) \ diff --git a/rocprim/include/rocprim/type_traits.hpp b/rocprim/include/rocprim/type_traits.hpp index 8a40d79f3..4f31d2319 100644 --- a/rocprim/include/rocprim/type_traits.hpp +++ b/rocprim/include/rocprim/type_traits.hpp @@ -22,9 +22,14 @@ #define ROCPRIM_TYPE_TRAITS_HPP_ #include "config.hpp" +#include "functional.hpp" #include "types.hpp" +#include "types/tuple.hpp" + +#include #include +#include /// \addtogroup utilsmodule_typetraits /// @{ @@ -142,9 +147,10 @@ struct get_unsigned_bits_type #ifndef DOXYGEN_SHOULD_SKIP_THIS template -ROCPRIM_DEVICE ROCPRIM_INLINE -auto TwiddleIn(UnsignedBits key) - -> typename std::enable_if::value, UnsignedBits>::type +[[deprecated("TwiddleIn is deprecated." + "Use radix_key_codec instead.")]] ROCPRIM_DEVICE ROCPRIM_INLINE auto + TwiddleIn(UnsignedBits key) -> + typename std::enable_if::value, UnsignedBits>::type { static const UnsignedBits HIGH_BIT = UnsignedBits(1) << ((sizeof(UnsignedBits) * 8) - 1); UnsignedBits mask = (key & HIGH_BIT) ? UnsignedBits(-1) : HIGH_BIT; @@ -152,26 +158,29 @@ auto TwiddleIn(UnsignedBits key) } template -static ROCPRIM_DEVICE ROCPRIM_INLINE -auto TwiddleIn(UnsignedBits key) - -> typename std::enable_if::value, UnsignedBits>::type +[[deprecated("TwiddleIn is deprecated." + "Use radix_key_codec instead.")]] static ROCPRIM_DEVICE ROCPRIM_INLINE auto + TwiddleIn(UnsignedBits key) -> + typename std::enable_if::value, UnsignedBits>::type { return key ; }; template -static ROCPRIM_DEVICE ROCPRIM_INLINE -auto TwiddleIn(UnsignedBits key) - -> typename std::enable_if::value && is_signed::value, UnsignedBits>::type +[[deprecated("TwiddleIn is deprecated." + "Use radix_key_codec instead.")]] static ROCPRIM_DEVICE ROCPRIM_INLINE auto + TwiddleIn(UnsignedBits key) -> + typename std::enable_if::value && is_signed::value, UnsignedBits>::type { static const UnsignedBits HIGH_BIT = UnsignedBits(1) << ((sizeof(UnsignedBits) * 8) - 1); return key ^ HIGH_BIT; }; template -ROCPRIM_DEVICE ROCPRIM_INLINE -auto TwiddleOut(UnsignedBits key) - -> typename std::enable_if::value, UnsignedBits>::type +[[deprecated("TwiddleOut is deprecated." + "Use radix_key_codec instead.")]] ROCPRIM_DEVICE ROCPRIM_INLINE auto + TwiddleOut(UnsignedBits key) -> + typename std::enable_if::value, UnsignedBits>::type { static const UnsignedBits HIGH_BIT = UnsignedBits(1) << ((sizeof(UnsignedBits) * 8) - 1); UnsignedBits mask = (key & HIGH_BIT) ? HIGH_BIT : UnsignedBits(-1); @@ -179,17 +188,19 @@ auto TwiddleOut(UnsignedBits key) } template -static ROCPRIM_DEVICE ROCPRIM_INLINE -auto TwiddleOut(UnsignedBits key) - -> typename std::enable_if::value, UnsignedBits>::type +[[deprecated("TwiddleOut is deprecated." + "Use radix_key_codec instead.")]] static ROCPRIM_DEVICE ROCPRIM_INLINE auto + TwiddleOut(UnsignedBits key) -> + typename std::enable_if::value, UnsignedBits>::type { return key; }; template -static ROCPRIM_DEVICE ROCPRIM_INLINE -auto TwiddleOut(UnsignedBits key) - -> typename std::enable_if::value && is_signed::value, UnsignedBits>::type +[[deprecated("TwiddleOut is deprecated." + "Use radix_key_codec instead.")]] static ROCPRIM_DEVICE ROCPRIM_INLINE auto + TwiddleOut(UnsignedBits key) -> + typename std::enable_if::value && is_signed::value, UnsignedBits>::type { static const UnsignedBits HIGH_BIT = UnsignedBits(1) << ((sizeof(UnsignedBits) * 8) - 1); return key ^ HIGH_BIT; @@ -266,6 +277,76 @@ struct invoke_result_impl(), std::declval(), std::declval()...)); }; +template +struct is_tuple_of_references +{ + static_assert(sizeof(T) == 0, "is_tuple_of_references is only implemented for rocprim::tuple"); +}; + +template +struct is_tuple_of_references<::rocprim::tuple> +{ +private: + template + ROCPRIM_HOST_DEVICE static constexpr bool is_tuple_of_references_impl() + { + using tuple_t = ::rocprim::tuple; + using element_t = ::rocprim::tuple_element_t; + return std::is_reference::value && is_tuple_of_references_impl(); + } + + template<> + ROCPRIM_HOST_DEVICE static constexpr bool is_tuple_of_references_impl() + { + return true; + } + +public: + static constexpr bool value = is_tuple_of_references_impl<0>(); +}; + +template +struct float_bit_mask; + +template<> +struct float_bit_mask +{ + static constexpr uint32_t sign_bit = 0x80000000; + static constexpr uint32_t exponent = 0x7F800000; + static constexpr uint32_t mantissa = 0x007FFFFF; + using bit_type = uint32_t; +}; + +template<> +struct float_bit_mask +{ + static constexpr uint64_t sign_bit = 0x8000000000000000; + static constexpr uint64_t exponent = 0x7FF0000000000000; + static constexpr uint64_t mantissa = 0x000FFFFFFFFFFFFF; + using bit_type = uint64_t; +}; + +template<> +struct float_bit_mask +{ + static constexpr uint16_t sign_bit = 0x8000; + static constexpr uint16_t exponent = 0x7F80; + static constexpr uint16_t mantissa = 0x007F; + using bit_type = uint16_t; +}; + +template<> +struct float_bit_mask +{ + static constexpr uint16_t sign_bit = 0x8000; + static constexpr uint16_t exponent = 0x7C00; + static constexpr uint16_t mantissa = 0x03FF; + using bit_type = uint16_t; +}; + +template +using void_t = void; + } // end namespace detail /// \brief Behaves like ``std::invoke_result``, but allows the use of invoke_result diff --git a/rocprim/include/rocprim/types.hpp b/rocprim/include/rocprim/types.hpp index 162bd7cf7..519857f34 100644 --- a/rocprim/include/rocprim/types.hpp +++ b/rocprim/include/rocprim/types.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017-2021 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -142,6 +142,12 @@ struct empty_binary_op constexpr empty_type operator()(const empty_type&, const empty_type&) const { return empty_type{}; } }; +/// \brief A decomposer that must be passed to the radix sort algorithms when +/// sorting keys that are arithmetic types. +/// To sort custom types, a custom decomposer should be provided. +struct identity_decomposer +{}; + /// \brief Half-precision floating point type using half = ::__half; /// \brief bfloat16 floating point type diff --git a/rocprim/include/rocprim/warp/detail/warp_sort_shuffle.hpp b/rocprim/include/rocprim/warp/detail/warp_sort_shuffle.hpp index ca543a1b7..d4bd2cee7 100644 --- a/rocprim/include/rocprim/warp/detail/warp_sort_shuffle.hpp +++ b/rocprim/include/rocprim/warp/detail/warp_sort_shuffle.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017-2022 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -26,167 +26,122 @@ #include "../../config.hpp" #include "../../detail/various.hpp" -#include "../../intrinsics.hpp" #include "../../functional.hpp" +#include "../../intrinsics.hpp" BEGIN_ROCPRIM_NAMESPACE namespace detail { -template< - class Key, - unsigned int WarpSize, - class Value -> +template class warp_sort_shuffle { private: - template - ROCPRIM_DEVICE ROCPRIM_INLINE - typename std::enable_if warp)>::type - swap(Key& k, V& v, int mask, bool dir, BinaryFunction compare_function) + template + ROCPRIM_DEVICE ROCPRIM_INLINE typename std::enable_if warp)>::type + swap(Key& k, V& v, bool dir, BinaryFunction compare_function) { - (void) k; - (void) v; - (void) mask; - (void) dir; - (void) compare_function; + (void)k; + (void)v; + (void)dir; + (void)compare_function; } - template - ROCPRIM_DEVICE ROCPRIM_INLINE - typename std::enable_if<(WarpSize > warp)>::type - swap(Key& k, V& v, int mask, bool dir, BinaryFunction compare_function) + template + ROCPRIM_DEVICE ROCPRIM_INLINE typename std::enable_if<(WarpSize > warp)>::type + swap(Key& k, V& v, bool dir, BinaryFunction compare_function) { - Key k1 = warp_shuffle_xor(k, mask, WarpSize); - //V v1 = warp_shuffle_xor(v, mask, WarpSize); + Key k1 = warp_swizzle_shuffle(k, xor_mask, WarpSize); bool swap = compare_function(dir ? k : k1, dir ? k1 : k); - if (swap) + if(swap) { k = k1; - v = warp_shuffle_xor(v, mask, WarpSize); + v = warp_swizzle_shuffle(v, xor_mask, WarpSize); } } - template< - int warp, - class V, - class BinaryFunction, - unsigned int ItemsPerThread - > - ROCPRIM_DEVICE ROCPRIM_INLINE - typename std::enable_if warp)>::type - swap(Key (&k)[ItemsPerThread], - V (&v)[ItemsPerThread], - int mask, - bool dir, - BinaryFunction compare_function) + template + ROCPRIM_DEVICE ROCPRIM_INLINE typename std::enable_if warp)>::type swap( + Key (&k)[ItemsPerThread], V (&v)[ItemsPerThread], bool dir, BinaryFunction compare_function) { - (void) k; - (void) v; - (void) mask; - (void) dir; - (void) compare_function; + (void)k; + (void)v; + (void)dir; + (void)compare_function; } - template< - int warp, - class V, - class BinaryFunction, - unsigned int ItemsPerThread - > - ROCPRIM_DEVICE ROCPRIM_INLINE - typename std::enable_if<(WarpSize > warp)>::type - swap(Key (&k)[ItemsPerThread], - V (&v)[ItemsPerThread], - int mask, - bool dir, - BinaryFunction compare_function) + template + ROCPRIM_DEVICE ROCPRIM_INLINE typename std::enable_if<(WarpSize > warp)>::type swap( + Key (&k)[ItemsPerThread], V (&v)[ItemsPerThread], bool dir, BinaryFunction compare_function) { Key k1[ItemsPerThread]; ROCPRIM_UNROLL - for (unsigned int item = 0; item < ItemsPerThread; item++) + for(unsigned int item = 0; item < ItemsPerThread; item++) { - k1[item]= warp_shuffle_xor(k[item], mask, WarpSize); - //V v1 = warp_shuffle_xor(v, mask, WarpSize); - bool swap = compare_function(dir ? k[item] : k1[item], dir ? k1[item] : k[item]); - if (swap) - { - k[item] = k1[item]; - v[item] = warp_shuffle_xor(v[item], mask, WarpSize); - } + k1[item] = warp_swizzle_shuffle(k[item], xor_mask, WarpSize); + bool swap = compare_function(dir ? k[item] : k1[item], dir ? k1[item] : k[item]); + if(swap) + { + k[item] = k1[item]; + v[item] = warp_swizzle_shuffle(v[item], xor_mask, WarpSize); + } } } - template - ROCPRIM_DEVICE ROCPRIM_INLINE - typename std::enable_if warp)>::type - swap(Key& k, int mask, bool dir, BinaryFunction compare_function) + template + ROCPRIM_DEVICE ROCPRIM_INLINE typename std::enable_if warp)>::type + swap(Key& k, bool dir, BinaryFunction compare_function) { - (void) k; - (void) mask; - (void) dir; - (void) compare_function; + (void)k; + (void)dir; + (void)compare_function; } - template - ROCPRIM_DEVICE ROCPRIM_INLINE - typename std::enable_if<(WarpSize > warp)>::type - swap(Key& k, int mask, bool dir, BinaryFunction compare_function) + template + ROCPRIM_DEVICE ROCPRIM_INLINE typename std::enable_if<(WarpSize > warp)>::type + swap(Key& k, bool dir, BinaryFunction compare_function) { - Key k1 = warp_shuffle_xor(k, mask, WarpSize); + Key k1 = warp_swizzle_shuffle(k, xor_mask, WarpSize); bool swap = compare_function(dir ? k : k1, dir ? k1 : k); - if (swap) + if(swap) { k = k1; } } - template< - int warp, - class BinaryFunction, - unsigned int ItemsPerThread - > - ROCPRIM_DEVICE ROCPRIM_INLINE - typename std::enable_if warp)>::type - swap(Key (&k)[ItemsPerThread], int mask, bool dir, BinaryFunction compare_function) + template + ROCPRIM_DEVICE ROCPRIM_INLINE typename std::enable_if warp)>::type + swap(Key (&k)[ItemsPerThread], bool dir, BinaryFunction compare_function) { - (void) k; - (void) mask; - (void) dir; - (void) compare_function; + (void)k; + (void)dir; + (void)compare_function; } - template< - int warp, - class BinaryFunction, - unsigned int ItemsPerThread - > - ROCPRIM_DEVICE ROCPRIM_INLINE - typename std::enable_if<(WarpSize > warp)>::type - swap(Key (&k)[ItemsPerThread], int mask, bool dir, BinaryFunction compare_function) + template + ROCPRIM_DEVICE ROCPRIM_INLINE typename std::enable_if<(WarpSize > warp)>::type + swap(Key (&k)[ItemsPerThread], bool dir, BinaryFunction compare_function) { Key k1[ItemsPerThread]; ROCPRIM_UNROLL - for (unsigned int item = 0; item < ItemsPerThread; item++) + for(unsigned int item = 0; item < ItemsPerThread; item++) { - k1[item]= warp_shuffle_xor(k[item], mask, WarpSize); + k1[item] = warp_swizzle_shuffle(k[item], xor_mask, WarpSize); bool swap = compare_function(dir ? k[item] : k1[item], dir ? k1[item] : k[item]); - if (swap) + if(swap) { k[item] = k1[item]; } } } - template - ROCPRIM_DEVICE ROCPRIM_INLINE - void thread_swap(Key (&k)[ItemsPerThread], - unsigned int i, - unsigned int j, - bool dir, - BinaryFunction compare_function) + template + ROCPRIM_DEVICE ROCPRIM_INLINE void thread_swap(Key (&k)[ItemsPerThread], + unsigned int i, + unsigned int j, + bool dir, + BinaryFunction compare_function) { if(compare_function(k[i], k[j]) == dir) { @@ -196,14 +151,13 @@ class warp_sort_shuffle } } - template - ROCPRIM_DEVICE ROCPRIM_INLINE - void thread_swap(Key (&k)[ItemsPerThread], - V (&v)[ItemsPerThread], - unsigned int i, - unsigned int j, - bool dir, - BinaryFunction compare_function) + template + ROCPRIM_DEVICE ROCPRIM_INLINE void thread_swap(Key (&k)[ItemsPerThread], + V (&v)[ItemsPerThread], + unsigned int i, + unsigned int j, + bool dir, + BinaryFunction compare_function) { if(compare_function(k[i], k[j]) == dir) { @@ -216,35 +170,30 @@ class warp_sort_shuffle } } - template - ROCPRIM_DEVICE ROCPRIM_INLINE - void thread_shuffle(unsigned int group_size, - unsigned int offset, - bool dir, - BinaryFunction compare_function, - KeyValue&... kv) + template + ROCPRIM_DEVICE ROCPRIM_INLINE void thread_shuffle(unsigned int group_size, + unsigned int offset, + bool dir, + BinaryFunction compare_function, + KeyValue&... kv) { ROCPRIM_UNROLL - for(unsigned int base = 0; base < ItemsPerThread; base += 2 * offset) { + for(unsigned int base = 0; base < ItemsPerThread; base += 2 * offset) + { // The local direction must change every group_size items // and is flipped if dir is true const bool local_dir = ((base & group_size) > 0) != dir; - ROCPRIM_UNROLL -// Workaround to prevent the compiler thinking this is a 'Parallel Loop' on clang 15 -// because it leads to invalid code generation with `T` = `char` and `ItemsPerthread` = 4 -#if defined(__clang_major__) && __clang_major__ >= 15 - #pragma clang loop vectorize(disable) -#endif - for(unsigned i = 0; i < offset; ++i) { + for(unsigned i = 0; i < offset; ++i) + { thread_swap(kv..., base + i, base + i + offset, local_dir, compare_function); } } } - template - ROCPRIM_DEVICE ROCPRIM_INLINE - void thread_sort(bool dir, BinaryFunction compare_function, KeyValue&... kv) + template + ROCPRIM_DEVICE ROCPRIM_INLINE void + thread_sort(bool dir, BinaryFunction compare_function, KeyValue&... kv) { ROCPRIM_UNROLL for(unsigned int k = 2; k <= ItemsPerThread; k *= 2) @@ -257,10 +206,9 @@ class warp_sort_shuffle } } - template - ROCPRIM_DEVICE ROCPRIM_INLINE - typename std::enable_if<(WarpSize > warp)>::type - thread_merge(bool dir, BinaryFunction compare_function, KeyValue&... kv) + template + ROCPRIM_DEVICE ROCPRIM_INLINE typename std::enable_if<(WarpSize > warp)>::type + thread_merge(bool dir, BinaryFunction compare_function, KeyValue&... kv) { ROCPRIM_UNROLL for(unsigned int j = ItemsPerThread / 2; j > 0; j /= 2) @@ -269,100 +217,92 @@ class warp_sort_shuffle } } - template - ROCPRIM_DEVICE ROCPRIM_INLINE - typename std::enable_if warp)>::type - thread_merge(bool /*dir*/, BinaryFunction /*compare_function*/, KeyValue&... /*kv*/) - { - } + template + ROCPRIM_DEVICE ROCPRIM_INLINE typename std::enable_if warp)>::type + thread_merge(bool /*dir*/, BinaryFunction /*compare_function*/, KeyValue&... /*kv*/) + {} template - ROCPRIM_DEVICE ROCPRIM_INLINE - void bitonic_sort(BinaryFunction compare_function, KeyValue&... kv) + ROCPRIM_DEVICE ROCPRIM_INLINE void bitonic_sort(BinaryFunction compare_function, + KeyValue&... kv) { - static_assert( - sizeof...(KeyValue) < 3, - "KeyValue parameter pack can 1 or 2 elements (key, or key and value)" - ); - - unsigned int id = detail::logical_lane_id(); - swap< 2>(kv..., 1, get_bit(id, 1) != get_bit(id, 0), compare_function); - - swap< 4>(kv..., 2, get_bit(id, 2) != get_bit(id, 1), compare_function); - swap< 4>(kv..., 1, get_bit(id, 2) != get_bit(id, 0), compare_function); - - swap< 8>(kv..., 4, get_bit(id, 3) != get_bit(id, 2), compare_function); - swap< 8>(kv..., 2, get_bit(id, 3) != get_bit(id, 1), compare_function); - swap< 8>(kv..., 1, get_bit(id, 3) != get_bit(id, 0), compare_function); - - swap<16>(kv..., 8, get_bit(id, 4) != get_bit(id, 3), compare_function); - swap<16>(kv..., 4, get_bit(id, 4) != get_bit(id, 2), compare_function); - swap<16>(kv..., 2, get_bit(id, 4) != get_bit(id, 1), compare_function); - swap<16>(kv..., 1, get_bit(id, 4) != get_bit(id, 0), compare_function); - - swap<32>(kv..., 16, get_bit(id, 5) != get_bit(id, 4), compare_function); - swap<32>(kv..., 8, get_bit(id, 5) != get_bit(id, 3), compare_function); - swap<32>(kv..., 4, get_bit(id, 5) != get_bit(id, 2), compare_function); - swap<32>(kv..., 2, get_bit(id, 5) != get_bit(id, 1), compare_function); - swap<32>(kv..., 1, get_bit(id, 5) != get_bit(id, 0), compare_function); - - swap<32>(kv..., 32, get_bit(id, 5) != 0, compare_function); - swap<16>(kv..., 16, get_bit(id, 4) != 0, compare_function); - swap< 8>(kv..., 8, get_bit(id, 3) != 0, compare_function); - swap< 4>(kv..., 4, get_bit(id, 2) != 0, compare_function); - swap< 2>(kv..., 2, get_bit(id, 1) != 0, compare_function); - swap< 0>(kv..., 1, get_bit(id, 0) != 0, compare_function); + static_assert(sizeof...(KeyValue) < 3, + "KeyValue parameter pack can 1 or 2 elements (key, or key and value)"); + + const unsigned int id = detail::logical_lane_id(); + + swap<2, 1>(kv..., get_bit(id, 1) != get_bit(id, 0), compare_function); + + swap<4, 2>(kv..., get_bit(id, 2) != get_bit(id, 1), compare_function); + swap<4, 1>(kv..., get_bit(id, 2) != get_bit(id, 0), compare_function); + + swap<8, 4>(kv..., get_bit(id, 3) != get_bit(id, 2), compare_function); + swap<8, 2>(kv..., get_bit(id, 3) != get_bit(id, 1), compare_function); + swap<8, 1>(kv..., get_bit(id, 3) != get_bit(id, 0), compare_function); + + swap<16, 8>(kv..., get_bit(id, 4) != get_bit(id, 3), compare_function); + swap<16, 4>(kv..., get_bit(id, 4) != get_bit(id, 2), compare_function); + swap<16, 2>(kv..., get_bit(id, 4) != get_bit(id, 1), compare_function); + swap<16, 1>(kv..., get_bit(id, 4) != get_bit(id, 0), compare_function); + + swap<32, 16>(kv..., get_bit(id, 5) != get_bit(id, 4), compare_function); + swap<32, 8>(kv..., get_bit(id, 5) != get_bit(id, 3), compare_function); + swap<32, 4>(kv..., get_bit(id, 5) != get_bit(id, 2), compare_function); + swap<32, 2>(kv..., get_bit(id, 5) != get_bit(id, 1), compare_function); + swap<32, 1>(kv..., get_bit(id, 5) != get_bit(id, 0), compare_function); + + swap<32, 32>(kv..., get_bit(id, 5) != 0, compare_function); + swap<16, 16>(kv..., get_bit(id, 4) != 0, compare_function); + swap<8, 8>(kv..., get_bit(id, 3) != 0, compare_function); + swap<4, 4>(kv..., get_bit(id, 2) != 0, compare_function); + swap<2, 2>(kv..., get_bit(id, 1) != 0, compare_function); + swap<0, 1>(kv..., get_bit(id, 0) != 0, compare_function); } - template< - unsigned int ItemsPerThread, - class BinaryFunction, - class... KeyValue - > - ROCPRIM_DEVICE ROCPRIM_INLINE - void bitonic_sort(BinaryFunction compare_function, KeyValue&... kv) + template + ROCPRIM_DEVICE ROCPRIM_INLINE void bitonic_sort(BinaryFunction compare_function, + KeyValue&... kv) { - static_assert( - sizeof...(KeyValue) < 3, - "KeyValue parameter pack can 1 or 2 elements (key, or key and value)" - ); + static_assert(sizeof...(KeyValue) < 3, + "KeyValue parameter pack can 1 or 2 elements (key, or key and value)"); static_assert(detail::is_power_of_two(ItemsPerThread), "ItemsPerThread must be power of 2"); - unsigned int id = detail::logical_lane_id(); + const unsigned int id = detail::logical_lane_id(); + thread_sort(get_bit(id, 0) != 0, compare_function, kv...); - swap< 2>(kv..., 1, get_bit(id, 1) != get_bit(id, 0), compare_function); + swap<2, 1>(kv..., get_bit(id, 1) != get_bit(id, 0), compare_function); thread_merge<2, ItemsPerThread>(get_bit(id, 1) != 0, compare_function, kv...); - swap< 4>(kv..., 2, get_bit(id, 2) != get_bit(id, 1), compare_function); - swap< 4>(kv..., 1, get_bit(id, 2) != get_bit(id, 0), compare_function); + swap<4, 2>(kv..., get_bit(id, 2) != get_bit(id, 1), compare_function); + swap<4, 1>(kv..., get_bit(id, 2) != get_bit(id, 0), compare_function); thread_merge<4, ItemsPerThread>(get_bit(id, 2) != 0, compare_function, kv...); - swap< 8>(kv..., 4, get_bit(id, 3) != get_bit(id, 2), compare_function); - swap< 8>(kv..., 2, get_bit(id, 3) != get_bit(id, 1), compare_function); - swap< 8>(kv..., 1, get_bit(id, 3) != get_bit(id, 0), compare_function); + swap<8, 4>(kv..., get_bit(id, 3) != get_bit(id, 2), compare_function); + swap<8, 2>(kv..., get_bit(id, 3) != get_bit(id, 1), compare_function); + swap<8, 1>(kv..., get_bit(id, 3) != get_bit(id, 0), compare_function); thread_merge<8, ItemsPerThread>(get_bit(id, 3) != 0, compare_function, kv...); - swap<16>(kv..., 8, get_bit(id, 4) != get_bit(id, 3), compare_function); - swap<16>(kv..., 4, get_bit(id, 4) != get_bit(id, 2), compare_function); - swap<16>(kv..., 2, get_bit(id, 4) != get_bit(id, 1), compare_function); - swap<16>(kv..., 1, get_bit(id, 4) != get_bit(id, 0), compare_function); + swap<16, 8>(kv..., get_bit(id, 4) != get_bit(id, 3), compare_function); + swap<16, 4>(kv..., get_bit(id, 4) != get_bit(id, 2), compare_function); + swap<16, 2>(kv..., get_bit(id, 4) != get_bit(id, 1), compare_function); + swap<16, 1>(kv..., get_bit(id, 4) != get_bit(id, 0), compare_function); thread_merge<16, ItemsPerThread>(get_bit(id, 4) != 0, compare_function, kv...); - swap<32>(kv..., 16, get_bit(id, 5) != get_bit(id, 4), compare_function); - swap<32>(kv..., 8, get_bit(id, 5) != get_bit(id, 3), compare_function); - swap<32>(kv..., 4, get_bit(id, 5) != get_bit(id, 2), compare_function); - swap<32>(kv..., 2, get_bit(id, 5) != get_bit(id, 1), compare_function); - swap<32>(kv..., 1, get_bit(id, 5) != get_bit(id, 0), compare_function); + swap<32, 16>(kv..., get_bit(id, 5) != get_bit(id, 4), compare_function); + swap<32, 8>(kv..., get_bit(id, 5) != get_bit(id, 3), compare_function); + swap<32, 4>(kv..., get_bit(id, 5) != get_bit(id, 2), compare_function); + swap<32, 2>(kv..., get_bit(id, 5) != get_bit(id, 1), compare_function); + swap<32, 1>(kv..., get_bit(id, 5) != get_bit(id, 0), compare_function); thread_merge<32, ItemsPerThread>(get_bit(id, 5) != 0, compare_function, kv...); - swap<32>(kv..., 32, get_bit(id, 5) != 0, compare_function); - swap<16>(kv..., 16, get_bit(id, 4) != 0, compare_function); - swap< 8>(kv..., 8, get_bit(id, 3) != 0, compare_function); - swap< 4>(kv..., 4, get_bit(id, 2) != 0, compare_function); - swap< 2>(kv..., 2, get_bit(id, 1) != 0, compare_function); - swap< 0>(kv..., 1, get_bit(id, 0) != 0, compare_function); + swap<32, 32>(kv..., get_bit(id, 5) != 0, compare_function); + swap<16, 16>(kv..., get_bit(id, 4) != 0, compare_function); + swap<8, 8>(kv..., get_bit(id, 3) != 0, compare_function); + swap<4, 4>(kv..., get_bit(id, 2) != 0, compare_function); + swap<2, 2>(kv..., get_bit(id, 1) != 0, compare_function); + swap<0, 1>(kv..., get_bit(id, 0) != 0, compare_function); thread_merge<1, ItemsPerThread>(false, compare_function, kv...); } @@ -372,61 +312,47 @@ class warp_sort_shuffle using storage_type = ::rocprim::detail::empty_storage_type; template - ROCPRIM_DEVICE ROCPRIM_INLINE - void sort(Key& thread_value, BinaryFunction compare_function) + ROCPRIM_DEVICE ROCPRIM_INLINE void sort(Key& thread_value, BinaryFunction compare_function) { // sort by value only bitonic_sort(compare_function, thread_value); } template - ROCPRIM_DEVICE ROCPRIM_INLINE - void sort(Key& thread_value, storage_type& storage, - BinaryFunction compare_function) + ROCPRIM_DEVICE ROCPRIM_INLINE void + sort(Key& thread_value, storage_type& storage, BinaryFunction compare_function) { - (void) storage; + (void)storage; sort(thread_value, compare_function); } - template< - unsigned int ItemsPerThread, - class BinaryFunction - > - ROCPRIM_DEVICE ROCPRIM_INLINE - void sort(Key (&thread_values)[ItemsPerThread], - BinaryFunction compare_function) + template + ROCPRIM_DEVICE ROCPRIM_INLINE void sort(Key (&thread_values)[ItemsPerThread], + BinaryFunction compare_function) { // sort by value only bitonic_sort(compare_function, thread_values); } - template< - unsigned int ItemsPerThread, - class BinaryFunction - > - ROCPRIM_DEVICE ROCPRIM_INLINE - void sort(Key (&thread_values)[ItemsPerThread], - storage_type& storage, - BinaryFunction compare_function) + template + ROCPRIM_DEVICE ROCPRIM_INLINE void sort(Key (&thread_values)[ItemsPerThread], + storage_type& storage, + BinaryFunction compare_function) { - (void) storage; + (void)storage; sort(thread_values, compare_function); } template - ROCPRIM_DEVICE ROCPRIM_INLINE - typename std::enable_if<(sizeof(V) <= sizeof(int))>::type - sort(Key& thread_key, Value& thread_value, - BinaryFunction compare_function) + ROCPRIM_DEVICE ROCPRIM_INLINE typename std::enable_if<(sizeof(V) <= sizeof(int))>::type + sort(Key& thread_key, Value& thread_value, BinaryFunction compare_function) { bitonic_sort(compare_function, thread_key, thread_value); } template - ROCPRIM_DEVICE ROCPRIM_INLINE - typename std::enable_if::type - sort(Key& thread_key, Value& thread_value, - BinaryFunction compare_function) + ROCPRIM_DEVICE ROCPRIM_INLINE typename std::enable_if::type + sort(Key& thread_key, Value& thread_value, BinaryFunction compare_function) { // Instead of passing large values between lanes we pass indices and gather values after sorting. unsigned int v = detail::logical_lane_id(); @@ -435,43 +361,34 @@ class warp_sort_shuffle } template - ROCPRIM_DEVICE ROCPRIM_INLINE - void sort(Key& thread_key, Value& thread_value, - storage_type& storage, BinaryFunction compare_function) + ROCPRIM_DEVICE ROCPRIM_INLINE void sort(Key& thread_key, + Value& thread_value, + storage_type& storage, + BinaryFunction compare_function) { - (void) storage; + (void)storage; sort(compare_function, thread_key, thread_value); } - template< - unsigned int ItemsPerThread, - class BinaryFunction, - class V = Value - > - ROCPRIM_DEVICE ROCPRIM_INLINE - typename std::enable_if<(sizeof(V) <= sizeof(int))>::type - sort(Key (&thread_keys)[ItemsPerThread], - Value (&thread_values)[ItemsPerThread], - BinaryFunction compare_function) + template + ROCPRIM_DEVICE ROCPRIM_INLINE typename std::enable_if<(sizeof(V) <= sizeof(int))>::type + sort(Key (&thread_keys)[ItemsPerThread], + Value (&thread_values)[ItemsPerThread], + BinaryFunction compare_function) { bitonic_sort(compare_function, thread_keys, thread_values); } - template< - unsigned int ItemsPerThread, - class BinaryFunction, - class V = Value - > - ROCPRIM_DEVICE ROCPRIM_INLINE - typename std::enable_if::type - sort(Key (&thread_keys)[ItemsPerThread], - Value (&thread_values)[ItemsPerThread], - BinaryFunction compare_function) + template + ROCPRIM_DEVICE ROCPRIM_INLINE typename std::enable_if::type + sort(Key (&thread_keys)[ItemsPerThread], + Value (&thread_values)[ItemsPerThread], + BinaryFunction compare_function) { // Instead of passing large values between lanes we pass indices and gather values after sorting. unsigned int v[ItemsPerThread]; ROCPRIM_UNROLL - for (unsigned int item = 0; item < ItemsPerThread; item++) + for(unsigned int item = 0; item < ItemsPerThread; item++) { v[item] = ItemsPerThread * detail::logical_lane_id() + item; } @@ -480,14 +397,17 @@ class warp_sort_shuffle V copy[ItemsPerThread]; ROCPRIM_UNROLL - for(unsigned item = 0; item < ItemsPerThread; ++item) { + for(unsigned item = 0; item < ItemsPerThread; ++item) + { copy[item] = thread_values[item]; } ROCPRIM_UNROLL - for(unsigned int dst_item = 0; dst_item < ItemsPerThread; ++dst_item) { + for(unsigned int dst_item = 0; dst_item < ItemsPerThread; ++dst_item) + { ROCPRIM_UNROLL - for(unsigned src_item = 0; src_item < ItemsPerThread; ++src_item) { + for(unsigned src_item = 0; src_item < ItemsPerThread; ++src_item) + { V temp = warp_shuffle(copy[src_item], v[dst_item] / ItemsPerThread, WarpSize); if(v[dst_item] % ItemsPerThread == src_item) thread_values[dst_item] = temp; @@ -495,16 +415,13 @@ class warp_sort_shuffle } } - template< - unsigned int ItemsPerThread, - class BinaryFunction - > - ROCPRIM_DEVICE ROCPRIM_INLINE - void sort(Key (&thread_keys)[ItemsPerThread], - Value (&thread_values)[ItemsPerThread], - storage_type& storage, BinaryFunction compare_function) + template + ROCPRIM_DEVICE ROCPRIM_INLINE void sort(Key (&thread_keys)[ItemsPerThread], + Value (&thread_values)[ItemsPerThread], + storage_type& storage, + BinaryFunction compare_function) { - (void) storage; + (void)storage; sort(thread_keys, thread_values, compare_function); } }; diff --git a/rocprim/include/rocprim/warp/warp_exchange.hpp b/rocprim/include/rocprim/warp/warp_exchange.hpp index 6e54f9d3b..0ac6ac6a1 100644 --- a/rocprim/include/rocprim/warp/warp_exchange.hpp +++ b/rocprim/include/rocprim/warp/warp_exchange.hpp @@ -21,13 +21,17 @@ #ifndef ROCPRIM_WARP_WARP_EXCHANGE_HPP_ #define ROCPRIM_WARP_WARP_EXCHANGE_HPP_ +#include + #include "../config.hpp" #include "../detail/various.hpp" +#include "../functional.hpp" #include "../intrinsics.hpp" #include "../intrinsics/warp_shuffle.hpp" -#include "../functional.hpp" #include "../types.hpp" +#include +#include /// \addtogroup warpmodule /// @{ @@ -90,6 +94,148 @@ class warp_exchange T buffer[WarpSize * ItemsPerThread]; }; + template + ROCPRIM_DEVICE ROCPRIM_INLINE void Foreach(const T (&input)[ItemsPerThread], + U (&output)[ItemsPerThread], + const int xor_bit_set) + { + // To prevent double work for IdX and IdX + NumEntries + if(NumEntries != 0 && (IdX / NumEntries) % 2 == 0) + { + const T send_val = (xor_bit_set ? input[IdX] : input[IdX + NumEntries]); + const T recv_val + = ::rocprim::detail::warp_swizzle_shuffle(send_val, NumEntries, WarpSize); + (xor_bit_set ? output[IdX] : output[IdX + NumEntries]) = recv_val; + } + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE void Foreach(const T (&input)[ItemsPerThread], + U (&output)[ItemsPerThread], + const std::integer_sequence, + const bool xor_bit_set) + { + // To create a static inner loop that executes the code with + // the values [0, 1, ..., ItemsPerThread-1, ItemsPerThread] as IdX + int ignored[] = {((Foreach(input, output, xor_bit_set)), 0)...}; + (void)ignored; + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE void TransposeImpl(const T (&input)[ItemsPerThread], + U (&output)[ItemsPerThread], + const unsigned int lane_id, + const std::integer_sequence) + { + // To create a static outer loop that executes the code with + // the values [ItemsPerThread/2, ItemsPerThread/4, ..., 1, 0] as NumEntries + int ignored[] + = {(Foreach<(1 << (MaxIter - Iter))>(input, + output, + std::make_integer_sequence{}, + lane_id & (1 << (MaxIter - Iter))), + 0)...}; + (void)ignored; + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE void Transpose(const T (&input)[ItemsPerThread], + U (&output)[ItemsPerThread], + const unsigned int lane_id) + { + constexpr unsigned int n_iter = rocprim::Log2::VALUE; + TransposeImpl(input, + output, + lane_id, + std::make_integer_sequence{}); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE void + blocked_striped_shuffle_efficient_impl(const T (&input)[ItemsPerThread], + U (&output)[ItemsPerThread]) + { + static constexpr bool IS_ARCH_WARP = WarpSize == ::rocprim::device_warp_size(); + const unsigned int flat_lane_id = ::rocprim::detail::logical_lane_id(); + const unsigned int lane_id = IS_ARCH_WARP ? flat_lane_id : (flat_lane_id % WarpSize); + T temp[ItemsPerThread]; + ROCPRIM_UNROLL + for(unsigned int i = 0; i < ItemsPerThread; i++) + { + temp[i] = input[i]; + } + Transpose(temp, temp, lane_id); + ROCPRIM_UNROLL + for(unsigned int i = 0; i < ItemsPerThread; i++) + { + output[i] = temp[i]; + } + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE void + blocked_to_striped_shuffle_impl(const T (&input)[ItemsPerThread], + U (&output)[ItemsPerThread]) + { + const unsigned int flat_id = ::rocprim::detail::logical_lane_id(); + U work_array[ItemsPerThread]; + + ROCPRIM_UNROLL + for(unsigned int dst_idx = 0; dst_idx < ItemsPerThread; dst_idx++) + { + ROCPRIM_UNROLL + for(unsigned int src_idx = 0; src_idx < ItemsPerThread; src_idx++) + { + const auto value = ::rocprim::warp_shuffle( + input[src_idx], + flat_id / ItemsPerThread + dst_idx * (WarpSize / ItemsPerThread), + WarpSize); + if(src_idx == flat_id % ItemsPerThread) + { + work_array[dst_idx] = value; + } + } + } + + ROCPRIM_UNROLL + for(unsigned int i = 0; i < ItemsPerThread; i++) + { + output[i] = work_array[i]; + } + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE void + striped_to_blocked_shuffle_impl(const T (&input)[ItemsPerThread], + U (&output)[ItemsPerThread]) + { + const unsigned int flat_id = ::rocprim::detail::logical_lane_id(); + U work_array[ItemsPerThread]; + + ROCPRIM_UNROLL + for(unsigned int dst_idx = 0; dst_idx < ItemsPerThread; dst_idx++) + { + ROCPRIM_UNROLL + for(unsigned int src_idx = 0; src_idx < ItemsPerThread; src_idx++) + { + const auto value + = ::rocprim::warp_shuffle(input[src_idx], + (ItemsPerThread * flat_id + dst_idx) % WarpSize, + WarpSize); + if(flat_id / (WarpSize / ItemsPerThread) == src_idx) + { + work_array[dst_idx] = value; + } + } + } + + ROCPRIM_UNROLL + for(unsigned int i = 0; i < ItemsPerThread; i++) + { + output[i] = work_array[i]; + } + } + public: /// \brief Struct used to allocate a temporary memory that is required for thread @@ -165,6 +311,7 @@ class warp_exchange /// \brief Transposes a blocked arrangement of items to a striped arrangement /// across the warp, using warp shuffle operations. + /// Uses an optimized implementation for when WarpSize is equal to ItemsPerThread. /// Caution: this API is experimental. Performance might not be consistent. /// ItemsPerThread must be a divisor of WarpSize. /// @@ -193,36 +340,19 @@ class warp_exchange /// } /// \endcode template - ROCPRIM_DEVICE ROCPRIM_INLINE - void blocked_to_striped_shuffle(const T (&input)[ItemsPerThread], - U (&output)[ItemsPerThread]) + ROCPRIM_DEVICE ROCPRIM_INLINE void blocked_to_striped_shuffle(const T (&input)[ItemsPerThread], + U (&output)[ItemsPerThread]) { - static_assert(WarpSize % ItemsPerThread == 0, - "ItemsPerThread must be a divisor of WarpSize to use blocked_to_striped_shuffle"); - const unsigned int flat_id = ::rocprim::detail::logical_lane_id(); - U work_array[ItemsPerThread]; - - ROCPRIM_UNROLL - for(unsigned int dst_idx = 0; dst_idx < ItemsPerThread; dst_idx++) + static_assert( + WarpSize % ItemsPerThread == 0, + "ItemsPerThread must be a divisor of WarpSize to use blocked_to_striped_shuffle"); + if(WarpSize == ItemsPerThread) { - ROCPRIM_UNROLL - for(unsigned int src_idx = 0; src_idx < ItemsPerThread; src_idx++) - { - const auto value = ::rocprim::warp_shuffle( - input[src_idx], - flat_id / ItemsPerThread + dst_idx * (WarpSize / ItemsPerThread), - WarpSize); - if(src_idx == flat_id % ItemsPerThread) - { - work_array[dst_idx] = value; - } - } + blocked_striped_shuffle_efficient_impl(input, output); } - - ROCPRIM_UNROLL - for(unsigned int i = 0; i < ItemsPerThread; i++) + else { - output[i] = work_array[i]; + blocked_to_striped_shuffle_impl(input, output); } } @@ -285,6 +415,7 @@ class warp_exchange /// \brief Transposes a striped arrangement of items to a blocked arrangement /// across the warp, using warp shuffle operations. + /// Uses an optimized implementation for when WarpSize is equal to ItemsPerThread. /// Caution: this API is experimental. Performance might not be consistent. /// ItemsPerThread must be a divisor of WarpSize. /// @@ -313,36 +444,20 @@ class warp_exchange /// } /// \endcode template - ROCPRIM_DEVICE ROCPRIM_INLINE - void striped_to_blocked_shuffle(const T (&input)[ItemsPerThread], - U (&output)[ItemsPerThread]) + ROCPRIM_DEVICE ROCPRIM_INLINE void striped_to_blocked_shuffle(const T (&input)[ItemsPerThread], + U (&output)[ItemsPerThread]) { - static_assert(WarpSize % ItemsPerThread == 0, - "ItemsPerThread must be a divisor of WarpSize to use striped_to_blocked_shuffle"); - const unsigned int flat_id = ::rocprim::detail::logical_lane_id(); - U work_array[ItemsPerThread]; + static_assert( + WarpSize % ItemsPerThread == 0, + "ItemsPerThread must be a divisor of WarpSize to use striped_to_blocked_shuffle"); - ROCPRIM_UNROLL - for(unsigned int dst_idx = 0; dst_idx < ItemsPerThread; dst_idx++) + if(WarpSize == ItemsPerThread) { - ROCPRIM_UNROLL - for(unsigned int src_idx = 0; src_idx < ItemsPerThread; src_idx++) - { - const auto value - = ::rocprim::warp_shuffle(input[src_idx], - (ItemsPerThread * flat_id + dst_idx) % WarpSize, - WarpSize); - if(flat_id / (WarpSize / ItemsPerThread) == src_idx) - { - work_array[dst_idx] = value; - } - } + blocked_striped_shuffle_efficient_impl(input, output); } - - ROCPRIM_UNROLL - for(unsigned int i = 0; i < ItemsPerThread; i++) + else { - output[i] = work_array[i]; + striped_to_blocked_shuffle_impl(input, output); } } diff --git a/test/common_test_header.hpp b/test/common_test_header.hpp index 62f6afcfd..6e728cd80 100755 --- a/test/common_test_header.hpp +++ b/test/common_test_header.hpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2020-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2020-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -25,6 +25,8 @@ #include #include +#include +#include #include #include #include @@ -89,54 +91,65 @@ namespace test_common_utils { -inline int obtain_device_from_ctest() +inline char* __get_env(const char* name) { + char* env; #ifdef _MSC_VER -#pragma warning( push ) -#pragma warning( disable : 4996 ) // This function or variable may be unsafe. Consider using _dupenv_s instead. + size_t len; + errno_t err = _dupenv_s(&env, &len, name); + if(err) + { + return nullptr; + } +#else + env = std::getenv(name); #endif - static const std::string rg0 = "CTEST_RESOURCE_GROUP_0"; - if (std::getenv(rg0.c_str()) != nullptr) + return env; +} + +inline void clean_env(char* name) +{ +#ifdef _MSC_VER + if(name != nullptr) + { + free(name); + } +#endif + (void)name; +} + +inline int obtain_device_from_ctest() +{ + static const std::string rg0 = "CTEST_RESOURCE_GROUP_0"; + char* env = __get_env(rg0.c_str()); + int device = 0; + if(env != nullptr) { - std::string amdgpu_target = std::getenv(rg0.c_str()); + std::string amdgpu_target(env); std::transform( amdgpu_target.cbegin(), amdgpu_target.cend(), amdgpu_target.begin(), // Feeding std::toupper plainly results in implicitly truncating conversions between int and char triggering warnings. - [](unsigned char c){ return static_cast(std::toupper(c)); } - ); - std::string reqs = std::getenv((rg0 + "_" + amdgpu_target).c_str()); - return std::atoi(reqs.substr(reqs.find(':') + 1, reqs.find(',') - (reqs.find(':') + 1)).c_str()); + [](unsigned char c) { return static_cast(std::toupper(c)); }); + char* env_reqs = __get_env((rg0 + "_" + amdgpu_target).c_str()); + std::string reqs(env_reqs); + device = std::atoi( + reqs.substr(reqs.find(':') + 1, reqs.find(',') - (reqs.find(':') + 1)).c_str()); + clean_env(env_reqs); } - else - return 0; -#ifdef _MSC_VER -#pragma warning( pop ) -#endif + clean_env(env); + return device; } -#ifdef _MSC_VER - #pragma warning(push) - #pragma warning( \ - disable : 4996) // This function or variable may be unsafe. Consider using _dupenv_s instead. -#endif inline bool use_hmm() { - if(getenv("ROCPRIM_USE_HMM") == nullptr) - { - return false; - } - if(strcmp(getenv("ROCPRIM_USE_HMM"), "1") == 0) - { - return true; - } - return false; + char* env = __get_env("ROCPRIM_USE_HMM"); + const bool hmm = (env != nullptr) && (strcmp(env, "1") == 0); + clean_env(env); + return hmm; } -#ifdef _MSC_VER - #pragma warning(pop) -#endif // Helper for HMM allocations: HMM is requested through ROCPRIM_USE_HMM=1 environment variable template diff --git a/test/hipgraph/test_hipgraph_algs.cpp b/test/hipgraph/test_hipgraph_algs.cpp index 816cff1f4..ae98add4f 100644 --- a/test/hipgraph/test_hipgraph_algs.cpp +++ b/test/hipgraph/test_hipgraph_algs.cpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -24,10 +24,12 @@ #include "../common_test_header.hpp" // required rocprim headers -#include #include "../rocprim/test_seed.hpp" #include "../rocprim/test_utils.hpp" +#include +#include + // required STL headers #include @@ -107,9 +109,6 @@ TEST(TestHipGraphAlgs, SortAndSearch) hipStream_t stream = 0; HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); - // Begin graph capture - hipGraph_t graph = test_utils::createGraphHelper(stream); - // Get temporary storage size required for merge_sort. // Note: doing this inside a graph doesn't gain us any benefit, // since these calls run entirely on the host - however, it is @@ -139,11 +138,6 @@ TEST(TestHipGraphAlgs, SortAndSearch) debug_synchronous )); - // End graph capture (since we can't malloc the temp storage inside the graph) - // and execute the graph (to get the temp storage size) - hipGraphExec_t graph_instance; - graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); - // Allocate the temp storage // Note: a single store will be used for both the sort and search algorithms size_t temp_storage_bytes = std::max(sort_temp_storage_bytes, search_temp_storage_bytes); @@ -154,8 +148,8 @@ TEST(TestHipGraphAlgs, SortAndSearch) HIP_CHECK(test_common_utils::hipMallocHelper(&d_temp_storage, temp_storage_bytes)); HIP_CHECK(hipDeviceSynchronize()); - // Re-start graph capture - test_utils::resetGraphHelper(graph, graph_instance, stream); + // Begin graph capture + hipGraph_t graph = test_utils::createGraphHelper(stream); // Launch merge_sort HIP_CHECK( @@ -186,6 +180,7 @@ TEST(TestHipGraphAlgs, SortAndSearch) ); // End graph capture, but do not execute the graph yet. + hipGraphExec_t graph_instance; graph_instance = test_utils::endCaptureGraphHelper(graph, stream); std::vector sort_input; diff --git a/test/hipgraph/test_hipgraph_basic.cpp b/test/hipgraph/test_hipgraph_basic.cpp index 12d2c8739..cc473d680 100644 --- a/test/hipgraph/test_hipgraph_basic.cpp +++ b/test/hipgraph/test_hipgraph_basic.cpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -24,7 +24,6 @@ #include "../common_test_header.hpp" // required rocprim headers -#include #include "../rocprim/test_seed.hpp" #include "../rocprim/test_utils.hpp" diff --git a/test/rocprim/CMakeLists.txt b/test/rocprim/CMakeLists.txt index 6ab37e809..20ad34ae6 100644 --- a/test/rocprim/CMakeLists.txt +++ b/test/rocprim/CMakeLists.txt @@ -227,14 +227,14 @@ add_rocprim_test("rocprim.basic_test" "test_basic.cpp;detail/get_rocprim_version add_rocprim_test("rocprim.arg_index_iterator" test_arg_index_iterator.cpp) add_rocprim_test("rocprim.temporary_storage_partitioning" test_temporary_storage_partitioning.cpp) add_rocprim_test_parallel("rocprim.block_adjacent_difference" test_block_adjacent_difference.cpp.in) -add_rocprim_test("rocprim.block_discontinuity" test_block_discontinuity.cpp) +add_rocprim_test_parallel("rocprim.block_discontinuity" test_block_discontinuity.cpp.in) add_rocprim_test("rocprim.block_exchange" test_block_exchange.cpp) add_rocprim_test("rocprim.block_histogram" test_block_histogram.cpp) add_rocprim_test("rocprim.block_load_store" test_block_load_store.cpp) add_rocprim_test("rocprim.block_sort_merge" test_block_sort_merge.cpp) add_rocprim_test("rocprim.block_sort_merge_stable" test_block_sort_merge_stable.cpp) add_rocprim_test_parallel("rocprim.block_radix_rank" test_block_radix_rank.cpp.in) -add_rocprim_test("rocprim.block_radix_sort" test_block_radix_sort.cpp) +add_rocprim_test_parallel("rocprim.block_radix_sort" test_block_radix_sort.cpp.in) add_rocprim_test("rocprim.block_reduce" test_block_reduce.cpp) add_rocprim_test("rocprim.block_run_length_decode" test_block_run_length_decode.cpp) add_rocprim_test_parallel("rocprim.block_scan" test_block_scan.cpp.in) @@ -261,6 +261,8 @@ add_rocprim_test("rocprim.device_segmented_scan" test_device_segmented_scan.cpp) add_rocprim_test("rocprim.device_select" test_device_select.cpp) add_rocprim_test("rocprim.device_transform" test_device_transform.cpp) add_rocprim_test("rocprim.discard_iterator" test_discard_iterator.cpp) +add_rocprim_test("rocprim.radix_key_codec" test_radix_key_codec.cpp) +add_rocprim_test("rocprim.predicate_iterator" test_predicate_iterator.cpp) add_rocprim_test("rocprim.reverse_iterator" test_reverse_iterator.cpp) if(NOT USE_HIP_CPU) add_rocprim_test("rocprim.texture_cache_iterator" test_texture_cache_iterator.cpp) diff --git a/test/rocprim/test_block_adjacent_difference.cpp.in b/test/rocprim/test_block_adjacent_difference.cpp.in index dd32cfd39..f0d1eded7 100644 --- a/test/rocprim/test_block_adjacent_difference.cpp.in +++ b/test/rocprim/test_block_adjacent_difference.cpp.in @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2017-2022 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -53,6 +53,13 @@ struct Floating; #define warp_params BlockDiscParamsFloating #define name_suffix Floating +#elif ROCPRIM_TEST_SLICE == 2 + +struct FloatingHalf; +#define suite_name RocprimBlockAdjacentDifference +#define warp_params BlockDiscParamsFloatingHalf +#define name_suffix FloatingHalf + #endif #include "test_block_adjacent_difference.hpp" diff --git a/test/rocprim/test_block_adjacent_difference.kernels.hpp b/test/rocprim/test_block_adjacent_difference.kernels.hpp index 974751135..db7e52a65 100644 --- a/test/rocprim/test_block_adjacent_difference.kernels.hpp +++ b/test/rocprim/test_block_adjacent_difference.kernels.hpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2017-2022 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -66,7 +66,8 @@ void flag_heads_kernel(Type* device_input, long long* device_heads) Type input[ItemsPerThread]; rocprim::block_load_direct_blocked(lid, device_input + block_offset, input); - rocprim::block_adjacent_difference bdiscontinuity; + rocprim::block_adjacent_difference adjacent_difference; + __shared__ typename decltype(adjacent_difference)::storage_type storage; FlagType head_flags[ItemsPerThread]; @@ -75,11 +76,15 @@ void flag_heads_kernel(Type* device_input, long long* device_heads) if(blockIdx.x % 2 == 1) { const Type tile_predecessor_item = device_input[block_offset - 1]; - bdiscontinuity.flag_heads(head_flags, tile_predecessor_item, input, FlagOpType()); + adjacent_difference.flag_heads(head_flags, + tile_predecessor_item, + input, + FlagOpType(), + storage); } else { - bdiscontinuity.flag_heads(head_flags, input, FlagOpType()); + adjacent_difference.flag_heads(head_flags, input, FlagOpType(), storage); } ROCPRIM_CLANG_SUPPRESS_WARNING_POP @@ -255,7 +260,8 @@ void flag_tails_kernel(Type* device_input, long long* device_tails) Type input[ItemsPerThread]; rocprim::block_load_direct_blocked(lid, device_input + block_offset, input); - rocprim::block_adjacent_difference bdiscontinuity; + rocprim::block_adjacent_difference adjacent_difference; + __shared__ typename decltype(adjacent_difference)::storage_type storage; FlagType tail_flags[ItemsPerThread]; @@ -264,11 +270,15 @@ void flag_tails_kernel(Type* device_input, long long* device_tails) if(blockIdx.x % 2 == 0) { const Type tile_successor_item = device_input[block_offset + items_per_block]; - bdiscontinuity.flag_tails(tail_flags, tile_successor_item, input, FlagOpType()); + adjacent_difference.flag_tails(tail_flags, + tile_successor_item, + input, + FlagOpType(), + storage); } else { - bdiscontinuity.flag_tails(tail_flags, input, FlagOpType()); + adjacent_difference.flag_tails(tail_flags, input, FlagOpType(), storage); } ROCPRIM_CLANG_SUPPRESS_WARNING_POP @@ -293,7 +303,8 @@ void flag_heads_and_tails_kernel(Type* device_input, long long* device_heads, lo Type input[ItemsPerThread]; rocprim::block_load_direct_blocked(lid, device_input + block_offset, input); - rocprim::block_adjacent_difference bdiscontinuity; + rocprim::block_adjacent_difference adjacent_difference; + __shared__ typename decltype(adjacent_difference)::storage_type storage; FlagType head_flags[ItemsPerThread]; FlagType tail_flags[ItemsPerThread]; @@ -303,22 +314,42 @@ void flag_heads_and_tails_kernel(Type* device_input, long long* device_heads, lo if(blockIdx.x % 4 == 0) { const Type tile_successor_item = device_input[block_offset + items_per_block]; - bdiscontinuity.flag_heads_and_tails(head_flags, tail_flags, tile_successor_item, input, FlagOpType()); + adjacent_difference.flag_heads_and_tails(head_flags, + tail_flags, + tile_successor_item, + input, + FlagOpType(), + storage); } else if(blockIdx.x % 4 == 1) { const Type tile_predecessor_item = device_input[block_offset - 1]; - const Type tile_successor_item = device_input[block_offset + items_per_block]; - bdiscontinuity.flag_heads_and_tails(head_flags, tile_predecessor_item, tail_flags, tile_successor_item, input, FlagOpType()); + const Type tile_successor_item = device_input[block_offset + items_per_block]; + adjacent_difference.flag_heads_and_tails(head_flags, + tile_predecessor_item, + tail_flags, + tile_successor_item, + input, + FlagOpType(), + storage); } else if(blockIdx.x % 4 == 2) { const Type tile_predecessor_item = device_input[block_offset - 1]; - bdiscontinuity.flag_heads_and_tails(head_flags, tile_predecessor_item, tail_flags, input, FlagOpType()); + adjacent_difference.flag_heads_and_tails(head_flags, + tile_predecessor_item, + tail_flags, + input, + FlagOpType(), + storage); } else if(blockIdx.x % 4 == 3) { - bdiscontinuity.flag_heads_and_tails(head_flags, tail_flags, input, FlagOpType()); + adjacent_difference.flag_heads_and_tails(head_flags, + tail_flags, + input, + FlagOpType(), + storage); } ROCPRIM_CLANG_SUPPRESS_WARNING_POP @@ -339,11 +370,15 @@ auto test_block_adjacent_difference() { using type = Type; // std::vector is a special case that will cause an error in hipMemcpy + // rocprim::half/rocprim::bfloat16 are special cases that cannot be compared '==' + // in ASSERT_EQ using stored_flag_type = typename std::conditional< - std::is_same::value, - int, - FlagType - >::type; + std::is_same::value, + int, + typename std::conditional::value + || std::is_same::value, + float, + FlagType>::type>::type; using flag_type = FlagType; using flag_op_type = FlagOpType; static constexpr size_t block_size = BlockSize; @@ -378,8 +413,11 @@ auto test_block_adjacent_difference() if(ii == 0) { expected_heads[i] = bi % 2 == 1 - ? apply(flag_op, input[i - 1], input[i], ii) - : flag_type(true); + ? apply(flag_op, + input[i - 1], + input[i], + ii) + : stored_flag_type(true); } else { @@ -450,11 +488,15 @@ auto test_block_adjacent_difference() { using type = Type; // std::vector is a special case that will cause an error in hipMemcpy + // rocprim::half/rocprim::bfloat16 are special cases that cannot be compared '==' + // in ASSERT_EQ using stored_flag_type = typename std::conditional< - std::is_same::value, - int, - FlagType - >::type; + std::is_same::value, + int, + typename std::conditional::value + || std::is_same::value, + float, + FlagType>::type>::type; using flag_type = FlagType; using flag_op_type = FlagOpType; static constexpr size_t block_size = BlockSize; @@ -489,8 +531,11 @@ auto test_block_adjacent_difference() if(ii == items_per_block - 1) { expected_tails[i] = bi % 2 == 0 - ? apply(flag_op, input[i], input[i + 1], ii + 1) - : flag_type(true); + ? apply(flag_op, + input[i], + input[i + 1], + ii + 1) + : stored_flag_type(true); } else { @@ -561,11 +606,15 @@ auto test_block_adjacent_difference() { using type = Type; // std::vector is a special case that will cause an error in hipMemcpy + // rocprim::half/rocprim::bfloat16 are special cases that cannot be compared '==' + // in ASSERT_EQ using stored_flag_type = typename std::conditional< - std::is_same::value, - int, - FlagType - >::type; + std::is_same::value, + int, + typename std::conditional::value + || std::is_same::value, + float, + FlagType>::type>::type; using flag_type = FlagType; using flag_op_type = FlagOpType; static constexpr size_t block_size = BlockSize; @@ -602,8 +651,11 @@ auto test_block_adjacent_difference() if(ii == 0) { expected_heads[i] = (bi % 4 == 1 || bi % 4 == 2) - ? apply(flag_op, input[i - 1], input[i], ii) - : flag_type(true); + ? apply(flag_op, + input[i - 1], + input[i], + ii) + : stored_flag_type(true); } else { @@ -612,8 +664,11 @@ auto test_block_adjacent_difference() if(ii == items_per_block - 1) { expected_tails[i] = (bi % 4 == 0 || bi % 4 == 1) - ? apply(flag_op, input[i], input[i + 1], ii + 1) - : flag_type(true); + ? apply(flag_op, + input[i], + input[i + 1], + ii + 1) + : stored_flag_type(true); } else { diff --git a/test/rocprim/test_block_discontinuity.cpp b/test/rocprim/test_block_discontinuity.cpp.in similarity index 83% rename from test/rocprim/test_block_discontinuity.cpp rename to test/rocprim/test_block_discontinuity.cpp.in index 34ff72a23..239ced550 100644 --- a/test/rocprim/test_block_discontinuity.cpp +++ b/test/rocprim/test_block_discontinuity.cpp.in @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2017-2022 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -37,20 +37,29 @@ // Start stamping out tests struct RocprimBlockDiscontinuity; +#cmakedefine ROCPRIM_TEST_SLICE @ROCPRIM_TEST_SLICE@ + +#if ROCPRIM_TEST_SLICE == 0 + struct Integral; #define suite_name RocprimBlockDiscontinuity #define warp_params BlockDiscParamsIntegral #define name_suffix Integral -#include "test_block_discontinuity.hpp" - -#undef suite_name -#undef warp_params -#undef name_suffix +#elif ROCPRIM_TEST_SLICE == 1 struct Floating; #define suite_name RocprimBlockDiscontinuity #define warp_params BlockDiscParamsFloating #define name_suffix Floating +#elif ROCPRIM_TEST_SLICE == 2 + +struct FloatingHalf; +#define suite_name RocprimBlockDiscontinuity +#define warp_params BlockDiscParamsFloatingHalf +#define name_suffix FloatingHalf + +#endif + #include "test_block_discontinuity.hpp" diff --git a/test/rocprim/test_block_discontinuity.kernels.hpp b/test/rocprim/test_block_discontinuity.kernels.hpp index 0f2f79dd8..6a61e87a8 100644 --- a/test/rocprim/test_block_discontinuity.kernels.hpp +++ b/test/rocprim/test_block_discontinuity.kernels.hpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2017-2022 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -75,17 +75,18 @@ void flag_heads_kernel(Type* device_input, long long* device_heads) Type input[ItemsPerThread]; rocprim::block_load_direct_blocked(lid, device_input + block_offset, input); - rocprim::block_discontinuity bdiscontinuity; + rocprim::block_discontinuity bdiscontinuity; + __shared__ typename decltype(bdiscontinuity)::storage_type storage; FlagType head_flags[ItemsPerThread]; if(blockIdx.x % 2 == 1) { const Type tile_predecessor_item = device_input[block_offset - 1]; - bdiscontinuity.flag_heads(head_flags, tile_predecessor_item, input, FlagOpType()); + bdiscontinuity.flag_heads(head_flags, tile_predecessor_item, input, FlagOpType(), storage); } else { - bdiscontinuity.flag_heads(head_flags, input, FlagOpType()); + bdiscontinuity.flag_heads(head_flags, input, FlagOpType(), storage); } rocprim::block_store_direct_blocked(lid, device_heads + block_offset, head_flags); @@ -109,17 +110,18 @@ void flag_tails_kernel(Type* device_input, long long* device_tails) Type input[ItemsPerThread]; rocprim::block_load_direct_blocked(lid, device_input + block_offset, input); - rocprim::block_discontinuity bdiscontinuity; + rocprim::block_discontinuity bdiscontinuity; + __shared__ typename decltype(bdiscontinuity)::storage_type storage; FlagType tail_flags[ItemsPerThread]; if(blockIdx.x % 2 == 0) { const Type tile_successor_item = device_input[block_offset + items_per_block]; - bdiscontinuity.flag_tails(tail_flags, tile_successor_item, input, FlagOpType()); + bdiscontinuity.flag_tails(tail_flags, tile_successor_item, input, FlagOpType(), storage); } else { - bdiscontinuity.flag_tails(tail_flags, input, FlagOpType()); + bdiscontinuity.flag_tails(tail_flags, input, FlagOpType(), storage); } rocprim::block_store_direct_blocked(lid, device_tails + block_offset, tail_flags); @@ -143,29 +145,46 @@ void flag_heads_and_tails_kernel(Type* device_input, long long* device_heads, lo Type input[ItemsPerThread]; rocprim::block_load_direct_blocked(lid, device_input + block_offset, input); - rocprim::block_discontinuity bdiscontinuity; + rocprim::block_discontinuity bdiscontinuity; + __shared__ typename decltype(bdiscontinuity)::storage_type storage; FlagType head_flags[ItemsPerThread]; FlagType tail_flags[ItemsPerThread]; if(blockIdx.x % 4 == 0) { const Type tile_successor_item = device_input[block_offset + items_per_block]; - bdiscontinuity.flag_heads_and_tails(head_flags, tail_flags, tile_successor_item, input, FlagOpType()); + bdiscontinuity.flag_heads_and_tails(head_flags, + tail_flags, + tile_successor_item, + input, + FlagOpType(), + storage); } else if(blockIdx.x % 4 == 1) { const Type tile_predecessor_item = device_input[block_offset - 1]; - const Type tile_successor_item = device_input[block_offset + items_per_block]; - bdiscontinuity.flag_heads_and_tails(head_flags, tile_predecessor_item, tail_flags, tile_successor_item, input, FlagOpType()); + const Type tile_successor_item = device_input[block_offset + items_per_block]; + bdiscontinuity.flag_heads_and_tails(head_flags, + tile_predecessor_item, + tail_flags, + tile_successor_item, + input, + FlagOpType(), + storage); } else if(blockIdx.x % 4 == 2) { const Type tile_predecessor_item = device_input[block_offset - 1]; - bdiscontinuity.flag_heads_and_tails(head_flags, tile_predecessor_item, tail_flags, input, FlagOpType()); + bdiscontinuity.flag_heads_and_tails(head_flags, + tile_predecessor_item, + tail_flags, + input, + FlagOpType(), + storage); } else if(blockIdx.x % 4 == 3) { - bdiscontinuity.flag_heads_and_tails(head_flags, tail_flags, input, FlagOpType()); + bdiscontinuity.flag_heads_and_tails(head_flags, tail_flags, input, FlagOpType(), storage); } rocprim::block_store_direct_blocked(lid, device_heads + block_offset, head_flags); @@ -185,11 +204,15 @@ auto test_block_discontinuity() { using type = Type; // std::vector is a special case that will cause an error in hipMemcpy + // rocprim::half/rocprim::bfloat16 are special cases that cannot be compared '==' + // in ASSERT_EQ using stored_flag_type = typename std::conditional< - std::is_same::value, - int, - FlagType - >::type; + std::is_same::value, + int, + typename std::conditional::value + || std::is_same::value, + float, + FlagType>::type>::type; using flag_type = FlagType; using flag_op_type = FlagOpType; static constexpr size_t block_size = BlockSize; @@ -223,9 +246,8 @@ auto test_block_discontinuity() const size_t i = bi * items_per_block + ii; if(ii == 0) { - expected_heads[i] = bi % 2 == 1 - ? apply(flag_op, input[i - 1], input[i], ii) - : flag_type(true); + expected_heads[i] = bi % 2 == 1 ? apply(flag_op, input[i - 1], input[i], ii) + : stored_flag_type(true); } else { @@ -296,11 +318,15 @@ auto test_block_discontinuity() { using type = Type; // std::vector is a special case that will cause an error in hipMemcpy + // rocprim::half/rocprim::bfloat16 are special cases that cannot be compared '==' + // in ASSERT_EQ using stored_flag_type = typename std::conditional< - std::is_same::value, - int, - FlagType - >::type; + std::is_same::value, + int, + typename std::conditional::value + || std::is_same::value, + float, + FlagType>::type>::type; using flag_type = FlagType; using flag_op_type = FlagOpType; static constexpr size_t block_size = BlockSize; @@ -334,9 +360,8 @@ auto test_block_discontinuity() const size_t i = bi * items_per_block + ii; if(ii == items_per_block - 1) { - expected_tails[i] = bi % 2 == 0 - ? apply(flag_op, input[i], input[i + 1], ii + 1) - : flag_type(true); + expected_tails[i] = bi % 2 == 0 ? apply(flag_op, input[i], input[i + 1], ii + 1) + : stored_flag_type(true); } else { @@ -407,11 +432,15 @@ auto test_block_discontinuity() { using type = Type; // std::vector is a special case that will cause an error in hipMemcpy + // rocprim::half/rocprim::bfloat16 are special cases that cannot be compared '==' + // in ASSERT_EQ using stored_flag_type = typename std::conditional< - std::is_same::value, - int, - FlagType - >::type; + std::is_same::value, + int, + typename std::conditional::value + || std::is_same::value, + float, + FlagType>::type>::type; using flag_type = FlagType; using flag_op_type = FlagOpType; static constexpr size_t block_size = BlockSize; @@ -448,8 +477,8 @@ auto test_block_discontinuity() if(ii == 0) { expected_heads[i] = (bi % 4 == 1 || bi % 4 == 2) - ? apply(flag_op, input[i - 1], input[i], ii) - : flag_type(true); + ? apply(flag_op, input[i - 1], input[i], ii) + : stored_flag_type(true); } else { @@ -458,8 +487,8 @@ auto test_block_discontinuity() if(ii == items_per_block - 1) { expected_tails[i] = (bi % 4 == 0 || bi % 4 == 1) - ? apply(flag_op, input[i], input[i + 1], ii + 1) - : flag_type(true); + ? apply(flag_op, input[i], input[i + 1], ii + 1) + : stored_flag_type(true); } else { diff --git a/test/rocprim/test_block_exchange.kernels.hpp b/test/rocprim/test_block_exchange.kernels.hpp index 5c56fff43..ff916e03d 100644 --- a/test/rocprim/test_block_exchange.kernels.hpp +++ b/test/rocprim/test_block_exchange.kernels.hpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -198,7 +198,11 @@ auto test_block_exchange(int /*device_id*/) -> typename std::enable_if values(size); - std::iota(values.begin(), values.end(), 0); + test_utils::iota_modulo(values.begin(), + values.end(), + 0, + std::min(test_utils::numeric_limits::max(), + test_utils::numeric_limits::max())); for(size_t bi = 0; bi < size / items_per_block; bi++) { for(size_t ti = 0; ti < block_size; ti++) @@ -209,7 +213,7 @@ auto test_block_exchange(int /*device_id*/) -> typename std::enable_if(values[i1]); + expected[i0] = values[i1]; } } } @@ -279,7 +283,11 @@ auto test_block_exchange(int /*device_id*/) -> typename std::enable_if values(size); - std::iota(values.begin(), values.end(), 0); + test_utils::iota_modulo(values.begin(), + values.end(), + 0, + std::min(test_utils::numeric_limits::max(), + test_utils::numeric_limits::max())); for(size_t bi = 0; bi < size / items_per_block; bi++) { for(size_t ti = 0; ti < block_size; ti++) @@ -367,7 +375,11 @@ auto test_block_exchange(int device_id) -> typename std::enable_if: // Calculate input and expected results on host std::vector values(size); - std::iota(values.begin(), values.end(), 0); + test_utils::iota_modulo(values.begin(), + values.end(), + 0, + std::min(test_utils::numeric_limits::max(), + test_utils::numeric_limits::max())); for(size_t bi = 0; bi < size / items_per_block; bi++) { for(size_t wi = 0; wi < warps_no; wi++) @@ -463,7 +475,11 @@ auto test_block_exchange(int device_id) -> typename std::enable_if: // Calculate input and expected results on host std::vector values(size); - std::iota(values.begin(), values.end(), 0); + test_utils::iota_modulo(values.begin(), + values.end(), + 0, + std::min(test_utils::numeric_limits::max(), + test_utils::numeric_limits::max())); for(size_t bi = 0; bi < size / items_per_block; bi++) { for(size_t wi = 0; wi < warps_no; wi++) @@ -557,7 +573,11 @@ auto test_block_exchange(int /*device_id*/) -> typename std::enable_if values(size); - std::iota(values.begin(), values.end(), 0); + test_utils::iota_modulo(values.begin(), + values.end(), + 0, + std::min(test_utils::numeric_limits::max(), + test_utils::numeric_limits::max())); for(size_t bi = 0; bi < size / items_per_block; bi++) { for(size_t ti = 0; ti < block_size; ti++) @@ -656,7 +676,11 @@ auto test_block_exchange(int /*device_id*/) -> typename std::enable_if values(size); - std::iota(values.begin(), values.end(), 0); + test_utils::iota_modulo(values.begin(), + values.end(), + 0, + std::min(test_utils::numeric_limits::max(), + test_utils::numeric_limits::max())); for(size_t bi = 0; bi < size / items_per_block; bi++) { for(size_t ti = 0; ti < block_size; ti++) diff --git a/test/rocprim/test_block_histogram.kernels.hpp b/test/rocprim/test_block_histogram.kernels.hpp index bd308b396..dd56ca87e 100644 --- a/test/rocprim/test_block_histogram.kernels.hpp +++ b/test/rocprim/test_block_histogram.kernels.hpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -32,6 +32,9 @@ #include "../common_test_header.hpp" #include "test_utils_types.hpp" +#include +#include + template< unsigned int BlockSize, unsigned int ItemsPerThread, @@ -107,7 +110,11 @@ void test_block_histogram_input_arrays() SCOPED_TRACE(testing::Message() << "with ItemsPerThread = " << items_per_thread); // Generate data - std::vector output = test_utils::get_random_data(size, 0, bin - 1, seed_value); + std::vector output = test_utils::get_random_data( + size, + 0, + std::min(std::numeric_limits::max(), bin - 1), + seed_value); // Output histogram results std::vector output_bin(bin_sizes, 0); diff --git a/test/rocprim/test_block_load_store.kernels.hpp b/test/rocprim/test_block_load_store.kernels.hpp index e066fe360..a9779d201 100644 --- a/test/rocprim/test_block_load_store.kernels.hpp +++ b/test/rocprim/test_block_load_store.kernels.hpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2017-2021 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -25,62 +25,138 @@ typedef ::testing::Types< // block_load_direct - class_params, - class_params, - class_params, - class_params, - class_params, - class_params, - class_params, + class_params, + class_params, + class_params, + class_params, + class_params, + class_params, + class_params, - class_params, - class_params, - class_params, - class_params, - class_params, - class_params, - class_params, + class_params, + class_params, + class_params, + class_params, + class_params, + class_params, + class_params, - class_params, rocprim::block_load_method::block_load_direct, - rocprim::block_store_method::block_store_direct, 64U, 1>, - class_params, rocprim::block_load_method::block_load_direct, - rocprim::block_store_method::block_store_direct, 64U, 5>, - class_params, rocprim::block_load_method::block_load_direct, - rocprim::block_store_method::block_store_direct, 256U, 1>, - class_params, rocprim::block_load_method::block_load_direct, - rocprim::block_store_method::block_store_direct, 256U, 4>, + class_params, + rocprim::block_load_method::block_load_direct, + rocprim::block_store_method::block_store_direct, + 64U, + 1>, + class_params, + rocprim::block_load_method::block_load_direct, + rocprim::block_store_method::block_store_direct, + 64U, + 5>, + class_params, + rocprim::block_load_method::block_load_direct, + rocprim::block_store_method::block_store_direct, + 256U, + 1>, + class_params, + rocprim::block_load_method::block_load_direct, + rocprim::block_store_method::block_store_direct, + 256U, + 4>, // block_load_vectorize - class_params, - class_params, - class_params, - class_params, - class_params, - class_params, - class_params + class_params, + class_params, + class_params, + class_params, + class_params, + class_params, + class_params -> ClassParamsFirstPart; + > + ClassParamsFirstPart; typedef ::testing::Types< diff --git a/test/rocprim/test_block_radix_sort.cpp b/test/rocprim/test_block_radix_sort.cpp.in similarity index 72% rename from test/rocprim/test_block_radix_sort.cpp rename to test/rocprim/test_block_radix_sort.cpp.in index bcc032b2b..1fdf44f33 100644 --- a/test/rocprim/test_block_radix_sort.cpp +++ b/test/rocprim/test_block_radix_sort.cpp.in @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -24,13 +24,12 @@ // required rocprim headers #include -#include #include - +#include // required test headers -#include "test_utils_types.hpp" #include "test_utils_sort_comparator.hpp" +#include "test_utils_types.hpp" // kernel definitions #include "test_block_radix_sort.kernels.hpp" @@ -38,20 +37,38 @@ // Start stamping out tests struct RocprimBlockRadixSort; +#cmakedefine ROCPRIM_TEST_SLICE @ROCPRIM_TEST_SLICE@ + +#if ROCPRIM_TEST_SLICE == 0 + struct Integral; #define suite_name RocprimBlockRadixSort #define warp_params BlockParamsIntegralExtended #define name_suffix Integral -#include "test_block_radix_sort.hpp" - -#undef suite_name -#undef warp_params -#undef name_suffix +#elif ROCPRIM_TEST_SLICE == 1 struct Floating; #define suite_name RocprimBlockRadixSort #define warp_params BlockParamsFloating #define name_suffix Floating +#elif ROCPRIM_TEST_SLICE == 2 + +typedef ::testing::Types< + block_param_type(test_utils::custom_test_type, int) + , block_param_type(test_utils::custom_test_type, int8_t) + , block_param_type(test_utils::custom_test_type, uint16_t) +#if ROCPRIM_HAS_INT128_SUPPORT + , block_param_type(test_utils::custom_test_type, double) +#endif +> BlockParamsCustom; + +struct Custom; +#define suite_name RocprimBlockRadixSort +#define warp_params BlockParamsCustom +#define name_suffix Custom + +#endif + #include "test_block_radix_sort.hpp" diff --git a/test/rocprim/test_block_radix_sort.kernels.hpp b/test/rocprim/test_block_radix_sort.kernels.hpp index d51b5d181..62770f19a 100644 --- a/test/rocprim/test_block_radix_sort.kernels.hpp +++ b/test/rocprim/test_block_radix_sort.kernels.hpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -40,69 +40,75 @@ static constexpr unsigned int end_radix[n_sizes] = { 0, 0, 0, 10, 11, 12, 0, 0, 0, 10, 11, 12 }; -template< - unsigned int BlockSize, - unsigned int ItemsPerThread, - class key_type -> -__global__ -__launch_bounds__(BlockSize) -void sort_key_kernel( - key_type* device_keys_output, - bool to_striped, - bool descending, - unsigned int start_bit, - unsigned int end_bit) +static constexpr unsigned int bits_per_pass_radix[n_sizes] = {4, 3, 1, 1, 3, 4, 4, 3, 1, 1, 3, 4}; + +template +__global__ __launch_bounds__(BlockSize) void sort_key_kernel(key_type* device_keys_output, + bool to_striped, + bool descending, + unsigned int start_bit, + unsigned int end_bit) { constexpr unsigned int items_per_block = BlockSize * ItemsPerThread; const unsigned int lid = threadIdx.x; const unsigned int block_offset = blockIdx.x * items_per_block; key_type keys[ItemsPerThread]; -#ifdef __HIP_CPU_RT__ - // TODO: check if it's really neccessary - // Initialize contents, as non-hipcc compilers don't unconditionally zero out allocated memory - std::memset(keys, 0, ItemsPerThread * sizeof(key_type)); -#endif rocprim::block_load_direct_blocked(lid, device_keys_output + block_offset, keys); - rocprim::block_radix_sort bsort; + rocprim::block_radix_sort + bsort; + + test_utils::select_decomposer_t decomposer{}; if(to_striped) { if(descending) - bsort.sort_desc_to_striped(keys, start_bit, end_bit); + { + bsort.sort_desc_to_striped(keys, start_bit, end_bit, decomposer); + } else - bsort.sort_to_striped(keys, start_bit, end_bit); + { + bsort.sort_to_striped(keys, start_bit, end_bit, decomposer); + } rocprim::block_store_direct_striped(lid, device_keys_output + block_offset, keys); } else { if(descending) - bsort.sort_desc(keys, start_bit, end_bit); + { + bsort.sort_desc(keys, start_bit, end_bit, decomposer); + } else - bsort.sort(keys, start_bit, end_bit); + { + bsort.sort(keys, start_bit, end_bit, decomposer); + } rocprim::block_store_direct_blocked(lid, device_keys_output + block_offset, keys); } } -template< - unsigned int BlockSize, - unsigned int ItemsPerThread, - class key_type, - class value_type -> -__global__ -__launch_bounds__(BlockSize) -void sort_key_value_kernel( - key_type* device_keys_output, - value_type* device_values_output, - bool to_striped, - bool descending, - unsigned int start_bit, - unsigned int end_bit) +template +__global__ __launch_bounds__(BlockSize) void sort_key_value_kernel(key_type* device_keys_output, + value_type* device_values_output, + bool to_striped, + bool descending, + unsigned int start_bit, + unsigned int end_bit) { constexpr unsigned int items_per_block = BlockSize * ItemsPerThread; const unsigned int lid = threadIdx.x; @@ -113,13 +119,20 @@ void sort_key_value_kernel( rocprim::block_load_direct_blocked(lid, device_keys_output + block_offset, keys); rocprim::block_load_direct_blocked(lid, device_values_output + block_offset, values); - rocprim::block_radix_sort bsort; + rocprim:: + block_radix_sort + bsort; + test_utils::select_decomposer_t decomposer{}; if(to_striped) { if(descending) - bsort.sort_desc_to_striped(keys, values, start_bit, end_bit); + { + bsort.sort_desc_to_striped(keys, values, start_bit, end_bit, decomposer); + } else - bsort.sort_to_striped(keys, values, start_bit, end_bit); + { + bsort.sort_to_striped(keys, values, start_bit, end_bit, decomposer); + } rocprim::block_store_direct_striped(lid, device_keys_output + block_offset, keys); rocprim::block_store_direct_striped(lid, device_values_output + block_offset, values); @@ -127,9 +140,13 @@ void sort_key_value_kernel( else { if(descending) - bsort.sort_desc(keys, values, start_bit, end_bit); + { + bsort.sort_desc(keys, values, start_bit, end_bit, decomposer); + } else - bsort.sort(keys, values, start_bit, end_bit); + { + bsort.sort(keys, values, start_bit, end_bit, decomposer); + } rocprim::block_store_direct_blocked(lid, device_keys_output + block_offset, keys); rocprim::block_store_direct_blocked(lid, device_values_output + block_offset, values); @@ -137,28 +154,27 @@ void sort_key_value_kernel( } // Test for radix sort -template< - class Key, - class Value, - unsigned int Method, - unsigned int BlockSize, - unsigned int ItemsPerThread, - bool Descending = false, - bool ToStriped = false, - unsigned int StartBit = 0, - unsigned int EndBit = sizeof(Key) * 8 -> -auto test_block_radix_sort() --> typename std::enable_if::type +template +auto test_block_radix_sort() -> typename std::enable_if::type { - using key_type = Key; - static constexpr size_t block_size = BlockSize; - static constexpr size_t items_per_thread = ItemsPerThread; - static constexpr bool descending = Descending; - static constexpr bool to_striped = ToStriped; - static constexpr unsigned int start_bit = StartBit; - static constexpr unsigned int end_bit = EndBit; - static constexpr size_t items_per_block = block_size * items_per_thread; + using key_type = Key; + static constexpr size_t block_size = BlockSize; + static constexpr size_t items_per_thread = ItemsPerThread; + static constexpr unsigned radix_bits_per_pass = RadixBitsPerPass; + static constexpr bool descending = Descending; + static constexpr bool to_striped = ToStriped; + static constexpr unsigned int start_bit = StartBit; + static constexpr unsigned int end_bit = EndBit; + static constexpr size_t items_per_block = block_size * items_per_thread; // Given block size not supported if(block_size > test_utils::get_max_block_size()) @@ -174,24 +190,25 @@ auto test_block_radix_sort() unsigned int seed_value = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count]; SCOPED_TRACE(testing::Message() << "with seed = " << seed_value); + engine_type rng_engine(seed_value); + // Generate data - std::vector keys_output; + auto keys_output = std::make_unique(size); if(rocprim::is_floating_point::value) { - keys_output = test_utils::get_random_data(size, -100, +100, seed_value); + test_utils::generate_random_data_n(keys_output.get(), size, -100, +100, rng_engine); } else { - keys_output = test_utils::get_random_data( - size, - std::numeric_limits::min(), - std::numeric_limits::max(), - seed_value - ); + test_utils::generate_random_data_n(keys_output.get(), + size, + test_utils::numeric_limits::min(), + test_utils::numeric_limits::max(), + rng_engine); } // Calculate expected results on host - std::vector expected(keys_output); + std::vector expected(keys_output.get(), keys_output.get() + size); for(size_t i = 0; i < size / items_per_block; i++) { std::stable_sort( @@ -203,63 +220,61 @@ auto test_block_radix_sort() // Preparing device key_type* device_keys_output; - HIP_CHECK(test_common_utils::hipMallocHelper(&device_keys_output, keys_output.size() * sizeof(key_type))); - - HIP_CHECK( - hipMemcpy( - device_keys_output, keys_output.data(), - keys_output.size() * sizeof(typename decltype(keys_output)::value_type), - hipMemcpyHostToDevice - ) - ); - - // Running kernel - hipLaunchKernelGGL( - HIP_KERNEL_NAME(sort_key_kernel), - dim3(grid_size), dim3(block_size), 0, 0, - device_keys_output, to_striped, descending, start_bit, end_bit - ); + HIP_CHECK(test_common_utils::hipMallocHelper(&device_keys_output, size * sizeof(key_type))); + + HIP_CHECK(hipMemcpy(device_keys_output, + keys_output.get(), + size * sizeof(keys_output[0]), + hipMemcpyHostToDevice)); + + sort_key_kernel + <<>>(device_keys_output, + to_striped, + descending, + start_bit, + end_bit); HIP_CHECK(hipGetLastError()); // Getting results to host - HIP_CHECK( - hipMemcpy( - keys_output.data(), device_keys_output, - keys_output.size() * sizeof(typename decltype(keys_output)::value_type), - hipMemcpyDeviceToHost - ) - ); + HIP_CHECK(hipMemcpy(keys_output.get(), + device_keys_output, + size * sizeof(keys_output[0]), + hipMemcpyDeviceToHost)); // Verifying results - ASSERT_NO_FATAL_FAILURE(test_utils::assert_eq(keys_output, expected)); + ASSERT_NO_FATAL_FAILURE(test_utils::assert_eq(keys_output.get(), + keys_output.get() + size, + expected.begin(), + expected.end())); HIP_CHECK(hipFree(device_keys_output)); } } -template< - class Key, - class Value, - unsigned int Method, - unsigned int BlockSize, - unsigned int ItemsPerThread, - bool Descending = false, - bool ToStriped = false, - unsigned int StartBit = 0, - unsigned int EndBit = sizeof(Key) * 8 -> -auto test_block_radix_sort() --> typename std::enable_if::type +template +auto test_block_radix_sort() -> typename std::enable_if::type { - using key_type = Key; - using value_type = Value; - static constexpr size_t block_size = BlockSize; - static constexpr size_t items_per_thread = ItemsPerThread; - static constexpr bool descending = Descending; - static constexpr bool to_striped = ToStriped; - static constexpr unsigned int start_bit = (rocprim::is_unsigned::value == false) ? 0 : StartBit; - static constexpr unsigned int end_bit = (rocprim::is_unsigned::value == false) ? sizeof(Key) * 8 : EndBit; + using key_type = Key; + using value_type = Value; + static constexpr size_t block_size = BlockSize; + static constexpr size_t items_per_thread = ItemsPerThread; + static constexpr unsigned radix_bits_per_pass = RadixBitsPerPass; + static constexpr bool descending = Descending; + static constexpr bool to_striped = ToStriped; + static constexpr unsigned int start_bit + = (rocprim::is_unsigned::value == false) ? 0 : StartBit; + static constexpr unsigned int end_bit + = (rocprim::is_unsigned::value == false) ? sizeof(Key) * 8 : EndBit; static constexpr size_t items_per_block = block_size * items_per_thread; // Given block size not supported @@ -276,20 +291,21 @@ auto test_block_radix_sort() seed_type seed_value = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count]; SCOPED_TRACE(testing::Message() << "with seed = " << seed_value); + engine_type rng_engine(seed_value); + // Generate data - std::vector keys_output; + auto keys_output = std::make_unique(size); if(rocprim::is_floating_point::value) { - keys_output = test_utils::get_random_data(size, -100, +100, seed_value); + test_utils::generate_random_data_n(keys_output.get(), size, -100, +100, rng_engine); } else { - keys_output = test_utils::get_random_data( - size, - std::numeric_limits::min(), - std::numeric_limits::max(), - seed_value - ); + test_utils::generate_random_data_n(keys_output.get(), + size, + test_utils::numeric_limits::min(), + test_utils::numeric_limits::max(), + rng_engine); } std::vector values_output = test_utils::get_random_data(size, 0, 100, seed_value); @@ -321,17 +337,14 @@ auto test_block_radix_sort() } key_type* device_keys_output; - HIP_CHECK(test_common_utils::hipMallocHelper(&device_keys_output, keys_output.size() * sizeof(key_type))); + HIP_CHECK(test_common_utils::hipMallocHelper(&device_keys_output, size * sizeof(key_type))); value_type* device_values_output; HIP_CHECK(test_common_utils::hipMallocHelper(&device_values_output, values_output.size() * sizeof(value_type))); - HIP_CHECK( - hipMemcpy( - device_keys_output, keys_output.data(), - keys_output.size() * sizeof(typename decltype(keys_output)::value_type), - hipMemcpyHostToDevice - ) - ); + HIP_CHECK(hipMemcpy(device_keys_output, + keys_output.get(), + size * sizeof(keys_output[0]), + hipMemcpyHostToDevice)); HIP_CHECK( hipMemcpy( @@ -342,21 +355,24 @@ auto test_block_radix_sort() ); // Running kernel - hipLaunchKernelGGL( - HIP_KERNEL_NAME(sort_key_value_kernel), - dim3(grid_size), dim3(block_size), 0, 0, - device_keys_output, device_values_output, to_striped, descending, start_bit, end_bit - ); + sort_key_value_kernel + <<>>(device_keys_output, + device_values_output, + to_striped, + descending, + start_bit, + end_bit); HIP_CHECK(hipGetLastError()); // Getting results to host - HIP_CHECK( - hipMemcpy( - keys_output.data(), device_keys_output, - keys_output.size() * sizeof(typename decltype(keys_output)::value_type), - hipMemcpyDeviceToHost - ) - ); + HIP_CHECK(hipMemcpy(keys_output.get(), + device_keys_output, + size * sizeof(keys_output[0]), + hipMemcpyDeviceToHost)); HIP_CHECK( hipMemcpy( @@ -366,7 +382,10 @@ auto test_block_radix_sort() ) ); - ASSERT_NO_FATAL_FAILURE(test_utils::assert_eq(keys_output, keys_expected)); + ASSERT_NO_FATAL_FAILURE(test_utils::assert_eq(keys_output.get(), + keys_output.get() + size, + keys_expected.begin(), + keys_expected.end())); ASSERT_NO_FATAL_FAILURE(test_utils::assert_eq(values_output, values_expected)); HIP_CHECK(hipFree(device_keys_output)); @@ -390,7 +409,16 @@ struct static_for static void run() { - test_block_radix_sort(); + test_block_radix_sort(); static_for::run(); } }; diff --git a/test/rocprim/test_block_reduce.hpp b/test/rocprim/test_block_reduce.hpp index 14326bc99..e7f736e70 100644 --- a/test/rocprim/test_block_reduce.hpp +++ b/test/rocprim/test_block_reduce.hpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -38,7 +38,6 @@ typed_test_def(suite_name_single, name_suffix, Reduce) using binary_op_type_host = typename test_utils::select_plus_operator_host::type; binary_op_type_host binary_op_host; using acc_type = typename test_utils::select_plus_operator_host::acc_type; - using cast_type = typename test_utils::select_plus_operator_host::cast_type; constexpr size_t block_size = TestFixture::block_size; @@ -70,7 +69,7 @@ typed_test_def(suite_name_single, name_suffix, Reduce) auto idx = i * block_size + j; value = binary_op_host(value, output[idx]); } - expected_reductions[i] = static_cast(value); + expected_reductions[i] = static_cast(value); } // Preparing device @@ -106,8 +105,7 @@ typed_test_def(suite_name_single, name_suffix, ReduceMultiplies) using T = typename TestFixture::input_type; using binary_op_type = rocprim::multiplies; constexpr size_t block_size = TestFixture::block_size; - using cast_type = typename test_utils::select_plus_operator_host::cast_type; - + // Given block size not supported if(block_size > test_utils::get_max_block_size()) { @@ -137,7 +135,7 @@ typed_test_def(suite_name_single, name_suffix, ReduceMultiplies) auto idx = i * block_size + j; value *= static_cast(output[idx]); } - expected_reductions[i] = static_cast(value); + expected_reductions[i] = static_cast(value); } // Preparing device @@ -268,7 +266,6 @@ typed_test_def(suite_name_single, name_suffix, ReduceValid) using binary_op_type_host = typename test_utils::select_plus_operator_host::type; binary_op_type_host binary_op_host; using acc_type = typename test_utils::select_plus_operator_host::acc_type; - using cast_type = typename test_utils::select_plus_operator_host::cast_type; constexpr size_t block_size = TestFixture::block_size; @@ -302,7 +299,7 @@ typed_test_def(suite_name_single, name_suffix, ReduceValid) auto idx = i * block_size + j; value = binary_op_host(value, output[idx]); } - expected_reductions[i] = static_cast(value); + expected_reductions[i] = static_cast(value); } // Preparing device diff --git a/test/rocprim/test_block_run_length_decode.cpp b/test/rocprim/test_block_run_length_decode.cpp index c16853af7..4afd78166 100644 --- a/test/rocprim/test_block_run_length_decode.cpp +++ b/test/rocprim/test_block_run_length_decode.cpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -57,6 +57,8 @@ class HipcubBlockRunLengthDecodeTest : public ::testing::Test using HipcubBlockRunLengthDecodeTestParams = ::testing::Types, + Params, + Params, Params, Params, Params, diff --git a/test/rocprim/test_block_scan.hpp b/test/rocprim/test_block_scan.hpp index 4766d92d9..f713d4081 100644 --- a/test/rocprim/test_block_scan.hpp +++ b/test/rocprim/test_block_scan.hpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2017-2022 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -33,7 +33,7 @@ typed_test_def(suite_name_single, name_suffix, InclusiveScan) using binary_op_type_host = typename test_utils::select_plus_operator_host::type; binary_op_type_host binary_op_host; using acc_type = typename test_utils::select_plus_operator_host::acc_type; - using cast_type = typename test_utils::select_plus_operator_host::cast_type; + constexpr size_t block_size = TestFixture::block_size; int device_id = test_common_utils::obtain_device_from_ctest(); @@ -67,7 +67,7 @@ typed_test_def(suite_name_single, name_suffix, InclusiveScan) { auto idx = i * block_size + j; accumulator = binary_op_host(output[idx], accumulator); - expected[idx] = static_cast(accumulator); + expected[idx] = static_cast(accumulator); } } @@ -97,7 +97,7 @@ typed_test_def(suite_name_single, name_suffix, InclusiveScanReduce) using binary_op_type_host = typename test_utils::select_plus_operator_host::type; binary_op_type_host binary_op_host; using acc_type = typename test_utils::select_plus_operator_host::acc_type; - using cast_type = typename test_utils::select_plus_operator_host::cast_type; + constexpr size_t block_size = TestFixture::block_size; int device_id = test_common_utils::obtain_device_from_ctest(); @@ -133,7 +133,7 @@ typed_test_def(suite_name_single, name_suffix, InclusiveScanReduce) { auto idx = i * block_size + j; accumulator = binary_op_host(output[idx], accumulator); - expected[idx] = static_cast(accumulator); + expected[idx] = static_cast(accumulator); } expected_reductions[i] = expected[(i+1) * block_size - 1]; } @@ -172,7 +172,6 @@ typed_test_def(suite_name_single, name_suffix, InclusiveScanPrefixCallback) using binary_op_type_host = typename test_utils::select_plus_operator_host::type; binary_op_type_host binary_op_host; using acc_type = typename test_utils::select_plus_operator_host::acc_type; - using cast_type = typename test_utils::select_plus_operator_host::cast_type; constexpr size_t block_size = TestFixture::block_size; @@ -210,7 +209,7 @@ typed_test_def(suite_name_single, name_suffix, InclusiveScanPrefixCallback) { auto idx = i * block_size + j; accumulator = binary_op_host(output[idx], accumulator); - expected[idx] = static_cast(accumulator); + expected[idx] = static_cast(accumulator); } expected_block_prefixes[i] = expected[(i+1) * block_size - 1]; } @@ -249,7 +248,6 @@ typed_test_def(suite_name_single, name_suffix, ExclusiveScan) using binary_op_type_host = typename test_utils::select_plus_operator_host::type; binary_op_type_host binary_op_host; using acc_type = typename test_utils::select_plus_operator_host::acc_type; - using cast_type = typename test_utils::select_plus_operator_host::cast_type; constexpr size_t block_size = TestFixture::block_size; @@ -286,7 +284,7 @@ typed_test_def(suite_name_single, name_suffix, ExclusiveScan) { auto idx = i * block_size + j; accumulator = binary_op_host(output[idx-1], accumulator); - expected[idx] = static_cast(accumulator); + expected[idx] = static_cast(accumulator); } } @@ -316,7 +314,6 @@ typed_test_def(suite_name_single, name_suffix, ExclusiveScanReduce) using binary_op_type_host = typename test_utils::select_plus_operator_host::type; binary_op_type_host binary_op_host; using acc_type = typename test_utils::select_plus_operator_host::acc_type; - using cast_type = typename test_utils::select_plus_operator_host::cast_type; constexpr size_t block_size = TestFixture::block_size; @@ -356,7 +353,7 @@ typed_test_def(suite_name_single, name_suffix, ExclusiveScanReduce) { auto idx = i * block_size + j; accumulator = binary_op_host(output[idx-1], accumulator); - expected[idx] = static_cast(accumulator); + expected[idx] = static_cast(accumulator); } acc_type accumulator_reductions(0); expected_reductions[i] = 0; @@ -364,7 +361,7 @@ typed_test_def(suite_name_single, name_suffix, ExclusiveScanReduce) { auto idx = i * block_size + j; accumulator_reductions = binary_op_host(accumulator_reductions, output[idx]); - expected_reductions[i] = static_cast(accumulator_reductions); + expected_reductions[i] = static_cast(accumulator_reductions); } } @@ -402,7 +399,6 @@ typed_test_def(suite_name_single, name_suffix, ExclusiveScanPrefixCallback) using binary_op_type_host = typename test_utils::select_plus_operator_host::type; binary_op_type_host binary_op_host; using acc_type = typename test_utils::select_plus_operator_host::acc_type; - using cast_type = typename test_utils::select_plus_operator_host::cast_type; constexpr size_t block_size = TestFixture::block_size; @@ -441,7 +437,7 @@ typed_test_def(suite_name_single, name_suffix, ExclusiveScanPrefixCallback) { auto idx = i * block_size + j; accumulator = binary_op_host(output[idx-1], accumulator); - expected[idx] = static_cast(accumulator); + expected[idx] = static_cast(accumulator); } acc_type accumulator_block_prefixes = block_prefix; @@ -450,7 +446,7 @@ typed_test_def(suite_name_single, name_suffix, ExclusiveScanPrefixCallback) { auto idx = i * block_size + j; accumulator_block_prefixes = binary_op_host(output[idx], accumulator_block_prefixes); - expected_block_prefixes[i] = static_cast(accumulator_block_prefixes); + expected_block_prefixes[i] = static_cast(accumulator_block_prefixes); } } diff --git a/test/rocprim/test_block_sort.hpp b/test/rocprim/test_block_sort.hpp index 654764fd9..7757434be 100644 --- a/test/rocprim/test_block_sort.hpp +++ b/test/rocprim/test_block_sort.hpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -22,9 +22,9 @@ #include "test_utils_sort_comparator.hpp" -block_sort_test_suite_type_def(suite_name, name_suffix) +block_sort_test_suite_type_def(suite_name, name_suffix); - typed_test_suite_def(suite_name, name_suffix, block_params); +typed_test_suite_def(suite_name, name_suffix, block_params); // using header guards for these test functions because this file is included multiple times: // once for the integrals test suite and once for the floating point test suite. diff --git a/test/rocprim/test_device_adjacent_difference.cpp b/test/rocprim/test_device_adjacent_difference.cpp index fa97eecb6..97a3ee954 100644 --- a/test/rocprim/test_device_adjacent_difference.cpp +++ b/test/rocprim/test_device_adjacent_difference.cpp @@ -27,6 +27,7 @@ #include +#include "rocprim/types.hpp" #include #include #include @@ -237,6 +238,9 @@ using custom_size_limit_config using RocprimDeviceAdjacentDifferenceTestsParams = ::testing::Types< // Tests with default configuration DeviceAdjacentDifferenceParams, + DeviceAdjacentDifferenceParams, + DeviceAdjacentDifferenceParams, + DeviceAdjacentDifferenceParams, DeviceAdjacentDifferenceParams, DeviceAdjacentDifferenceParams, DeviceAdjacentDifferenceParams, @@ -247,13 +251,15 @@ using RocprimDeviceAdjacentDifferenceTestsParams = ::testing::Types< true, api_variant::in_place, false>, - // this is changed to not use identity iterator - // because the function doesn't work with it, should be changed back, when fixed DeviceAdjacentDifferenceParams, + true>, + // Tests for void value_type + DeviceAdjacentDifferenceParams, + DeviceAdjacentDifferenceParams, + DeviceAdjacentDifferenceParams, // Tests for supported config structs DeviceAdjacentDifferenceParams( d_input); - hipGraph_t graph; - hipGraphExec_t graph_instance; - if(TestFixture::use_graphs) - graph = test_utils::createGraphHelper(stream); - // Allocate temporary storage std::size_t temp_storage_size; void* d_temp_storage = nullptr; @@ -364,19 +365,24 @@ TYPED_TEST(RocprimDeviceAdjacentDifferenceTests, AdjacentDifference) stream, debug_synchronous)); - if(TestFixture::use_graphs) - graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); - ASSERT_GT(temp_storage_size, 0); HIP_CHECK(test_common_utils::hipMallocHelper(&d_temp_storage, temp_storage_size)); + hipGraph_t graph; + hipGraphExec_t graph_instance; + // We might call the API multiple times, with almost the same parameter // (in-place and out-of-place) // we should be able to use the same amount of temp storage for and get the same // results (maybe with different types) for both. auto run_and_verify = [&](const auto output_it, auto* d_output) { + if(TestFixture::use_graphs) + { + graph = test_utils::createGraphHelper(stream); + } + // Run HIP_CHECK(dispatch_adjacent_difference(left_tag, alias_tag, @@ -416,6 +422,11 @@ TYPED_TEST(RocprimDeviceAdjacentDifferenceTests, AdjacentDifference) output, expected, std::max(test_utils::precision, test_utils::precision)); + + if(TestFixture::use_graphs) + { + test_utils::cleanupGraphHelper(graph, graph_instance); + } }; // if api_variant is not in_place we should check the non aliased function call @@ -424,9 +435,6 @@ TYPED_TEST(RocprimDeviceAdjacentDifferenceTests, AdjacentDifference) output_type* d_output = nullptr; HIP_CHECK(test_common_utils::hipMallocHelper(&d_output, size * sizeof(*d_output))); - if(TestFixture::use_graphs) - test_utils::resetGraphHelper(graph, graph_instance, stream); - const auto output_it = test_utils::wrap_in_identity_iterator(d_output); @@ -438,15 +446,11 @@ TYPED_TEST(RocprimDeviceAdjacentDifferenceTests, AdjacentDifference) // if api_variant is not no_alias we should check the inplace function call if(aliasing != api_variant::no_alias) { - if(TestFixture::use_graphs) - test_utils::resetGraphHelper(graph, graph_instance, stream); - ASSERT_NO_FATAL_FAILURE(run_and_verify(input_it, d_input)); } if(TestFixture::use_graphs) { - test_utils::cleanupGraphHelper(graph, graph_instance); HIP_CHECK(hipStreamDestroy(stream)); } @@ -651,11 +655,6 @@ TYPED_TEST(RocprimDeviceAdjacentDifferenceLargeTests, LargeIndices) static constexpr auto left_tag = rocprim::detail::bool_constant{}; static constexpr auto aliasing_tag = std::integral_constant{}; - hipGraph_t graph; - hipGraphExec_t graph_instance; - if (TestFixture::use_graphs) - graph = test_utils::createGraphHelper(stream); - // Allocate temporary storage std::size_t temp_storage_size; void* d_temp_storage = nullptr; @@ -670,16 +669,20 @@ TYPED_TEST(RocprimDeviceAdjacentDifferenceLargeTests, LargeIndices) stream, debug_synchronous)); - if (TestFixture::use_graphs) - graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); - ASSERT_GT(temp_storage_size, 0); HIP_CHECK(test_common_utils::hipMallocHelper(&d_temp_storage, temp_storage_size)); - if (TestFixture::use_graphs) - test_utils::resetGraphHelper(graph, graph_instance, stream); - + hipGraph_t graph; + if(TestFixture::use_graphs) + { + graph = test_utils::createGraphHelper(stream); + } + + // Capture the memset in the graph so that relaunching will have expected result + HIP_CHECK(hipMemsetAsync(d_incorrect_flag, 0, sizeof(*d_incorrect_flag), stream)); + HIP_CHECK(hipMemsetAsync(d_counter, 0, sizeof(*d_counter), stream)); + // Run HIP_CHECK(dispatch_adjacent_difference(left_tag, aliasing_tag, @@ -692,8 +695,11 @@ TYPED_TEST(RocprimDeviceAdjacentDifferenceLargeTests, LargeIndices) stream, debug_synchronous)); - if (TestFixture::use_graphs) + hipGraphExec_t graph_instance; + if(TestFixture::use_graphs) + { graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); + } // Copy output to host flag_type incorrect_flag; @@ -711,8 +717,10 @@ TYPED_TEST(RocprimDeviceAdjacentDifferenceLargeTests, LargeIndices) hipFree(d_incorrect_flag); hipFree(d_counter); - if (TestFixture::use_graphs) + if(TestFixture::use_graphs) + { test_utils::cleanupGraphHelper(graph, graph_instance); + } } } diff --git a/test/rocprim/test_device_batch_memcpy.cpp b/test/rocprim/test_device_batch_memcpy.cpp index 64986a796..b3957afce 100644 --- a/test/rocprim/test_device_batch_memcpy.cpp +++ b/test/rocprim/test_device_batch_memcpy.cpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -27,6 +27,7 @@ #include "test_utils_types.hpp" #include "rocprim/detail/various.hpp" +#include "rocprim/device/device_copy.hpp" #include "rocprim/device/device_memcpy.hpp" #include "rocprim/intrinsics/thread.hpp" @@ -44,6 +45,7 @@ template @@ -51,6 +53,7 @@ struct DeviceBatchMemcpyParams { using value_type = ValueType; using size_type = SizeType; + static constexpr bool isMemCpy = IsMemCpy; static constexpr bool shuffled = Shuffled; static constexpr uint32_t num_buffers = NumBuffers; static constexpr uint32_t max_size = MaxSize; @@ -61,6 +64,7 @@ struct DeviceBatchMemcpyTests : public ::testing::Test { using value_type = typename Params::value_type; using size_type = typename Params::size_type; + static constexpr bool isMemCpy = Params::isMemCpy; static constexpr bool shuffled = Params::shuffled; static constexpr uint32_t num_buffers = Params::num_buffers; static constexpr uint32_t max_size = Params::max_size; @@ -68,36 +72,42 @@ struct DeviceBatchMemcpyTests : public ::testing::Test typedef ::testing::Types< // Ignore copy/move - DeviceBatchMemcpyParams, uint32_t, false>, - DeviceBatchMemcpyParams, uint32_t, false>, - DeviceBatchMemcpyParams, uint32_t, false>, + DeviceBatchMemcpyParams, uint32_t, true, false>, + DeviceBatchMemcpyParams, uint32_t, true, false>, + DeviceBatchMemcpyParams, uint32_t, true, false>, // Unshuffled inputs and outputs // Variable value_type - DeviceBatchMemcpyParams, - DeviceBatchMemcpyParams, - DeviceBatchMemcpyParams, + DeviceBatchMemcpyParams, + DeviceBatchMemcpyParams, + DeviceBatchMemcpyParams, + DeviceBatchMemcpyParams, + DeviceBatchMemcpyParams, + DeviceBatchMemcpyParams, // size_type: uint16_t - DeviceBatchMemcpyParams, + DeviceBatchMemcpyParams, // size_type: int64_t - DeviceBatchMemcpyParams, - DeviceBatchMemcpyParams, + DeviceBatchMemcpyParams, + DeviceBatchMemcpyParams, // weird amount of buffers - DeviceBatchMemcpyParams, - DeviceBatchMemcpyParams, - DeviceBatchMemcpyParams, + DeviceBatchMemcpyParams, + DeviceBatchMemcpyParams, + DeviceBatchMemcpyParams, // Shuffled inputs and outputs // Variable value_type - DeviceBatchMemcpyParams, - DeviceBatchMemcpyParams, - DeviceBatchMemcpyParams, + DeviceBatchMemcpyParams, + DeviceBatchMemcpyParams, + DeviceBatchMemcpyParams, + DeviceBatchMemcpyParams, + DeviceBatchMemcpyParams, + DeviceBatchMemcpyParams, // size_type: uint16_t - DeviceBatchMemcpyParams, + DeviceBatchMemcpyParams, // size_type: int64_t - DeviceBatchMemcpyParams, - DeviceBatchMemcpyParams> + DeviceBatchMemcpyParams, + DeviceBatchMemcpyParams> DeviceBatchMemcpyTestsParams; TYPED_TEST_SUITE(DeviceBatchMemcpyTests, DeviceBatchMemcpyTestsParams); @@ -145,6 +155,161 @@ std::vector shuffled_exclusive_scan(const std::vector& input, RandomGenera return result; } +template::type = 0> +void init_input(ContainerMemCpy& h_input_for_memcpy, + ContainerCopy& /*h_input_for_copy*/, + std::mt19937_64& rng, + byte_offset_type total_num_bytes) +{ + std::independent_bits_engine bits_engine{rng}; + + const size_t num_ints = rocprim::detail::ceiling_div(total_num_bytes, sizeof(uint64_t)); + h_input_for_memcpy = std::vector(num_ints * sizeof(uint64_t)); + + // generate_n for uninitialized memory, pragmatically use placement-new, since there are no + // uint64_t objects alive yet in the storage. + std::for_each( + reinterpret_cast(h_input_for_memcpy.data()), + reinterpret_cast(h_input_for_memcpy.data() + num_ints * sizeof(uint64_t)), + [&bits_engine](uint64_t& elem) { ::new(&elem) uint64_t{bits_engine()}; }); +} + +template::type = 0> +void init_input(ContainerMemCpy& /*h_input_for_memcpy*/, + ContainerCopy& h_input_for_copy, + std::mt19937_64& rng, + byte_offset_type total_num_bytes) +{ + using value_type = typename ContainerCopy::value_type; + + std::independent_bits_engine bits_engine{rng}; + + const size_t num_ints = rocprim::detail::ceiling_div(total_num_bytes, sizeof(uint64_t)); + const size_t num_of_elements + = rocprim::detail::ceiling_div(num_ints * sizeof(uint64_t), sizeof(value_type)); + h_input_for_copy = std::vector(num_of_elements); + + // generate_n for uninitialized memory, pragmatically use placement-new, since there are no + // uint64_t objects alive yet in the storage. + std::for_each(reinterpret_cast(h_input_for_copy.data()), + reinterpret_cast(h_input_for_copy.data()) + num_ints, + [&bits_engine](uint64_t& elem) { ::new(&elem) uint64_t{bits_engine()}; }); +} + +template::type = 0> +void batch_copy(void* temporary_storage, + size_t& storage_size, + InputBufferItType sources, + OutputBufferItType destinations, + BufferSizeItType sizes, + uint32_t num_copies, + hipStream_t stream) +{ + HIP_CHECK(rocprim::batch_memcpy(temporary_storage, + storage_size, + sources, + destinations, + sizes, + num_copies, + stream)); +} + +template::type = 0> +void batch_copy(void* temporary_storage, + size_t& storage_size, + InputBufferItType sources, + OutputBufferItType destinations, + BufferSizeItType sizes, + uint32_t num_copies, + hipStream_t stream) +{ + HIP_CHECK(rocprim::batch_copy(temporary_storage, + storage_size, + sources, + destinations, + sizes, + num_copies, + stream)); +} + +template::type = 0> +void check_result(ContainerMemCpy& h_input_for_memcpy, + ContainerCopy& /*h_input_for_copy*/, + ptr d_output, + byte_offset_type total_num_bytes, + byte_offset_type /*total_num_elements*/, + int32_t num_buffers, + OffsetContainer& src_offsets, + OffsetContainer& dst_offsets, + SizesContainer& h_buffer_num_bytes) +{ + using value_type = typename ContainerCopy::value_type; + std::vector h_output = std::vector(total_num_bytes); + HIP_CHECK(hipMemcpy(h_output.data(), d_output, total_num_bytes, hipMemcpyDeviceToHost)); + for(int32_t i = 0; i < num_buffers; ++i) + { + ASSERT_EQ(std::memcmp(h_input_for_memcpy.data() + src_offsets[i] * sizeof(value_type), + h_output.data() + dst_offsets[i] * sizeof(value_type), + h_buffer_num_bytes[i]), + 0) + << "with index = " << i; + } +} + +template::type = 0> +void check_result(ContainerMemCpy& /*h_input_for_memcpy*/, + ContainerCopy& h_input_for_copy, + ptr d_output, + byte_offset_type total_num_bytes, + byte_offset_type total_num_elements, + int32_t num_buffers, + OffsetContainer& src_offsets, + OffsetContainer& dst_offsets, + SizesContainer& h_buffer_num_bytes) +{ + using value_type = typename ContainerCopy::value_type; + std::vector h_output = std::vector(total_num_elements); + HIP_CHECK(hipMemcpy(h_output.data(), d_output, total_num_bytes, hipMemcpyDeviceToHost)); + for(int32_t i = 0; i < num_buffers; ++i) + { + ASSERT_EQ(std::memcmp(h_input_for_copy.data() + src_offsets[i], + h_output.data() + dst_offsets[i], + h_buffer_num_bytes[i]), + 0) + << "with index = " << i; + } +} + TYPED_TEST(DeviceBatchMemcpyTests, SizeAndTypeVariation) { using value_type = typename TestFixture::value_type; @@ -155,6 +320,7 @@ TYPED_TEST(DeviceBatchMemcpyTests, SizeAndTypeVariation) constexpr int32_t num_buffers = TestFixture::num_buffers; constexpr int32_t max_size = TestFixture::max_size; constexpr bool shuffled = TestFixture::shuffled; + constexpr bool isMemCpy = TestFixture::isMemCpy; constexpr int32_t wlev_min_size = rocprim::batch_memcpy_config<>::wlev_size_threshold; constexpr int32_t blev_min_size = rocprim::batch_memcpy_config<>::blev_size_threshold; @@ -198,13 +364,6 @@ TYPED_TEST(DeviceBatchMemcpyTests, SizeAndTypeVariation) // Shuffle the sizes so that size classes aren't clustered std::shuffle(h_buffer_num_elements.begin(), h_buffer_num_elements.end(), rng); - // Get the byte size of each buffer - std::vector h_buffer_num_bytes(num_buffers); - for(size_t i = 0; i < num_buffers; ++i) - { - h_buffer_num_bytes[i] = h_buffer_num_elements[i] * sizeof(value_type); - } - // And the total byte size const byte_offset_type total_num_bytes = total_num_elements * sizeof(value_type); @@ -219,12 +378,13 @@ TYPED_TEST(DeviceBatchMemcpyTests, SizeAndTypeVariation) size_t temp_storage_bytes = 0; - HIP_CHECK(rocprim::batch_memcpy(nullptr, - temp_storage_bytes, - d_buffer_srcs, - d_buffer_dsts, - d_buffer_sizes, - num_buffers)); + batch_copy(nullptr, + temp_storage_bytes, + d_buffer_srcs, + d_buffer_dsts, + d_buffer_sizes, + num_buffers, + hipStreamDefault); void* d_temp_storage = nullptr; @@ -239,16 +399,9 @@ TYPED_TEST(DeviceBatchMemcpyTests, SizeAndTypeVariation) HIP_CHECK(hipMalloc(&d_temp_storage, temp_storage_bytes)); // Generate data. - std::independent_bits_engine bits_engine{rng}; - - const size_t num_ints = rocprim::detail::ceiling_div(total_num_bytes, sizeof(uint64_t)); - auto h_input = std::make_unique(num_ints * sizeof(uint64_t)); - - // generate_n for uninitialized memory, pragmatically use placement-new, since there are no - // uint64_t objects alive yet in the storage. - std::for_each(reinterpret_cast(h_input.get()), - reinterpret_cast(h_input.get() + num_ints * sizeof(uint64_t)), - [&bits_engine](uint64_t& elem) { ::new(&elem) uint64_t{bits_engine()}; }); + std::vector h_input_for_memcpy; + std::vector h_input_for_copy; + init_input(h_input_for_memcpy, h_input_for_copy, rng, total_num_bytes); // Generate the source and shuffled destination offsets. std::vector src_offsets; @@ -274,6 +427,13 @@ TYPED_TEST(DeviceBatchMemcpyTests, SizeAndTypeVariation) dst_offsets.begin() + 1); } + // Get the byte size of each buffer + std::vector h_buffer_num_bytes(num_buffers); + for(size_t i = 0; i < num_buffers; ++i) + { + h_buffer_num_bytes[i] = h_buffer_num_elements[i] * sizeof(value_type); + } + // Generate the source and destination pointers. std::vector h_buffer_srcs(num_buffers); std::vector h_buffer_dsts(num_buffers); @@ -285,7 +445,25 @@ TYPED_TEST(DeviceBatchMemcpyTests, SizeAndTypeVariation) } // Prepare the batch memcpy. - HIP_CHECK(hipMemcpy(d_input, h_input.get(), total_num_bytes, hipMemcpyHostToDevice)); + if(isMemCpy) + { + HIP_CHECK( + hipMemcpy(d_input, h_input_for_memcpy.data(), total_num_bytes, hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_buffer_sizes, + h_buffer_num_bytes.data(), + h_buffer_num_bytes.size() * sizeof(*d_buffer_sizes), + hipMemcpyHostToDevice)); + } + else + { + HIP_CHECK( + hipMemcpy(d_input, h_input_for_copy.data(), total_num_bytes, hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_buffer_sizes, + h_buffer_num_elements.data(), + h_buffer_num_elements.size() * sizeof(*d_buffer_sizes), + hipMemcpyHostToDevice)); + } + HIP_CHECK(hipMemcpy(d_buffer_srcs, h_buffer_srcs.data(), h_buffer_srcs.size() * sizeof(*d_buffer_srcs), @@ -294,31 +472,26 @@ TYPED_TEST(DeviceBatchMemcpyTests, SizeAndTypeVariation) h_buffer_dsts.data(), h_buffer_dsts.size() * sizeof(*d_buffer_dsts), hipMemcpyHostToDevice)); - HIP_CHECK(hipMemcpy(d_buffer_sizes, - h_buffer_num_bytes.data(), - h_buffer_num_bytes.size() * sizeof(*d_buffer_sizes), - hipMemcpyHostToDevice)); // Run batched memcpy. - HIP_CHECK(rocprim::batch_memcpy(d_temp_storage, - temp_storage_bytes, - d_buffer_srcs, - d_buffer_dsts, - d_buffer_sizes, - num_buffers, - hipStreamDefault)); - // Verify results. - auto h_output = std::make_unique(total_num_bytes); - HIP_CHECK(hipMemcpy(h_output.get(), d_output, total_num_bytes, hipMemcpyDeviceToHost)); + batch_copy(d_temp_storage, + temp_storage_bytes, + d_buffer_srcs, + d_buffer_dsts, + d_buffer_sizes, + num_buffers, + hipStreamDefault); - for(int32_t i = 0; i < num_buffers; ++i) - { - ASSERT_EQ(std::memcmp(h_input.get() + src_offsets[i] * sizeof(value_type), - h_output.get() + dst_offsets[i] * sizeof(value_type), - h_buffer_num_bytes[i]), - 0) - << "with index = " << i; - } + // Verify results. + check_result(h_input_for_memcpy, + h_input_for_copy, + d_output, + total_num_bytes, + total_num_elements, + num_buffers, + src_offsets, + dst_offsets, + h_buffer_num_bytes); HIP_CHECK(hipFree(d_temp_storage)); HIP_CHECK(hipFree(d_buffer_sizes)); diff --git a/test/rocprim/test_device_binary_search.cpp b/test/rocprim/test_device_binary_search.cpp index 332be50f8..dfe36a607 100644 --- a/test/rocprim/test_device_binary_search.cpp +++ b/test/rocprim/test_device_binary_search.cpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2019-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2019-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -157,11 +157,6 @@ TYPED_TEST(RocprimDeviceBinarySearch, LowerBound) haystack.begin(); } - hipGraph_t graph; - hipGraphExec_t graph_instance; - if (TestFixture::params::use_graphs) - graph = test_utils::createGraphHelper(stream); - void * d_temporary_storage = nullptr; size_t temporary_storage_bytes; HIP_CHECK(rocprim::lower_bound(d_temporary_storage, @@ -175,16 +170,16 @@ TYPED_TEST(RocprimDeviceBinarySearch, LowerBound) stream, debug_synchronous)); - if (TestFixture::params::use_graphs) - graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); - ASSERT_GT(temporary_storage_bytes, 0); HIP_CHECK(test_common_utils::hipMallocHelper(&d_temporary_storage, temporary_storage_bytes)); - if (TestFixture::params::use_graphs) - test_utils::resetGraphHelper(graph, graph_instance, stream); - + hipGraph_t graph; + if(TestFixture::params::use_graphs) + { + graph = test_utils::createGraphHelper(stream); + } + HIP_CHECK(rocprim::lower_bound(d_temporary_storage, temporary_storage_bytes, d_haystack, @@ -196,9 +191,12 @@ TYPED_TEST(RocprimDeviceBinarySearch, LowerBound) stream, debug_synchronous)); - if (TestFixture::params::use_graphs) + hipGraphExec_t graph_instance; + if(TestFixture::params::use_graphs) + { graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); - + } + std::vector output(needles_size); HIP_CHECK( hipMemcpy( @@ -213,8 +211,10 @@ TYPED_TEST(RocprimDeviceBinarySearch, LowerBound) HIP_CHECK(hipFree(d_needles)); HIP_CHECK(hipFree(d_output)); - if (TestFixture::params::use_graphs) + if(TestFixture::params::use_graphs) + { test_utils::cleanupGraphHelper(graph, graph_instance); + } ASSERT_NO_FATAL_FAILURE(test_utils::assert_eq(output, expected)); } @@ -303,11 +303,6 @@ TYPED_TEST(RocprimDeviceBinarySearch, UpperBound) haystack.begin(); } - hipGraph_t graph; - hipGraphExec_t graph_instance; - if (TestFixture::params::use_graphs) - graph = test_utils::createGraphHelper(stream); - void * d_temporary_storage = nullptr; size_t temporary_storage_bytes; HIP_CHECK(rocprim::upper_bound(d_temporary_storage, @@ -320,16 +315,17 @@ TYPED_TEST(RocprimDeviceBinarySearch, UpperBound) compare_op, stream, debug_synchronous)); - if (TestFixture::params::use_graphs) - graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); - + ASSERT_GT(temporary_storage_bytes, 0); HIP_CHECK(test_common_utils::hipMallocHelper(&d_temporary_storage, temporary_storage_bytes)); - if (TestFixture::params::use_graphs) - test_utils::resetGraphHelper(graph, graph_instance, stream); - + hipGraph_t graph; + if(TestFixture::params::use_graphs) + { + graph = test_utils::createGraphHelper(stream); + } + HIP_CHECK(rocprim::upper_bound(d_temporary_storage, temporary_storage_bytes, d_haystack, @@ -341,9 +337,12 @@ TYPED_TEST(RocprimDeviceBinarySearch, UpperBound) stream, debug_synchronous)); - if (TestFixture::params::use_graphs) + hipGraphExec_t graph_instance; + if(TestFixture::params::use_graphs) + { graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); - + } + std::vector output(needles_size); HIP_CHECK( hipMemcpy( @@ -358,15 +357,19 @@ TYPED_TEST(RocprimDeviceBinarySearch, UpperBound) HIP_CHECK(hipFree(d_needles)); HIP_CHECK(hipFree(d_output)); - if (TestFixture::params::use_graphs) + if(TestFixture::params::use_graphs) + { test_utils::cleanupGraphHelper(graph, graph_instance); + } ASSERT_NO_FATAL_FAILURE(test_utils::assert_eq(output, expected)); } } - if (TestFixture::params::use_graphs) + if(TestFixture::params::use_graphs) + { HIP_CHECK(hipStreamDestroy(stream)); + } } TYPED_TEST(RocprimDeviceBinarySearch, BinarySearch) @@ -447,11 +450,6 @@ TYPED_TEST(RocprimDeviceBinarySearch, BinarySearch) expected[i] = std::binary_search(haystack.begin(), haystack.end(), needles[i], compare_op); } - hipGraph_t graph; - hipGraphExec_t graph_instance; - if (TestFixture::params::use_graphs) - graph = test_utils::createGraphHelper(stream); - void * d_temporary_storage = nullptr; size_t temporary_storage_bytes; HIP_CHECK(rocprim::binary_search(d_temporary_storage, @@ -465,16 +463,16 @@ TYPED_TEST(RocprimDeviceBinarySearch, BinarySearch) stream, debug_synchronous)); - if (TestFixture::params::use_graphs) - graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); - ASSERT_GT(temporary_storage_bytes, 0); HIP_CHECK(test_common_utils::hipMallocHelper(&d_temporary_storage, temporary_storage_bytes)); - if (TestFixture::params::use_graphs) - test_utils::resetGraphHelper(graph, graph_instance, stream); - + hipGraph_t graph; + if(TestFixture::params::use_graphs) + { + graph = test_utils::createGraphHelper(stream); + } + HIP_CHECK(rocprim::binary_search(d_temporary_storage, temporary_storage_bytes, d_haystack, @@ -486,9 +484,12 @@ TYPED_TEST(RocprimDeviceBinarySearch, BinarySearch) stream, debug_synchronous)); - if (TestFixture::params::use_graphs) + hipGraphExec_t graph_instance; + if(TestFixture::params::use_graphs) + { graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); - + } + std::vector output(needles_size); HIP_CHECK( hipMemcpy( @@ -503,13 +504,17 @@ TYPED_TEST(RocprimDeviceBinarySearch, BinarySearch) HIP_CHECK(hipFree(d_needles)); HIP_CHECK(hipFree(d_output)); - if (TestFixture::params::use_graphs) + if(TestFixture::params::use_graphs) + { test_utils::cleanupGraphHelper(graph, graph_instance); + } ASSERT_NO_FATAL_FAILURE(test_utils::assert_eq(output, expected)); } } - if (TestFixture::params::use_graphs) + if(TestFixture::params::use_graphs) + { HIP_CHECK(hipStreamDestroy(stream)); + } } diff --git a/test/rocprim/test_device_histogram.cpp b/test/rocprim/test_device_histogram.cpp index 1c9f38131..ff47dadf3 100644 --- a/test/rocprim/test_device_histogram.cpp +++ b/test/rocprim/test_device_histogram.cpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -60,8 +60,8 @@ std::vector> get_dims() // Generate values ouside the desired histogram range (+-10%) // (correctly handling test cases like uchar [0, 256), ushort [0, 65536)) template -inline auto get_random_samples(size_t size, U min, U max, int seed_value) - -> typename std::enable_if::value, std::vector>::type +inline auto get_random_samples(size_t size, U min, U max, int seed_value) -> + typename std::enable_if::value, std::vector>::type { const long long min1 = static_cast(min); const long long max1 = static_cast(max); @@ -75,8 +75,8 @@ inline auto get_random_samples(size_t size, U min, U max, int seed_value) } template -inline auto get_random_samples(size_t size, U min, U max, int seed_value) - -> typename std::enable_if::value, std::vector>::type +inline auto get_random_samples(size_t size, U min, U max, int seed_value) -> + typename std::enable_if::value, std::vector>::type { const double min1 = static_cast(min); const double max1 = static_cast(max); @@ -129,6 +129,10 @@ class RocprimDeviceHistogramEven : public ::testing::Test { using custom_config1 = rocprim::histogram_config>; typedef ::testing::Types, + params1, + //params1, + params1, + params1, params1, params1, params1, @@ -143,7 +147,6 @@ typedef ::testing::Types, TYPED_TEST_SUITE(RocprimDeviceHistogramEven, Params1); -template void testHistogramEvenIncorrectInput() { int device_id = test_common_utils::obtain_device_from_ctest(); @@ -151,39 +154,23 @@ void testHistogramEvenIncorrectInput() HIP_CHECK(hipSetDevice(device_id)); hipStream_t stream = 0; - hipGraph_t graph; - hipGraphExec_t graph_instance; - if (UseGraphs) - { - // Default stream does not support hipGraph stream capture, so create one - HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); - graph = test_utils::createGraphHelper(stream); - } - + size_t temporary_storage_bytes = 0; int * d_input = nullptr; int * d_histogram = nullptr; - hipError_t result = rocprim::histogram_even( - nullptr, temporary_storage_bytes, - d_input, 123, + // This check happens on host so there is nothing to capture for hipGraph. + hipError_t result = rocprim::histogram_even(nullptr, + temporary_storage_bytes, + d_input, + 123, d_histogram, - 1, 1, 2, stream - ); - - if (UseGraphs) - graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); - - ASSERT_EQ( - result, - hipErrorInvalidValue - ); + 1, + 1, + 2, + stream); - if (UseGraphs) - { - test_utils::cleanupGraphHelper(graph, graph_instance); - HIP_CHECK(hipStreamDestroy(stream)); - } + ASSERT_EQ(result, hipErrorInvalidValue); } TEST(RocprimDeviceHistogramEven, IncorrectInput) @@ -191,11 +178,6 @@ TEST(RocprimDeviceHistogramEven, IncorrectInput) testHistogramEvenIncorrectInput(); } -TEST(RocprimDeviceHistogramEven, IncorrectInputWithGraphs) -{ - testHistogramEvenIncorrectInput(); -} - TYPED_TEST(RocprimDeviceHistogramEven, Even) { int device_id = test_common_utils::obtain_device_from_ctest(); @@ -206,8 +188,8 @@ TYPED_TEST(RocprimDeviceHistogramEven, Even) using counter_type = typename TestFixture::params::counter_type; using level_type = typename TestFixture::params::level_type; constexpr unsigned int bins = TestFixture::params::bins; - constexpr level_type lower_level = TestFixture::params::lower_level; - constexpr level_type upper_level = TestFixture::params::upper_level; + const level_type lower_level = static_cast(TestFixture::params::lower_level); + const level_type upper_level = static_cast(TestFixture::params::upper_level); hipStream_t stream = 0; if (TestFixture::params::use_graphs) @@ -255,7 +237,7 @@ TYPED_TEST(RocprimDeviceHistogramEven, Even) // Calculate expected results on host std::vector histogram_expected(bins, 0); - const level_type scale = (upper_level - lower_level) / bins; + const level_type scale = static_cast((upper_level - lower_level) / bins); for(size_t row = 0; row < rows; row++) { for(size_t column = 0; column < columns; column++) @@ -272,11 +254,6 @@ TYPED_TEST(RocprimDeviceHistogramEven, Even) using config = typename TestFixture::params::config; - hipGraph_t graph; - hipGraphExec_t graph_instance; - if (TestFixture::params::use_graphs) - graph = test_utils::createGraphHelper(stream); - size_t temporary_storage_bytes = 0; if(rows == 1) { @@ -303,17 +280,17 @@ TYPED_TEST(RocprimDeviceHistogramEven, Even) ); } - if (TestFixture::params::use_graphs) - graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); - ASSERT_GT(temporary_storage_bytes, 0U); void * d_temporary_storage; HIP_CHECK(test_common_utils::hipMallocHelper(&d_temporary_storage, temporary_storage_bytes)); - if (TestFixture::params::use_graphs) - test_utils::resetGraphHelper(graph, graph_instance, stream); - + hipGraph_t graph; + if(TestFixture::params::use_graphs) + { + graph = test_utils::createGraphHelper(stream); + } + if(rows == 1) { HIP_CHECK( @@ -339,9 +316,12 @@ TYPED_TEST(RocprimDeviceHistogramEven, Even) ); } - if (TestFixture::params::use_graphs) + hipGraphExec_t graph_instance; + if(TestFixture::params::use_graphs) + { graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); - + } + std::vector histogram(bins); HIP_CHECK( hipMemcpy( @@ -360,13 +340,17 @@ TYPED_TEST(RocprimDeviceHistogramEven, Even) ASSERT_EQ(histogram[i], histogram_expected[i]); } - if (TestFixture::params::use_graphs) + if(TestFixture::params::use_graphs) + { test_utils::cleanupGraphHelper(graph, graph_instance); + } } } - if (TestFixture::params::use_graphs) + if(TestFixture::params::use_graphs) + { HIP_CHECK(hipStreamDestroy(stream)); + } } template void testHistogramRangeIncorrectInput() { int device_id = test_common_utils::obtain_device_from_ctest(); @@ -421,20 +404,13 @@ void testHistogramRangeIncorrectInput() HIP_CHECK(hipSetDevice(device_id)); hipStream_t stream = 0; - hipGraph_t graph; - hipGraphExec_t graph_instance; - if (UseGraphs) - { - // Default stream does not support hipGraph stream capture, so create one - HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); - graph = test_utils::createGraphHelper(stream); - } - + size_t temporary_storage_bytes = 0; int * d_input = nullptr; int * d_histogram = nullptr; int * d_levels = nullptr; + // This check happens on host so there is nothing to capture for hipGraph. hipError_t result = rocprim::histogram_range( nullptr, temporary_storage_bytes, d_input, 123, @@ -442,19 +418,7 @@ void testHistogramRangeIncorrectInput() 1, d_levels, stream ); - if (UseGraphs) - graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); - - ASSERT_EQ( - result, - hipErrorInvalidValue - ); - - if (UseGraphs) - { - test_utils::cleanupGraphHelper(graph, graph_instance); - HIP_CHECK(hipStreamDestroy(stream)); - } + ASSERT_EQ(result, hipErrorInvalidValue); } TEST(RocprimDeviceHistogramRange, RangeIncorrectInput) @@ -462,12 +426,6 @@ TEST(RocprimDeviceHistogramRange, RangeIncorrectInput) testHistogramRangeIncorrectInput(); } -TEST(RocprimDeviceHistogramRange, RangeIncorrectInputWithGraphs) -{ - testHistogramRangeIncorrectInput(); -} - - TYPED_TEST(RocprimDeviceHistogramRange, Range) { int device_id = test_common_utils::obtain_device_from_ctest(); @@ -572,11 +530,6 @@ TYPED_TEST(RocprimDeviceHistogramRange, Range) using config = typename TestFixture::params::config; - hipGraph_t graph; - hipGraphExec_t graph_instance; - if (TestFixture::params::use_graphs) - graph = test_utils::createGraphHelper(stream); - size_t temporary_storage_bytes = 0; if(rows == 1) { @@ -605,17 +558,17 @@ TYPED_TEST(RocprimDeviceHistogramRange, Range) debug_synchronous)); } - if (TestFixture::params::use_graphs) - graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); - ASSERT_GT(temporary_storage_bytes, 0U); void * d_temporary_storage; HIP_CHECK(test_common_utils::hipMallocHelper(&d_temporary_storage, temporary_storage_bytes)); - if (TestFixture::params::use_graphs) - test_utils::resetGraphHelper(graph, graph_instance, stream); - + hipGraph_t graph; + if(TestFixture::params::use_graphs) + { + graph = test_utils::createGraphHelper(stream); + } + if(rows == 1) { HIP_CHECK(rocprim::histogram_range(d_temporary_storage, @@ -643,8 +596,11 @@ TYPED_TEST(RocprimDeviceHistogramRange, Range) debug_synchronous)); } - if (TestFixture::params::use_graphs) + hipGraphExec_t graph_instance; + if(TestFixture::params::use_graphs) + { graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); + } std::vector histogram(bins); HIP_CHECK( @@ -665,8 +621,10 @@ TYPED_TEST(RocprimDeviceHistogramRange, Range) ASSERT_EQ(histogram[i], histogram_expected[i]); } - if (TestFixture::params::use_graphs) + if(TestFixture::params::use_graphs) + { test_utils::cleanupGraphHelper(graph, graph_instance); + } } } @@ -856,11 +814,6 @@ TYPED_TEST(RocprimDeviceHistogramMultiEven, MultiEven) using config = typename TestFixture::params::config; - hipGraph_t graph; - hipGraphExec_t graph_instance; - if (TestFixture::params::use_graphs) - graph = test_utils::createGraphHelper(stream); - size_t temporary_storage_bytes = 0; if(rows == 1) { @@ -893,16 +846,16 @@ TYPED_TEST(RocprimDeviceHistogramMultiEven, MultiEven) debug_synchronous))); } - if (TestFixture::params::use_graphs) - graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); - ASSERT_GT(temporary_storage_bytes, 0U); void * d_temporary_storage; HIP_CHECK(test_common_utils::hipMallocHelper(&d_temporary_storage, temporary_storage_bytes)); - if (TestFixture::params::use_graphs) - test_utils::resetGraphHelper(graph, graph_instance, stream); + hipGraph_t graph; + if(TestFixture::params::use_graphs) + { + graph = test_utils::createGraphHelper(stream); + } if(rows == 1) { @@ -935,9 +888,12 @@ TYPED_TEST(RocprimDeviceHistogramMultiEven, MultiEven) debug_synchronous))); } - if (TestFixture::params::use_graphs) + hipGraphExec_t graph_instance; + if(TestFixture::params::use_graphs) + { graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); - + } + std::vector histogram[active_channels]; for(unsigned int channel = 0; channel < active_channels; channel++) { @@ -965,13 +921,17 @@ TYPED_TEST(RocprimDeviceHistogramMultiEven, MultiEven) } } - if (TestFixture::params::use_graphs) + if(TestFixture::params::use_graphs) + { test_utils::cleanupGraphHelper(graph, graph_instance); + } } } - if (TestFixture::params::use_graphs) + if(TestFixture::params::use_graphs) + { HIP_CHECK(hipStreamDestroy(stream)); + } } template histogram[active_channels]; for(unsigned int channel = 0; channel < active_channels; channel++) { @@ -1264,8 +1222,10 @@ TYPED_TEST(RocprimDeviceHistogramMultiRange, MultiRange) HIP_CHECK(hipFree(d_temporary_storage)); HIP_CHECK(hipFree(d_input)); - if (TestFixture::params::use_graphs) + if(TestFixture::params::use_graphs) + { test_utils::cleanupGraphHelper(graph, graph_instance); + } for(unsigned int channel = 0; channel < active_channels; channel++) { @@ -1279,6 +1239,8 @@ TYPED_TEST(RocprimDeviceHistogramMultiRange, MultiRange) } } - if (TestFixture::params::use_graphs) + if(TestFixture::params::use_graphs) + { HIP_CHECK(hipStreamDestroy(stream)); + } } diff --git a/test/rocprim/test_device_merge.cpp b/test/rocprim/test_device_merge.cpp index 2e608d500..1369d0b98 100644 --- a/test/rocprim/test_device_merge.cpp +++ b/test/rocprim/test_device_merge.cpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2017-2022 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -72,6 +72,7 @@ typedef ::testing::Types< DeviceMergeParams>, DeviceMergeParams, DeviceMergeParams, + DeviceMergeParams, DeviceMergeParams, DeviceMergeParams, DeviceMergeParams>, @@ -188,40 +189,35 @@ TYPED_TEST(RocprimDeviceMergeTests, MergeKey) test_utils::bounds_checking_iterator d_keys_checking_output( d_keys_output, out_of_bounds.device_pointer(), - size1 + size2 - ); + size1 + size2); - hipGraph_t graph; - hipGraphExec_t graph_instance; - if (TestFixture::use_graphs) - graph = test_utils::createGraphHelper(stream); - // temp storage size_t temp_storage_size_bytes; void * d_temp_storage = nullptr; // Get size of d_temp_storage - HIP_CHECK( - rocprim::merge( - d_temp_storage, temp_storage_size_bytes, - d_keys_input1, d_keys_input2, - d_keys_checking_output, - keys_input1.size(), keys_input2.size(), - compare_op, stream, debug_synchronous - ) - ); + HIP_CHECK(rocprim::merge(d_temp_storage, + temp_storage_size_bytes, + d_keys_input1, + d_keys_input2, + d_keys_checking_output, + keys_input1.size(), + keys_input2.size(), + compare_op, + stream, + debug_synchronous)); - if (TestFixture::use_graphs) - graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); - // temp_storage_size_bytes must be >0 ASSERT_GT(temp_storage_size_bytes, 0); // allocate temporary storage HIP_CHECK(test_common_utils::hipMallocHelper(&d_temp_storage, temp_storage_size_bytes)); - if (TestFixture::use_graphs) - test_utils::resetGraphHelper(graph, graph_instance, stream); - + hipGraph_t graph; + if(TestFixture::use_graphs) + { + graph = test_utils::createGraphHelper(stream); + } + // Run HIP_CHECK( rocprim::merge( @@ -233,8 +229,11 @@ TYPED_TEST(RocprimDeviceMergeTests, MergeKey) ) ); - if (TestFixture::use_graphs) + hipGraphExec_t graph_instance; + if(TestFixture::use_graphs) + { graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); + } HIP_CHECK(hipGetLastError()); HIP_CHECK(hipDeviceSynchronize()); @@ -395,14 +394,8 @@ TYPED_TEST(RocprimDeviceMergeTests, MergeKeyValue) test_utils::bounds_checking_iterator d_values_checking_output( d_values_output, out_of_bounds.device_pointer(), - size1 + size2 - ); + size1 + size2); - hipGraph_t graph; - hipGraphExec_t graph_instance; - if (TestFixture::use_graphs) - graph = test_utils::createGraphHelper(stream); - // temp storage size_t temp_storage_size_bytes; void * d_temp_storage = nullptr; @@ -419,18 +412,18 @@ TYPED_TEST(RocprimDeviceMergeTests, MergeKeyValue) ) ); - if (TestFixture::use_graphs) - graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); - // temp_storage_size_bytes must be >0 ASSERT_GT(temp_storage_size_bytes, 0); // allocate temporary storage HIP_CHECK(test_common_utils::hipMallocHelper(&d_temp_storage, temp_storage_size_bytes)); - if (TestFixture::use_graphs) - test_utils::resetGraphHelper(graph, graph_instance, stream); - + hipGraph_t graph; + if(TestFixture::use_graphs) + { + graph = test_utils::createGraphHelper(stream); + } + // Run HIP_CHECK( rocprim::merge( @@ -444,9 +437,12 @@ TYPED_TEST(RocprimDeviceMergeTests, MergeKeyValue) ) ); - if (TestFixture::use_graphs) + hipGraphExec_t graph_instance; + if(TestFixture::use_graphs) + { graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, false); - + } + HIP_CHECK(hipGetLastError()); HIP_CHECK(hipDeviceSynchronize()); @@ -535,13 +531,10 @@ void testMergeMismatchedIteratorTypes() static constexpr bool debug_synchronous = false; hipStream_t stream = 0; // default - hipGraph_t graph; - hipGraphExec_t graph_instance; if (UseGraphs) { // Default stream does not support hipGraph stream capture, so create one HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); - graph = test_utils::createGraphHelper(stream); } size_t temp_storage_size_bytes = 0; @@ -556,16 +549,16 @@ void testMergeMismatchedIteratorTypes() stream, debug_synchronous)); - if (UseGraphs) - graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); - ASSERT_GT(temp_storage_size_bytes, 0); void* d_temp_storage = nullptr; HIP_CHECK(test_common_utils::hipMallocHelper(&d_temp_storage, temp_storage_size_bytes)); - if (UseGraphs) - test_utils::resetGraphHelper(graph, graph_instance, stream); + hipGraph_t graph; + if(UseGraphs) + { + graph = test_utils::createGraphHelper(stream); + } HIP_CHECK(rocprim::merge(d_temp_storage, temp_storage_size_bytes, @@ -578,8 +571,11 @@ void testMergeMismatchedIteratorTypes() hipStreamDefault, debug_synchronous)); - if (UseGraphs) + hipGraphExec_t graph_instance; + if(UseGraphs) + { graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); + } std::vector keys_output(expected_keys_output.size()); HIP_CHECK(hipMemcpy(keys_output.data(), diff --git a/test/rocprim/test_device_merge_sort.cpp b/test/rocprim/test_device_merge_sort.cpp index 75e5b7932..4a1288230 100644 --- a/test/rocprim/test_device_merge_sort.cpp +++ b/test/rocprim/test_device_merge_sort.cpp @@ -1,6 +1,6 @@ /// MIT License // -// Copyright (c) 2017-2022 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -144,17 +144,8 @@ TYPED_TEST(RocprimDeviceSortTests, SortKey) // Calculate expected results on host std::vector expected(input); - std::stable_sort( - expected.begin(), - expected.end(), - compare_op - ); + std::stable_sort(expected.begin(), expected.end(), compare_op); - hipGraph_t graph; - hipGraphExec_t graph_instance; - if (TestFixture::use_graphs) - graph = test_utils::createGraphHelper(stream); - // temp storage size_t temp_storage_size_bytes; void * d_temp_storage = nullptr; @@ -167,9 +158,6 @@ TYPED_TEST(RocprimDeviceSortTests, SortKey) ) ); - if (TestFixture::use_graphs) - graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); - // temp_storage_size_bytes must be >0 ASSERT_GT(temp_storage_size_bytes, 0); @@ -177,9 +165,12 @@ TYPED_TEST(RocprimDeviceSortTests, SortKey) HIP_CHECK(test_common_utils::hipMallocHelper(&d_temp_storage, temp_storage_size_bytes)); HIP_CHECK(hipDeviceSynchronize()); - if (TestFixture::use_graphs) - test_utils::resetGraphHelper(graph, graph_instance, stream); - + hipGraph_t graph; + if(TestFixture::use_graphs) + { + graph = test_utils::createGraphHelper(stream); + } + // Run HIP_CHECK( rocprim::merge_sort( @@ -189,9 +180,12 @@ TYPED_TEST(RocprimDeviceSortTests, SortKey) ) ); - if (TestFixture::use_graphs) + hipGraphExec_t graph_instance; + if(TestFixture::use_graphs) + { graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); - + } + HIP_CHECK(hipGetLastError()); HIP_CHECK(hipDeviceSynchronize()); @@ -314,33 +308,26 @@ TYPED_TEST(RocprimDeviceSortTests, SortKeyValue) { expected[i] = key_value(keys_input[i], values_input[i]); } - std::stable_sort( - expected.begin(), - expected.end(), - [compare_op](const key_value& a, const key_value& b) { return compare_op(a.first, b.first); } - ); + std::stable_sort(expected.begin(), + expected.end(), + [compare_op](const key_value& a, const key_value& b) + { return compare_op(a.first, b.first); }); - hipGraph_t graph; - hipGraphExec_t graph_instance; - if (TestFixture::use_graphs) - graph = test_utils::createGraphHelper(stream); - // temp storage size_t temp_storage_size_bytes; void * d_temp_storage = nullptr; // Get size of d_temp_storage - HIP_CHECK( - rocprim::merge_sort( - d_temp_storage, temp_storage_size_bytes, - d_keys_input, d_keys_output, - d_values_input, d_values_output, keys_input.size(), - compare_op, stream, debug_synchronous - ) - ); + HIP_CHECK(rocprim::merge_sort(d_temp_storage, + temp_storage_size_bytes, + d_keys_input, + d_keys_output, + d_values_input, + d_values_output, + keys_input.size(), + compare_op, + stream, + debug_synchronous)); - if (TestFixture::use_graphs) - graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); - // temp_storage_size_bytes must be >0 ASSERT_GT(temp_storage_size_bytes, 0); @@ -348,9 +335,12 @@ TYPED_TEST(RocprimDeviceSortTests, SortKeyValue) HIP_CHECK(test_common_utils::hipMallocHelper(&d_temp_storage, temp_storage_size_bytes)); HIP_CHECK(hipDeviceSynchronize()); - if (TestFixture::use_graphs) - test_utils::resetGraphHelper(graph, graph_instance, stream); - + hipGraph_t graph; + if(TestFixture::use_graphs) + { + graph = test_utils::createGraphHelper(stream); + } + // Run HIP_CHECK( rocprim::merge_sort( @@ -361,9 +351,12 @@ TYPED_TEST(RocprimDeviceSortTests, SortKeyValue) ) ); - if (TestFixture::use_graphs) + hipGraphExec_t graph_instance; + if(TestFixture::use_graphs) + { graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, false); - + } + HIP_CHECK(hipGetLastError()); HIP_CHECK(hipDeviceSynchronize()); diff --git a/test/rocprim/test_device_partition.cpp b/test/rocprim/test_device_partition.cpp index 44a2c92bc..77f2d5629 100644 --- a/test/rocprim/test_device_partition.cpp +++ b/test/rocprim/test_device_partition.cpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -60,17 +60,18 @@ class RocprimDevicePartitionTests : public ::testing::Test static constexpr bool use_graphs = Params::use_graphs; }; -typedef ::testing::Types< - DevicePartitionParams, - DevicePartitionParams, - DevicePartitionParams, - DevicePartitionParams, - DevicePartitionParams, - DevicePartitionParams, - DevicePartitionParams, - DevicePartitionParams>, - DevicePartitionParams -> RocprimDevicePartitionTestsParams; +typedef ::testing::Types, + DevicePartitionParams, + DevicePartitionParams, + DevicePartitionParams, + DevicePartitionParams, + DevicePartitionParams, + DevicePartitionParams, + DevicePartitionParams, + DevicePartitionParams, + DevicePartitionParams>, + DevicePartitionParams> + RocprimDevicePartitionTestsParams; TYPED_TEST_SUITE(RocprimDevicePartitionTests, RocprimDevicePartitionTestsParams); @@ -137,11 +138,6 @@ TYPED_TEST(RocprimDevicePartitionTests, Flagged) } std::reverse(expected_rejected.begin(), expected_rejected.end()); - hipGraph_t graph; - hipGraphExec_t graph_instance; - if (TestFixture::use_graphs) - graph = test_utils::createGraphHelper(stream); - // temp storage size_t temp_storage_size_bytes; // Get size of d_temp_storage @@ -156,9 +152,6 @@ TYPED_TEST(RocprimDevicePartitionTests, Flagged) stream, debug_synchronous)); - if (TestFixture::use_graphs) - graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); - // temp_storage_size_bytes must be >0 ASSERT_GT(temp_storage_size_bytes, 0); @@ -166,9 +159,12 @@ TYPED_TEST(RocprimDevicePartitionTests, Flagged) void* d_temp_storage = nullptr; HIP_CHECK(test_common_utils::hipMallocHelper(&d_temp_storage, temp_storage_size_bytes)); - if (TestFixture::use_graphs) - test_utils::resetGraphHelper(graph, graph_instance, stream); - + hipGraph_t graph; + if(TestFixture::use_graphs) + { + graph = test_utils::createGraphHelper(stream); + } + // Run HIP_CHECK(rocprim::partition( d_temp_storage, @@ -181,9 +177,12 @@ TYPED_TEST(RocprimDevicePartitionTests, Flagged) stream, debug_synchronous)); - if (TestFixture::use_graphs) + hipGraphExec_t graph_instance; + if(TestFixture::use_graphs) + { graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); - + } + // Check if number of selected value is as expected_selected unsigned int selected_count_output = 0; HIP_CHECK(hipMemcpy(&selected_count_output, @@ -213,14 +212,18 @@ TYPED_TEST(RocprimDevicePartitionTests, Flagged) hipFree(d_output); hipFree(d_selected_count_output); hipFree(d_temp_storage); - - if (TestFixture::use_graphs) + + if(TestFixture::use_graphs) + { test_utils::cleanupGraphHelper(graph, graph_instance); + } } } - if (TestFixture::use_graphs) + if(TestFixture::use_graphs) + { HIP_CHECK(hipStreamDestroy(stream)); + } } TYPED_TEST(RocprimDevicePartitionTests, PredicateEmptyInput) @@ -257,17 +260,10 @@ TYPED_TEST(RocprimDevicePartitionTests, PredicateEmptyInput) hipMemcpyHostToDevice)); test_utils::out_of_bounds_flag out_of_bounds; - test_utils::bounds_checking_iterator d_checking_output( - d_output, - out_of_bounds.device_pointer(), - 0 - ); + test_utils::bounds_checking_iterator d_checking_output(d_output, + out_of_bounds.device_pointer(), + 0); - hipGraph_t graph; - hipGraphExec_t graph_instance; - if (TestFixture::use_graphs) - graph = test_utils::createGraphHelper(stream); - // temp storage size_t temp_storage_size_bytes; // Get size of d_temp_storage @@ -281,16 +277,16 @@ TYPED_TEST(RocprimDevicePartitionTests, PredicateEmptyInput) stream, debug_synchronous)); - if (TestFixture::use_graphs) - graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); - // allocate temporary storage void* d_temp_storage = nullptr; HIP_CHECK(test_common_utils::hipMallocHelper(&d_temp_storage, temp_storage_size_bytes)); - if (TestFixture::use_graphs) - test_utils::resetGraphHelper(graph, graph_instance, stream); - + hipGraph_t graph; + if(TestFixture::use_graphs) + { + graph = test_utils::createGraphHelper(stream); + } + // Run HIP_CHECK(rocprim::partition(d_temp_storage, temp_storage_size_bytes, @@ -302,8 +298,11 @@ TYPED_TEST(RocprimDevicePartitionTests, PredicateEmptyInput) stream, debug_synchronous)); - if (TestFixture::use_graphs) + hipGraphExec_t graph_instance; + if(TestFixture::use_graphs) + { graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, false); + } HIP_CHECK(hipDeviceSynchronize()); ASSERT_FALSE(out_of_bounds.get()); @@ -392,11 +391,6 @@ TYPED_TEST(RocprimDevicePartitionTests, Predicate) } std::reverse(expected_rejected.begin(), expected_rejected.end()); - hipGraph_t graph; - hipGraphExec_t graph_instance; - if (TestFixture::use_graphs) - graph = test_utils::createGraphHelper(stream); - // temp storage size_t temp_storage_size_bytes; // Get size of d_temp_storage @@ -411,9 +405,6 @@ TYPED_TEST(RocprimDevicePartitionTests, Predicate) stream, debug_synchronous)); - if (TestFixture::use_graphs) - graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); - // temp_storage_size_bytes must be >0 ASSERT_GT(temp_storage_size_bytes, 0); @@ -421,9 +412,12 @@ TYPED_TEST(RocprimDevicePartitionTests, Predicate) void* d_temp_storage = nullptr; HIP_CHECK(test_common_utils::hipMallocHelper(&d_temp_storage, temp_storage_size_bytes)); - if (TestFixture::use_graphs) - test_utils::resetGraphHelper(graph, graph_instance, stream); - + hipGraph_t graph; + if(TestFixture::use_graphs) + { + graph = test_utils::createGraphHelper(stream); + } + // Run HIP_CHECK(rocprim::partition( d_temp_storage, @@ -436,8 +430,11 @@ TYPED_TEST(RocprimDevicePartitionTests, Predicate) stream, debug_synchronous)); - if (TestFixture::use_graphs) + hipGraphExec_t graph_instance; + if(TestFixture::use_graphs) + { graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, false); + } HIP_CHECK(hipDeviceSynchronize()); @@ -470,13 +467,17 @@ TYPED_TEST(RocprimDevicePartitionTests, Predicate) hipFree(d_selected_count_output); hipFree(d_temp_storage); - if (TestFixture::use_graphs) + if(TestFixture::use_graphs) + { test_utils::cleanupGraphHelper(graph, graph_instance); + } } } - if (TestFixture::use_graphs) + if(TestFixture::use_graphs) + { HIP_CHECK(hipStreamDestroy(stream)); + } } TYPED_TEST(RocprimDevicePartitionTests, PredicateTwoWay) @@ -548,11 +549,6 @@ TYPED_TEST(RocprimDevicePartitionTests, PredicateTwoWay) } } - hipGraph_t graph; - hipGraphExec_t graph_instance; - if (TestFixture::use_graphs) - graph = test_utils::createGraphHelper(stream); - // temp storage size_t temp_storage_size_bytes; // Get size of d_temp_storage @@ -567,9 +563,6 @@ TYPED_TEST(RocprimDevicePartitionTests, PredicateTwoWay) stream, debug_synchronous)); - if (TestFixture::use_graphs) - graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); - // temp_storage_size_bytes must be >0 ASSERT_GT(temp_storage_size_bytes, 0); @@ -577,9 +570,12 @@ TYPED_TEST(RocprimDevicePartitionTests, PredicateTwoWay) void* d_temp_storage = nullptr; HIP_CHECK(test_common_utils::hipMallocHelper(&d_temp_storage, temp_storage_size_bytes)); - if (TestFixture::use_graphs) - test_utils::resetGraphHelper(graph, graph_instance, stream); - + hipGraph_t graph; + if(TestFixture::use_graphs) + { + graph = test_utils::createGraphHelper(stream); + } + // Run HIP_CHECK(rocprim::partition_two_way( d_temp_storage, @@ -593,8 +589,11 @@ TYPED_TEST(RocprimDevicePartitionTests, PredicateTwoWay) stream, debug_synchronous)); - if (TestFixture::use_graphs) + hipGraphExec_t graph_instance; + if(TestFixture::use_graphs) + { graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, false); + } HIP_CHECK(hipDeviceSynchronize()); @@ -629,13 +628,17 @@ TYPED_TEST(RocprimDevicePartitionTests, PredicateTwoWay) hipFree(d_selected_count_output); hipFree(d_temp_storage); - if (TestFixture::use_graphs) + if(TestFixture::use_graphs) + { test_utils::cleanupGraphHelper(graph, graph_instance); + } } } - if (TestFixture::use_graphs) + if(TestFixture::use_graphs) + { HIP_CHECK(hipStreamDestroy(stream)); + } } namespace { @@ -728,17 +731,13 @@ TYPED_TEST(RocprimDevicePartitionTests, PredicateThreeWay) static_cast(second_partiton_point - partion_point) }; - const auto expected = [&]{ + const auto expected = [&] + { auto result = std::vector(copy.size()); std::copy(copy.cbegin(), copy.cend(), result.begin()); return result; }(); - hipGraph_t graph; - hipGraphExec_t graph_instance; - if (TestFixture::use_graphs) - graph = test_utils::createGraphHelper(stream); - // temp storage size_t temp_storage_size_bytes; // Get size of d_temp_storage @@ -757,9 +756,6 @@ TYPED_TEST(RocprimDevicePartitionTests, PredicateThreeWay) stream, debug_synchronous)); - if (TestFixture::use_graphs) - graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); - // temp_storage_size_bytes must be >0 ASSERT_GT(temp_storage_size_bytes, 0); @@ -767,9 +763,12 @@ TYPED_TEST(RocprimDevicePartitionTests, PredicateThreeWay) void* d_temp_storage = nullptr; HIP_CHECK(hipMalloc(&d_temp_storage, temp_storage_size_bytes)); - if (TestFixture::use_graphs) - test_utils::resetGraphHelper(graph, graph_instance, stream); - + hipGraph_t graph; + if(TestFixture::use_graphs) + { + graph = test_utils::createGraphHelper(stream); + } + // Run HIP_CHECK(rocprim::partition_three_way( d_temp_storage, @@ -786,8 +785,11 @@ TYPED_TEST(RocprimDevicePartitionTests, PredicateThreeWay) stream, debug_synchronous)); - if (TestFixture::use_graphs) + hipGraphExec_t graph_instance; + if(TestFixture::use_graphs) + { graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); + } HIP_CHECK(hipDeviceSynchronize()); @@ -838,14 +840,18 @@ TYPED_TEST(RocprimDevicePartitionTests, PredicateThreeWay) hipFree(d_selected_counts); hipFree(d_temp_storage); - if (TestFixture::use_graphs) + if(TestFixture::use_graphs) + { test_utils::cleanupGraphHelper(graph, graph_instance); + } } } } - if (TestFixture::use_graphs) + if(TestFixture::use_graphs) + { HIP_CHECK(hipStreamDestroy(stream)); + } } namespace @@ -1130,11 +1136,6 @@ TEST_P(RocprimDevicePartitionLargeInputTests, LargeInputPartition) size_t* d_count_output{}; HIP_CHECK(test_common_utils::hipMallocHelper(&d_count_output, sizeof(*d_count_output))); - hipGraph_t graph; - hipGraphExec_t graph_instance; - if (use_graphs) - graph = test_utils::createGraphHelper(stream); - void* d_temporary_storage{}; size_t temporary_storage_size{}; HIP_CHECK(rocprim::partition(d_temporary_storage, @@ -1147,15 +1148,15 @@ TEST_P(RocprimDevicePartitionLargeInputTests, LargeInputPartition) stream, debug_synchronous)); - if (use_graphs) - graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); - ASSERT_NE(0, temporary_storage_size); HIP_CHECK(test_common_utils::hipMallocHelper(&d_temporary_storage, temporary_storage_size)); - if (use_graphs) - test_utils::resetGraphHelper(graph, graph_instance, stream); - + hipGraph_t graph; + if(use_graphs) + { + graph = test_utils::createGraphHelper(stream); + } + HIP_CHECK(rocprim::partition(d_temporary_storage, temporary_storage_size, input_iterator, @@ -1166,9 +1167,12 @@ TEST_P(RocprimDevicePartitionLargeInputTests, LargeInputPartition) stream, debug_synchronous)); - if (use_graphs) + hipGraphExec_t graph_instance; + if(use_graphs) + { graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); - + } + size_t count_output{}; HIP_CHECK(hipMemcpyWithStream(&count_output, d_count_output, @@ -1192,12 +1196,16 @@ TEST_P(RocprimDevicePartitionLargeInputTests, LargeInputPartition) HIP_CHECK(hipFree(d_count_output)); HIP_CHECK(hipFree(d_incorrect_flag)); - if (use_graphs) + if(use_graphs) + { test_utils::cleanupGraphHelper(graph, graph_instance); + } } - if (use_graphs) + if(use_graphs) + { HIP_CHECK(hipStreamDestroy(stream)); + } } TEST_P(RocprimDevicePartitionLargeInputTests, LargeInputPartitionTwoWay) @@ -1248,11 +1256,6 @@ TEST_P(RocprimDevicePartitionLargeInputTests, LargeInputPartitionTwoWay) size_t* d_count_output{}; HIP_CHECK(test_common_utils::hipMallocHelper(&d_count_output, sizeof(*d_count_output))); - hipGraph_t graph; - hipGraphExec_t graph_instance; - if (use_graphs) - graph = test_utils::createGraphHelper(stream); - void* d_temporary_storage{}; size_t temporary_storage_size{}; HIP_CHECK(rocprim::partition_two_way(d_temporary_storage, @@ -1266,15 +1269,15 @@ TEST_P(RocprimDevicePartitionLargeInputTests, LargeInputPartitionTwoWay) stream, debug_synchronous)); - if (use_graphs) - graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); - ASSERT_NE(0, temporary_storage_size); HIP_CHECK(test_common_utils::hipMallocHelper(&d_temporary_storage, temporary_storage_size)); - if (use_graphs) - test_utils::resetGraphHelper(graph, graph_instance, stream); - + hipGraph_t graph; + if(use_graphs) + { + graph = test_utils::createGraphHelper(stream); + } + HIP_CHECK(rocprim::partition_two_way(d_temporary_storage, temporary_storage_size, input_iterator, @@ -1286,9 +1289,12 @@ TEST_P(RocprimDevicePartitionLargeInputTests, LargeInputPartitionTwoWay) stream, debug_synchronous)); - if (use_graphs) + hipGraphExec_t graph_instance; + if(use_graphs) + { graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); - + } + size_t count_output{}; HIP_CHECK(hipMemcpyWithStream(&count_output, d_count_output, @@ -1320,12 +1326,16 @@ TEST_P(RocprimDevicePartitionLargeInputTests, LargeInputPartitionTwoWay) HIP_CHECK(hipFree(d_incorrect_select_flag)); HIP_CHECK(hipFree(d_incorrect_reject_flag)); - if (use_graphs) + if(use_graphs) + { test_utils::cleanupGraphHelper(graph, graph_instance); + } } - if (use_graphs) + if(use_graphs) + { HIP_CHECK(hipStreamDestroy(stream)); + } } TEST_P(RocprimDevicePartitionLargeInputTests, LargeInputPartitionThreeWay) @@ -1373,11 +1383,6 @@ TEST_P(RocprimDevicePartitionLargeInputTests, LargeInputPartitionThreeWay) size_t* d_count_output{}; HIP_CHECK(test_common_utils::hipMallocHelper(&d_count_output, 2 * sizeof(*d_count_output))); - hipGraph_t graph; - hipGraphExec_t graph_instance; - if (use_graphs) - graph = test_utils::createGraphHelper(stream); - void* d_temporary_storage{}; size_t temporary_storage_size{}; HIP_CHECK(rocprim::partition_three_way(d_temporary_storage, @@ -1393,15 +1398,15 @@ TEST_P(RocprimDevicePartitionLargeInputTests, LargeInputPartitionThreeWay) stream, debug_synchronous)); - if (use_graphs) - graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); - ASSERT_NE(0, temporary_storage_size); HIP_CHECK(test_common_utils::hipMallocHelper(&d_temporary_storage, temporary_storage_size)); - if (use_graphs) - test_utils::resetGraphHelper(graph, graph_instance, stream); - + hipGraph_t graph; + if(use_graphs) + { + graph = test_utils::createGraphHelper(stream); + } + HIP_CHECK(rocprim::partition_three_way(d_temporary_storage, temporary_storage_size, input_iterator, @@ -1415,9 +1420,12 @@ TEST_P(RocprimDevicePartitionLargeInputTests, LargeInputPartitionThreeWay) stream, debug_synchronous)); - if (use_graphs) + hipGraphExec_t graph_instance; + if(use_graphs) + { graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); - + } + size_t count_output[2]{}; HIP_CHECK(hipMemcpyWithStream(&count_output, d_count_output, @@ -1445,10 +1453,14 @@ TEST_P(RocprimDevicePartitionLargeInputTests, LargeInputPartitionThreeWay) HIP_CHECK(hipFree(d_count_output)); HIP_CHECK(hipFree(d_incorrect_flag)); - if (use_graphs) + if(use_graphs) + { test_utils::cleanupGraphHelper(graph, graph_instance); + } } - if (use_graphs) + if(use_graphs) + { HIP_CHECK(hipStreamDestroy(stream)); + } } diff --git a/test/rocprim/test_device_radix_sort.cpp.in b/test/rocprim/test_device_radix_sort.cpp.in index 65d34f6a6..c70343d03 100644 --- a/test/rocprim/test_device_radix_sort.cpp.in +++ b/test/rocprim/test_device_radix_sort.cpp.in @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2022-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2022-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -54,6 +54,8 @@ INSTANTIATE(params<__int128_t, __int128_t>) INSTANTIATE(params<__uint128_t, __uint128_t>) #endif + INSTANTIATE(params) + INSTANTIATE(params) INSTANTIATE(params) INSTANTIATE(params) INSTANTIATE(params) @@ -104,4 +106,15 @@ // test with graphs INSTANTIATE(params) +#elif ROCPRIM_TEST_TYPE_SLICE == 2 + // custom types using a custom decomposer (ascending + descending) + INSTANTIATE(params, int>) + INSTANTIATE(params, int, true>) + INSTANTIATE(params, float>) + INSTANTIATE(params, int8_t, true>) + + // start_bit and end_bit + INSTANTIATE(params, int, false, 7, 55>) + INSTANTIATE(params, int, true, 0, 32>) + INSTANTIATE(params, float, false, 64, 99>) #endif diff --git a/test/rocprim/test_device_radix_sort.hpp b/test/rocprim/test_device_radix_sort.hpp index 668c16532..a6a6ac88a 100644 --- a/test/rocprim/test_device_radix_sort.hpp +++ b/test/rocprim/test_device_radix_sort.hpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -29,10 +29,15 @@ #include // required test headers +#include "test_seed.hpp" #include "test_utils_custom_float_type.hpp" +#include "test_utils_custom_test_types.hpp" #include "test_utils_sort_comparator.hpp" #include "test_utils_types.hpp" +#include +#include + template +auto generate_key_input(KeyIter keys_input, size_t size, engine_type& rng_engine) + -> std::enable_if_t< + rocprim::is_floating_point::value_type>::value> +{ + using key_type = typename std::iterator_traits::value_type; + test_utils::generate_random_data_n(keys_input, + size, + static_cast(-1000), + static_cast(+1000), + rng_engine); + test_utils::add_special_values(keys_input, size, rng_engine); +} + +template +auto generate_key_input(KeyIter keys_input, size_t size, engine_type& rng_engine) + -> std::enable_if_t< + !rocprim::is_floating_point::value_type>::value> +{ + using key_type = typename std::iterator_traits::value_type; + test_utils::generate_random_data_n(keys_input, + size, + std::numeric_limits::min(), + std::numeric_limits::max(), + rng_engine); +} + +// Working around custom_float_test_type, which is both a float and a custom_test_type +template +constexpr bool is_custom_not_float_test_type + = test_utils::is_custom_test_type::value && !rocprim::is_floating_point::value; + +template +auto invoke_sort_keys(void* d_temporary_storage, + size_t& temporary_storage_bytes, + Key* d_keys_input, + Key* d_keys_output, + size_t size, + unsigned int start_bit, + unsigned int end_bit, + hipStream_t stream, + bool debug_synchronous) + -> std::enable_if_t, hipError_t> +{ + return rocprim::radix_sort_keys(d_temporary_storage, + temporary_storage_bytes, + d_keys_input, + d_keys_output, + size, + start_bit, + end_bit, + stream, + debug_synchronous); +} + +template +auto invoke_sort_keys(void* d_temporary_storage, + size_t& temporary_storage_bytes, + Key* d_keys_input, + Key* d_keys_output, + size_t size, + unsigned int start_bit, + unsigned int end_bit, + hipStream_t stream, + bool debug_synchronous) + -> std::enable_if_t, hipError_t> +{ + return rocprim::radix_sort_keys_desc(d_temporary_storage, + temporary_storage_bytes, + d_keys_input, + d_keys_output, + size, + start_bit, + end_bit, + stream, + debug_synchronous); +} + +template +auto invoke_sort_keys(void* d_temporary_storage, + size_t& temporary_storage_bytes, + Key* d_keys_input, + Key* d_keys_output, + size_t size, + unsigned int start_bit, + unsigned int end_bit, + hipStream_t stream, + bool debug_synchronous) + -> std::enable_if_t, hipError_t> +{ + using decomposer_t = test_utils::custom_test_type_decomposer; + if(start_bit == 0 && end_bit == rocprim::detail::decomposer_max_bits::value) + { + return rocprim::radix_sort_keys(d_temporary_storage, + temporary_storage_bytes, + d_keys_input, + d_keys_output, + size, + decomposer_t{}, + stream, + debug_synchronous); + } + else + { + return rocprim::radix_sort_keys(d_temporary_storage, + temporary_storage_bytes, + d_keys_input, + d_keys_output, + size, + decomposer_t{}, + start_bit, + end_bit, + stream, + debug_synchronous); + } +} + +template +auto invoke_sort_keys(void* d_temporary_storage, + size_t& temporary_storage_bytes, + Key* d_keys_input, + Key* d_keys_output, + size_t size, + unsigned int start_bit, + unsigned int end_bit, + hipStream_t stream, + bool debug_synchronous) + -> std::enable_if_t, hipError_t> +{ + using decomposer_t = test_utils::custom_test_type_decomposer; + if(start_bit == 0 && end_bit == rocprim::detail::decomposer_max_bits::value) + { + return rocprim::radix_sort_keys_desc(d_temporary_storage, + temporary_storage_bytes, + d_keys_input, + d_keys_output, + size, + decomposer_t{}, + stream, + debug_synchronous); + } + else + { + return rocprim::radix_sort_keys_desc(d_temporary_storage, + temporary_storage_bytes, + d_keys_input, + d_keys_output, + size, + decomposer_t{}, + start_bit, + end_bit, + stream, + debug_synchronous); + } +} + template -inline void sort_keys() +void sort_keys() { int device_id = test_common_utils::obtain_device_from_ctest(); SCOPED_TRACE(testing::Message() << "with device_id = " << device_id); @@ -96,30 +257,18 @@ inline void sort_keys() for(size_t size : sizes) { if(size > (1 << 17) && !check_large_sizes) + { break; + } SCOPED_TRACE(testing::Message() << "with size = " << size); + engine_type rng_engine(seed_value); in_place = !in_place; // Generate data - std::vector keys_input; - if(rocprim::is_floating_point::value) - { - keys_input = test_utils::get_random_data(size, - static_cast(-1000), - static_cast(+1000), - seed_value); - test_utils::add_special_values(keys_input, seed_value); - } - else - { - keys_input - = test_utils::get_random_data(size, - std::numeric_limits::min(), - std::numeric_limits::max(), - seed_value); - } + auto keys_input = std::make_unique(size); + generate_key_input(keys_input.get(), size, rng_engine); key_type* d_keys_input; key_type* d_keys_output; @@ -134,12 +283,12 @@ inline void sort_keys() test_common_utils::hipMallocHelper(&d_keys_output, size * sizeof(key_type))); } HIP_CHECK(hipMemcpy(d_keys_input, - keys_input.data(), + keys_input.get(), size * sizeof(key_type), hipMemcpyHostToDevice)); // Calculate expected results on host - std::vector expected(keys_input); + std::vector expected(keys_input.get(), keys_input.get() + size); std::stable_sort( expected.begin(), expected.end(), @@ -151,62 +300,46 @@ inline void sort_keys() rocprim::default_config, 1024 * 512>; - hipGraph_t graph; - hipGraphExec_t graph_instance; - if (TestFixture::params::use_graphs) - graph = test_utils::createGraphHelper(stream); - size_t temporary_storage_bytes; - HIP_CHECK(rocprim::radix_sort_keys(nullptr, - temporary_storage_bytes, - d_keys_input, - d_keys_output, - size, - start_bit, - end_bit)); - - if (TestFixture::params::use_graphs) - graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); - + HIP_CHECK((invoke_sort_keys(nullptr, + temporary_storage_bytes, + d_keys_input, + d_keys_output, + size, + start_bit, + end_bit, + stream, + debug_synchronous))); ASSERT_GT(temporary_storage_bytes, 0); void* d_temporary_storage; HIP_CHECK( test_common_utils::hipMallocHelper(&d_temporary_storage, temporary_storage_bytes)); - if (TestFixture::params::use_graphs) - test_utils::resetGraphHelper(graph, graph_instance, stream); - - if(descending) + hipGraph_t graph; + if(TestFixture::params::use_graphs) { - HIP_CHECK(rocprim::radix_sort_keys_desc(d_temporary_storage, - temporary_storage_bytes, - d_keys_input, - d_keys_output, - size, - start_bit, - end_bit, - stream, - debug_synchronous)); + graph = test_utils::createGraphHelper(stream); } - else + + HIP_CHECK((invoke_sort_keys(d_temporary_storage, + temporary_storage_bytes, + d_keys_input, + d_keys_output, + size, + start_bit, + end_bit, + stream, + debug_synchronous))); + + hipGraphExec_t graph_instance; + if(TestFixture::params::use_graphs) { - HIP_CHECK(rocprim::radix_sort_keys(d_temporary_storage, - temporary_storage_bytes, - d_keys_input, - d_keys_output, - size, - start_bit, - end_bit, - stream, - debug_synchronous)); + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); } - if (TestFixture::params::use_graphs) - graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); - - std::vector keys_output(size); - HIP_CHECK(hipMemcpy(keys_output.data(), + auto keys_output = std::make_unique(size); + HIP_CHECK(hipMemcpy(keys_output.get(), d_keys_output, size * sizeof(key_type), hipMemcpyDeviceToHost)); @@ -218,19 +351,170 @@ inline void sort_keys() HIP_CHECK(hipFree(d_keys_output)); } - if (TestFixture::params::use_graphs) + if(TestFixture::params::use_graphs) + { test_utils::cleanupGraphHelper(graph, graph_instance); - - ASSERT_NO_FATAL_FAILURE(test_utils::assert_bit_eq(keys_output, expected)); + } + + ASSERT_NO_FATAL_FAILURE(test_utils::assert_bit_eq(keys_output.get(), + keys_output.get() + size, + expected.begin(), + expected.end())); } } - if (TestFixture::params::use_graphs) + if(TestFixture::params::use_graphs) + { HIP_CHECK(hipStreamDestroy(stream)); + } +} + +template +auto invoke_sort_pairs(void* d_temporary_storage, + size_t& temporary_storage_bytes, + Key* d_keys_input, + Key* d_keys_output, + Value* d_values_input, + Value* d_values_output, + size_t size, + unsigned int start_bit, + unsigned int end_bit, + hipStream_t stream, + bool debug_synchronous) + -> std::enable_if_t, hipError_t> +{ + return rocprim::radix_sort_pairs(d_temporary_storage, + temporary_storage_bytes, + d_keys_input, + d_keys_output, + d_values_input, + d_values_output, + size, + start_bit, + end_bit, + stream, + debug_synchronous); +} + +template +auto invoke_sort_pairs(void* d_temporary_storage, + size_t& temporary_storage_bytes, + Key* d_keys_input, + Key* d_keys_output, + Value* d_values_input, + Value* d_values_output, + size_t size, + unsigned int start_bit, + unsigned int end_bit, + hipStream_t stream, + bool debug_synchronous) + -> std::enable_if_t, hipError_t> +{ + return rocprim::radix_sort_pairs_desc(d_temporary_storage, + temporary_storage_bytes, + d_keys_input, + d_keys_output, + d_values_input, + d_values_output, + size, + start_bit, + end_bit, + stream, + debug_synchronous); +} + +template +auto invoke_sort_pairs(void* d_temporary_storage, + size_t& temporary_storage_bytes, + Key* d_keys_input, + Key* d_keys_output, + Value* d_values_input, + Value* d_values_output, + size_t size, + unsigned int start_bit, + unsigned int end_bit, + hipStream_t stream, + bool debug_synchronous) + -> std::enable_if_t, hipError_t> +{ + using decomposer_t = test_utils::custom_test_type_decomposer; + if(start_bit == 0 && end_bit == rocprim::detail::decomposer_max_bits::value) + { + return rocprim::radix_sort_pairs(d_temporary_storage, + temporary_storage_bytes, + d_keys_input, + d_keys_output, + d_values_input, + d_values_output, + size, + decomposer_t{}, + stream, + debug_synchronous); + } + else + { + return rocprim::radix_sort_pairs(d_temporary_storage, + temporary_storage_bytes, + d_keys_input, + d_keys_output, + d_values_input, + d_values_output, + size, + decomposer_t{}, + start_bit, + end_bit, + stream, + debug_synchronous); + } +} + +template +auto invoke_sort_pairs(void* d_temporary_storage, + size_t& temporary_storage_bytes, + Key* d_keys_input, + Key* d_keys_output, + Value* d_values_input, + Value* d_values_output, + size_t size, + unsigned int start_bit, + unsigned int end_bit, + hipStream_t stream, + bool debug_synchronous) + -> std::enable_if_t, hipError_t> +{ + using decomposer_t = test_utils::custom_test_type_decomposer; + if(start_bit == 0 && end_bit == rocprim::detail::decomposer_max_bits::value) + { + return rocprim::radix_sort_pairs_desc(d_temporary_storage, + temporary_storage_bytes, + d_keys_input, + d_keys_output, + d_values_input, + d_values_output, + size, + decomposer_t{}, + stream, + debug_synchronous); + } + else + { + return rocprim::radix_sort_pairs_desc(d_temporary_storage, + temporary_storage_bytes, + d_keys_input, + d_keys_output, + d_values_input, + d_values_output, + size, + decomposer_t{}, + start_bit, + end_bit, + stream, + debug_synchronous); + } } template -inline void sort_pairs() +void sort_pairs() { int device_id = test_common_utils::obtain_device_from_ctest(); SCOPED_TRACE(testing::Message() << "with device_id = " << device_id); @@ -266,30 +550,18 @@ inline void sort_pairs() for(size_t size : sizes) { if(size > (1 << 17) && !check_large_sizes) + { break; + } SCOPED_TRACE(testing::Message() << "with size = " << size); + engine_type rng_engine(seed_value); in_place = !in_place; // Generate data - std::vector keys_input; - if(rocprim::is_floating_point::value) - { - keys_input = test_utils::get_random_data(size, - static_cast(-1000), - static_cast(+1000), - seed_value); - test_utils::add_special_values(keys_input, seed_value); - } - else - { - keys_input - = test_utils::get_random_data(size, - std::numeric_limits::min(), - std::numeric_limits::max(), - seed_value); - } + auto keys_input = std::make_unique(size); + generate_key_input(keys_input.get(), size, rng_engine); std::vector values_input(size); test_utils::iota(values_input.begin(), values_input.end(), 0); @@ -307,7 +579,7 @@ inline void sort_pairs() test_common_utils::hipMallocHelper(&d_keys_output, size * sizeof(key_type))); } HIP_CHECK(hipMemcpy(d_keys_input, - keys_input.data(), + keys_input.get(), size * sizeof(key_type), hipMemcpyHostToDevice)); @@ -361,66 +633,59 @@ inline void sort_pairs() hipGraph_t graph; hipGraphExec_t graph_instance; - if (TestFixture::params::use_graphs) + if(TestFixture::params::use_graphs) + { graph = test_utils::createGraphHelper(stream); - + } + void* d_temporary_storage = nullptr; size_t temporary_storage_bytes; - HIP_CHECK(rocprim::radix_sort_pairs(d_temporary_storage, - temporary_storage_bytes, - d_keys_input, - d_keys_output, - d_values_input, - d_values_output, - size, - start_bit, - end_bit)); - - if (TestFixture::params::use_graphs) + HIP_CHECK((invoke_sort_pairs(d_temporary_storage, + temporary_storage_bytes, + d_keys_input, + d_keys_output, + d_values_input, + d_values_output, + size, + start_bit, + end_bit, + stream, + debug_synchronous))); + + if(TestFixture::params::use_graphs) + { graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); - + } + ASSERT_GT(temporary_storage_bytes, 0); HIP_CHECK( test_common_utils::hipMallocHelper(&d_temporary_storage, temporary_storage_bytes)); - if (TestFixture::params::use_graphs) - test_utils::resetGraphHelper(graph, graph_instance, stream); - - if(descending) + if(TestFixture::params::use_graphs) { - HIP_CHECK(rocprim::radix_sort_pairs_desc(d_temporary_storage, - temporary_storage_bytes, - d_keys_input, - d_keys_output, - d_values_input, - d_values_output, - size, - start_bit, - end_bit, - stream, - debug_synchronous)); + test_utils::resetGraphHelper(graph, graph_instance, stream); } - else + + HIP_CHECK((invoke_sort_pairs(d_temporary_storage, + temporary_storage_bytes, + d_keys_input, + d_keys_output, + d_values_input, + d_values_output, + size, + start_bit, + end_bit, + stream, + debug_synchronous))); + + if(TestFixture::params::use_graphs) { - HIP_CHECK(rocprim::radix_sort_pairs(d_temporary_storage, - temporary_storage_bytes, - d_keys_input, - d_keys_output, - d_values_input, - d_values_output, - size, - start_bit, - end_bit, - stream, - debug_synchronous)); + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); } - if (TestFixture::params::use_graphs) - graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); - - std::vector keys_output(size); - HIP_CHECK(hipMemcpy(keys_output.data(), + auto keys_output = std::make_unique(size); + HIP_CHECK(hipMemcpy(keys_output.get(), d_keys_output, size * sizeof(key_type), hipMemcpyDeviceToHost)); @@ -440,20 +705,144 @@ inline void sort_pairs() HIP_CHECK(hipFree(d_values_output)); } - if (TestFixture::params::use_graphs) + if(TestFixture::params::use_graphs) + { test_utils::cleanupGraphHelper(graph, graph_instance); + } - ASSERT_NO_FATAL_FAILURE(test_utils::assert_bit_eq(keys_output, keys_expected)); - ASSERT_NO_FATAL_FAILURE(test_utils::assert_bit_eq(values_output, values_expected)); + ASSERT_NO_FATAL_FAILURE(test_utils::assert_bit_eq(keys_output.get(), + keys_output.get() + size, + keys_expected.begin(), + keys_expected.end())); + ASSERT_NO_FATAL_FAILURE(test_utils::assert_bit_eq(values_output.begin(), + values_output.end(), + values_expected.begin(), + values_expected.end())); } } - if (TestFixture::params::use_graphs) + if(TestFixture::params::use_graphs) + { HIP_CHECK(hipStreamDestroy(stream)); + } +} + +template +auto invoke_sort_keys(void* d_temporary_storage, + size_t& temporary_storage_bytes, + rocprim::double_buffer& d_keys, + size_t size, + unsigned int start_bit, + unsigned int end_bit, + hipStream_t stream, + bool debug_synchronous) + -> std::enable_if_t, hipError_t> +{ + return rocprim::radix_sort_keys(d_temporary_storage, + temporary_storage_bytes, + d_keys, + size, + start_bit, + end_bit, + stream, + debug_synchronous); +} + +template +auto invoke_sort_keys(void* d_temporary_storage, + size_t& temporary_storage_bytes, + rocprim::double_buffer& d_keys, + size_t size, + unsigned int start_bit, + unsigned int end_bit, + hipStream_t stream, + bool debug_synchronous) + -> std::enable_if_t, hipError_t> +{ + return rocprim::radix_sort_keys_desc(d_temporary_storage, + temporary_storage_bytes, + d_keys, + size, + start_bit, + end_bit, + stream, + debug_synchronous); +} + +template +auto invoke_sort_keys(void* d_temporary_storage, + size_t& temporary_storage_bytes, + rocprim::double_buffer& d_keys, + size_t size, + unsigned int start_bit, + unsigned int end_bit, + hipStream_t stream, + bool debug_synchronous) + -> std::enable_if_t, hipError_t> +{ + using decomposer_t = test_utils::custom_test_type_decomposer; + if(start_bit == 0 && end_bit == rocprim::detail::decomposer_max_bits::value) + { + return rocprim::radix_sort_keys(d_temporary_storage, + temporary_storage_bytes, + d_keys, + size, + decomposer_t{}, + stream, + debug_synchronous); + } + else + { + return rocprim::radix_sort_keys(d_temporary_storage, + temporary_storage_bytes, + d_keys, + size, + decomposer_t{}, + start_bit, + end_bit, + stream, + debug_synchronous); + } +} + +template +auto invoke_sort_keys(void* d_temporary_storage, + size_t& temporary_storage_bytes, + rocprim::double_buffer& d_keys, + size_t size, + unsigned int start_bit, + unsigned int end_bit, + hipStream_t stream, + bool debug_synchronous) + -> std::enable_if_t, hipError_t> +{ + using decomposer_t = test_utils::custom_test_type_decomposer; + if(start_bit == 0 && end_bit == rocprim::detail::decomposer_max_bits::value) + { + return rocprim::radix_sort_keys_desc(d_temporary_storage, + temporary_storage_bytes, + d_keys, + size, + decomposer_t{}, + stream, + debug_synchronous); + } + else + { + return rocprim::radix_sort_keys_desc(d_temporary_storage, + temporary_storage_bytes, + d_keys, + size, + decomposer_t{}, + start_bit, + end_bit, + stream, + debug_synchronous); + } } template -inline void sort_keys_double_buffer() +void sort_keys_double_buffer() { int device_id = test_common_utils::obtain_device_from_ctest(); SCOPED_TRACE(testing::Message() << "with device_id = " << device_id); @@ -486,40 +875,29 @@ inline void sort_keys_double_buffer() for(size_t size : sizes) { if(size > (1 << 17) && !check_large_sizes) + { break; + } SCOPED_TRACE(testing::Message() << "with size = " << size); + engine_type rng_engine(seed_value); + // Generate data - std::vector keys_input; - if(rocprim::is_floating_point::value) - { - keys_input = test_utils::get_random_data(size, - static_cast(-1000), - static_cast(+1000), - seed_value); - test_utils::add_special_values(keys_input, seed_value); - } - else - { - keys_input - = test_utils::get_random_data(size, - std::numeric_limits::min(), - std::numeric_limits::max(), - seed_value); - } + auto keys_input = std::make_unique(size); + generate_key_input(keys_input.get(), size, rng_engine); key_type* d_keys_input; key_type* d_keys_output; HIP_CHECK(test_common_utils::hipMallocHelper(&d_keys_input, size * sizeof(key_type))); HIP_CHECK(test_common_utils::hipMallocHelper(&d_keys_output, size * sizeof(key_type))); HIP_CHECK(hipMemcpy(d_keys_input, - keys_input.data(), + keys_input.get(), size * sizeof(key_type), hipMemcpyHostToDevice)); // Calculate expected results on host - std::vector expected(keys_input); + std::vector expected(keys_input.get(), keys_input.get() + size); std::stable_sort( expected.begin(), expected.end(), @@ -527,61 +905,49 @@ inline void sort_keys_double_buffer() rocprim::double_buffer d_keys(d_keys_input, d_keys_output); - hipGraph_t graph; - hipGraphExec_t graph_instance; - if (TestFixture::params::use_graphs) - graph = test_utils::createGraphHelper(stream); - size_t temporary_storage_bytes; - HIP_CHECK(rocprim::radix_sort_keys(nullptr, - temporary_storage_bytes, - d_keys, - size, - start_bit, - end_bit)); - - if (TestFixture::params::use_graphs) - graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); - + HIP_CHECK( + (invoke_sort_keys(nullptr, + temporary_storage_bytes, + d_keys, + size, + start_bit, + end_bit, + stream, + debug_synchronous))); + ASSERT_GT(temporary_storage_bytes, 0); void* d_temporary_storage; HIP_CHECK( test_common_utils::hipMallocHelper(&d_temporary_storage, temporary_storage_bytes)); - if (TestFixture::params::use_graphs) - test_utils::resetGraphHelper(graph, graph_instance, stream); - - if(descending) + hipGraph_t graph; + if(TestFixture::params::use_graphs) { - HIP_CHECK(rocprim::radix_sort_keys_desc(d_temporary_storage, - temporary_storage_bytes, - d_keys, - size, - start_bit, - end_bit, - stream, - debug_synchronous)); + graph = test_utils::createGraphHelper(stream); } - else + + HIP_CHECK( + (invoke_sort_keys(d_temporary_storage, + temporary_storage_bytes, + d_keys, + size, + start_bit, + end_bit, + stream, + debug_synchronous))); + + hipGraphExec_t graph_instance; + if(TestFixture::params::use_graphs) { - HIP_CHECK(rocprim::radix_sort_keys(d_temporary_storage, - temporary_storage_bytes, - d_keys, - size, - start_bit, - end_bit, - stream, - debug_synchronous)); + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); } - if (TestFixture::params::use_graphs) - graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); - HIP_CHECK(hipFree(d_temporary_storage)); - std::vector keys_output(size); - HIP_CHECK(hipMemcpy(keys_output.data(), + auto keys_output = std::make_unique(size); + HIP_CHECK(hipMemcpy(keys_output.get(), d_keys.current(), size * sizeof(key_type), hipMemcpyDeviceToHost)); @@ -589,19 +955,150 @@ inline void sort_keys_double_buffer() HIP_CHECK(hipFree(d_keys_input)); HIP_CHECK(hipFree(d_keys_output)); - if (TestFixture::params::use_graphs) + if(TestFixture::params::use_graphs) + { test_utils::cleanupGraphHelper(graph, graph_instance); - - ASSERT_NO_FATAL_FAILURE(test_utils::assert_bit_eq(keys_output, expected)); + } + + ASSERT_NO_FATAL_FAILURE(test_utils::assert_bit_eq(keys_output.get(), + keys_output.get() + size, + expected.begin(), + expected.end())); } } - if (TestFixture::params::use_graphs) + if(TestFixture::params::use_graphs) + { HIP_CHECK(hipStreamDestroy(stream)); + } +} + +template +auto invoke_sort_pairs(void* d_temporary_storage, + size_t& temporary_storage_bytes, + rocprim::double_buffer& d_keys, + rocprim::double_buffer& d_values, + size_t size, + unsigned int start_bit, + unsigned int end_bit, + hipStream_t stream, + bool debug_synchronous) + -> std::enable_if_t, hipError_t> +{ + return rocprim::radix_sort_pairs(d_temporary_storage, + temporary_storage_bytes, + d_keys, + d_values, + size, + start_bit, + end_bit, + stream, + debug_synchronous); +} + +template +auto invoke_sort_pairs(void* d_temporary_storage, + size_t& temporary_storage_bytes, + rocprim::double_buffer& d_keys, + rocprim::double_buffer& d_values, + size_t size, + unsigned int start_bit, + unsigned int end_bit, + hipStream_t stream, + bool debug_synchronous) + -> std::enable_if_t, hipError_t> +{ + return rocprim::radix_sort_pairs_desc(d_temporary_storage, + temporary_storage_bytes, + d_keys, + d_values, + size, + start_bit, + end_bit, + stream, + debug_synchronous); +} + +template +auto invoke_sort_pairs(void* d_temporary_storage, + size_t& temporary_storage_bytes, + rocprim::double_buffer& d_keys, + rocprim::double_buffer& d_values, + size_t size, + unsigned int start_bit, + unsigned int end_bit, + hipStream_t stream, + bool debug_synchronous) + -> std::enable_if_t, hipError_t> +{ + using decomposer_t = test_utils::custom_test_type_decomposer; + if(start_bit == 0 && end_bit == rocprim::detail::decomposer_max_bits::value) + { + return rocprim::radix_sort_pairs(d_temporary_storage, + temporary_storage_bytes, + d_keys, + d_values, + size, + decomposer_t{}, + stream, + debug_synchronous); + } + else + { + return rocprim::radix_sort_pairs(d_temporary_storage, + temporary_storage_bytes, + d_keys, + d_values, + size, + decomposer_t{}, + start_bit, + end_bit, + stream, + debug_synchronous); + } +} + +template +auto invoke_sort_pairs(void* d_temporary_storage, + size_t& temporary_storage_bytes, + rocprim::double_buffer& d_keys, + rocprim::double_buffer& d_values, + size_t size, + unsigned int start_bit, + unsigned int end_bit, + hipStream_t stream, + bool debug_synchronous) + -> std::enable_if_t, hipError_t> +{ + using decomposer_t = test_utils::custom_test_type_decomposer; + if(start_bit == 0 && end_bit == rocprim::detail::decomposer_max_bits::value) + { + return rocprim::radix_sort_pairs_desc(d_temporary_storage, + temporary_storage_bytes, + d_keys, + d_values, + size, + decomposer_t{}, + stream, + debug_synchronous); + } + else + { + return rocprim::radix_sort_pairs_desc(d_temporary_storage, + temporary_storage_bytes, + d_keys, + d_values, + size, + decomposer_t{}, + start_bit, + end_bit, + stream, + debug_synchronous); + } } template -inline void sort_pairs_double_buffer() +void sort_pairs_double_buffer() { int device_id = test_common_utils::obtain_device_from_ctest(); SCOPED_TRACE(testing::Message() << "with device_id = " << device_id); @@ -635,28 +1132,17 @@ inline void sort_pairs_double_buffer() for(size_t size : sizes) { if(size > (1 << 17) && !check_large_sizes) + { break; + } SCOPED_TRACE(testing::Message() << "with size = " << size); + engine_type rng_engine(seed_value); + // Generate data - std::vector keys_input; - if(rocprim::is_floating_point::value) - { - keys_input = test_utils::get_random_data(size, - static_cast(-1000), - static_cast(+1000), - seed_value); - test_utils::add_special_values(keys_input, seed_value); - } - else - { - keys_input - = test_utils::get_random_data(size, - std::numeric_limits::min(), - std::numeric_limits::max(), - seed_value); - } + auto keys_input = std::make_unique(size); + generate_key_input(keys_input.get(), size, rng_engine); std::vector values_input(size); test_utils::iota(values_input.begin(), values_input.end(), 0); @@ -666,7 +1152,7 @@ inline void sort_pairs_double_buffer() HIP_CHECK(test_common_utils::hipMallocHelper(&d_keys_input, size * sizeof(key_type))); HIP_CHECK(test_common_utils::hipMallocHelper(&d_keys_output, size * sizeof(key_type))); HIP_CHECK(hipMemcpy(d_keys_input, - keys_input.data(), + keys_input.get(), size * sizeof(key_type), hipMemcpyHostToDevice)); @@ -705,64 +1191,51 @@ inline void sort_pairs_double_buffer() rocprim::double_buffer d_keys(d_keys_input, d_keys_output); rocprim::double_buffer d_values(d_values_input, d_values_output); - hipGraph_t graph; - hipGraphExec_t graph_instance; - if (TestFixture::params::use_graphs) - graph = test_utils::createGraphHelper(stream); - void* d_temporary_storage = nullptr; size_t temporary_storage_bytes; - HIP_CHECK(rocprim::radix_sort_pairs(d_temporary_storage, - temporary_storage_bytes, - d_keys, - d_values, - size, - start_bit, - end_bit)); + HIP_CHECK( + (invoke_sort_pairs(d_temporary_storage, + temporary_storage_bytes, + d_keys, + d_values, + size, + start_bit, + end_bit, + stream, + debug_synchronous))); - if (TestFixture::params::use_graphs) - graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); - ASSERT_GT(temporary_storage_bytes, 0); HIP_CHECK( test_common_utils::hipMallocHelper(&d_temporary_storage, temporary_storage_bytes)); - if (TestFixture::params::use_graphs) - test_utils::resetGraphHelper(graph, graph_instance, stream); - - if(descending) + hipGraph_t graph; + if(TestFixture::params::use_graphs) { - HIP_CHECK(rocprim::radix_sort_pairs_desc(d_temporary_storage, - temporary_storage_bytes, - d_keys, - d_values, - size, - start_bit, - end_bit, - stream, - debug_synchronous)); + graph = test_utils::createGraphHelper(stream); } - else + + HIP_CHECK( + (invoke_sort_pairs(d_temporary_storage, + temporary_storage_bytes, + d_keys, + d_values, + size, + start_bit, + end_bit, + stream, + debug_synchronous))); + + hipGraphExec_t graph_instance; + if(TestFixture::params::use_graphs) { - HIP_CHECK(rocprim::radix_sort_pairs(d_temporary_storage, - temporary_storage_bytes, - d_keys, - d_values, - size, - start_bit, - end_bit, - stream, - debug_synchronous)); + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); } - if (TestFixture::params::use_graphs) - graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); - HIP_CHECK(hipFree(d_temporary_storage)); - std::vector keys_output(size); - HIP_CHECK(hipMemcpy(keys_output.data(), + auto keys_output = std::make_unique(size); + HIP_CHECK(hipMemcpy(keys_output.get(), d_keys.current(), size * sizeof(key_type), hipMemcpyDeviceToHost)); @@ -778,20 +1251,30 @@ inline void sort_pairs_double_buffer() HIP_CHECK(hipFree(d_values_input)); HIP_CHECK(hipFree(d_values_output)); - if (TestFixture::params::use_graphs) + if(TestFixture::params::use_graphs) + { test_utils::cleanupGraphHelper(graph, graph_instance); - - ASSERT_NO_FATAL_FAILURE(test_utils::assert_bit_eq(keys_output, keys_expected)); - ASSERT_NO_FATAL_FAILURE(test_utils::assert_bit_eq(values_output, values_expected)); + } + + ASSERT_NO_FATAL_FAILURE(test_utils::assert_bit_eq(keys_output.get(), + keys_output.get() + size, + keys_expected.begin(), + keys_expected.end())); + ASSERT_NO_FATAL_FAILURE(test_utils::assert_bit_eq(values_output.begin(), + values_output.end(), + values_expected.begin(), + values_expected.end())); } } - if (TestFixture::params::use_graphs) + if(TestFixture::params::use_graphs) + { HIP_CHECK(hipStreamDestroy(stream)); + } } template -inline void sort_keys_over_4g() +void sort_keys_over_4g() { using key_type = uint8_t; constexpr unsigned int start_bit = 0; @@ -832,11 +1315,6 @@ inline void sort_keys_over_4g() key_type_storage_bytes, hipMemcpyHostToDevice)); - hipGraph_t graph; - hipGraphExec_t graph_instance; - if (UseGraphs) - graph = test_utils::createGraphHelper(stream); - size_t temporary_storage_bytes; HIP_CHECK(rocprim::radix_sort_keys(nullptr, temporary_storage_bytes, @@ -848,9 +1326,6 @@ inline void sort_keys_over_4g() stream, debug_synchronous)); - if (UseGraphs) - graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); - ASSERT_GT(temporary_storage_bytes, 0); hipDeviceProp_t prop; @@ -866,8 +1341,11 @@ inline void sort_keys_over_4g() void* d_temporary_storage; HIP_CHECK(test_common_utils::hipMallocHelper(&d_temporary_storage, temporary_storage_bytes)); - if (UseGraphs) - test_utils::resetGraphHelper(graph, graph_instance, stream); + hipGraph_t graph; + if(UseGraphs) + { + graph = test_utils::createGraphHelper(stream); + } HIP_CHECK(rocprim::radix_sort_keys(d_temporary_storage, temporary_storage_bytes, @@ -879,9 +1357,12 @@ inline void sort_keys_over_4g() stream, debug_synchronous)); - if (UseGraphs) + hipGraphExec_t graph_instance; + if(UseGraphs) + { graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); - + } + std::vector output(keys_input.size()); HIP_CHECK(hipMemcpy(output.data(), d_keys_input_output, diff --git a/test/rocprim/test_device_reduce.cpp b/test/rocprim/test_device_reduce.cpp index 3cb7458ea..6da75cd28 100644 --- a/test/rocprim/test_device_reduce.cpp +++ b/test/rocprim/test_device_reduce.cpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -98,25 +98,23 @@ typedef ::testing::Types< DeviceReduceParams, DeviceReduceParamsList(int, int, false, 512), DeviceReduceParamsList(float, float, false, 2048), + DeviceReduceParamsList(double, double, false, 2048), DeviceReduceParamsList(int, int, false, 4096), DeviceReduceParamsList(int, int, false, 2097152), DeviceReduceParamsList(int, int, false, 1073741824), DeviceReduceParams, DeviceReduceParams, - // #156 temporarily disable half test due to known issue with converting from double to half - // DeviceReduceParams, + DeviceReduceParams, DeviceReduceParams, DeviceReduceParams, test_utils::custom_test_type>, DeviceReduceParams, test_utils::custom_test_type>, DeviceReduceParams> RocprimDeviceReduceTestsParams; -typedef ::testing::Types< - DeviceReduceParams, - DeviceReduceParamsList(float, float, false, 2048), - // #156 temporarily disable half test due to known issue with converting from double to half - // DeviceReduceParams, - DeviceReduceParams> +typedef ::testing::Types, + DeviceReduceParamsList(float, float, false, 2048), + DeviceReduceParams, + DeviceReduceParams> RocprimDeviceReducePrecisionTestsParams; TYPED_TEST_SUITE(RocprimDeviceReduceTests, RocprimDeviceReduceTestsParams); @@ -145,11 +143,6 @@ TYPED_TEST(RocprimDeviceReduceTests, ReduceEmptyInput) const U initial_value = U(1234); - hipGraph_t graph; - hipGraphExec_t graph_instance; - if (TestFixture::use_graphs) - graph = test_utils::createGraphHelper(stream); - size_t temp_storage_size_bytes; // Get size of d_temp_storage HIP_CHECK( @@ -162,14 +155,14 @@ TYPED_TEST(RocprimDeviceReduceTests, ReduceEmptyInput) ) ); - if (TestFixture::use_graphs) - graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); - void * d_temp_storage = nullptr; HIP_CHECK(test_common_utils::hipMallocHelper(&d_temp_storage, temp_storage_size_bytes)); - if (TestFixture::use_graphs) - test_utils::resetGraphHelper(graph, graph_instance, stream); + hipGraph_t graph; + if(TestFixture::use_graphs) + { + graph = test_utils::createGraphHelper(stream); + } // Run HIP_CHECK( @@ -182,9 +175,12 @@ TYPED_TEST(RocprimDeviceReduceTests, ReduceEmptyInput) ) ); - if (TestFixture::use_graphs) + hipGraphExec_t graph_instance; + if(TestFixture::use_graphs) + { graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); - + } + HIP_CHECK(hipDeviceSynchronize()); U output; @@ -269,27 +265,20 @@ TYPED_TEST(RocprimDeviceReduceTests, ReduceSum) if(size == 0) expected = U(); - hipGraph_t graph; - hipGraphExec_t graph_instance; - if (TestFixture::use_graphs) - graph = test_utils::createGraphHelper(stream); - // temp storage size_t temp_storage_size_bytes; void * d_temp_storage = nullptr; // Get size of d_temp_storage - HIP_CHECK( - rocprim::reduce( - d_temp_storage, temp_storage_size_bytes, - d_input, - test_utils::wrap_in_identity_iterator(d_output), - input.size(), rocprim::plus(), stream, debug_synchronous - ) - ); + HIP_CHECK(rocprim::reduce( + d_temp_storage, + temp_storage_size_bytes, + d_input, + test_utils::wrap_in_identity_iterator(d_output), + input.size(), + rocprim::plus(), + stream, + debug_synchronous)); - if (TestFixture::use_graphs) - graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); - // temp_storage_size_bytes must be >0 ASSERT_GT(temp_storage_size_bytes, 0); @@ -297,9 +286,12 @@ TYPED_TEST(RocprimDeviceReduceTests, ReduceSum) HIP_CHECK(test_common_utils::hipMallocHelper(&d_temp_storage, temp_storage_size_bytes)); HIP_CHECK(hipDeviceSynchronize()); - if (TestFixture::use_graphs) - test_utils::resetGraphHelper(graph, graph_instance, stream); - + hipGraph_t graph; + if(TestFixture::use_graphs) + { + graph = test_utils::createGraphHelper(stream); + } + // Run HIP_CHECK( rocprim::reduce( @@ -310,8 +302,11 @@ TYPED_TEST(RocprimDeviceReduceTests, ReduceSum) ) ); - if (TestFixture::use_graphs) + hipGraphExec_t graph_instance; + if(TestFixture::use_graphs) + { graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); + } HIP_CHECK(hipGetLastError()); HIP_CHECK(hipDeviceSynchronize()); @@ -423,11 +418,6 @@ TYPED_TEST(RocprimDeviceReduceTests, ReduceArgMinimum) expected = reduce_op(expected, input[i]); } - hipGraph_t graph; - hipGraphExec_t graph_instance; - if (TestFixture::use_graphs) - graph = test_utils::createGraphHelper(stream); - // temp storage size_t temp_storage_size_bytes; void * d_temp_storage = nullptr; @@ -441,9 +431,6 @@ TYPED_TEST(RocprimDeviceReduceTests, ReduceArgMinimum) ) ); - if (TestFixture::use_graphs) - graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); - // temp_storage_size_bytes must be >0 ASSERT_GT(temp_storage_size_bytes, 0); @@ -451,9 +438,12 @@ TYPED_TEST(RocprimDeviceReduceTests, ReduceArgMinimum) HIP_CHECK(test_common_utils::hipMallocHelper(&d_temp_storage, temp_storage_size_bytes)); HIP_CHECK(hipDeviceSynchronize()); - if (TestFixture::use_graphs) - test_utils::resetGraphHelper(graph, graph_instance, stream); - + hipGraph_t graph; + if(TestFixture::use_graphs) + { + graph = test_utils::createGraphHelper(stream); + } + // Run HIP_CHECK( rocprim::reduce( @@ -464,9 +454,12 @@ TYPED_TEST(RocprimDeviceReduceTests, ReduceArgMinimum) ) ); - if (TestFixture::use_graphs) + hipGraphExec_t graph_instance; + if(TestFixture::use_graphs) + { graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); - + } + HIP_CHECK(hipGetLastError()); HIP_CHECK(hipDeviceSynchronize()); @@ -530,11 +523,6 @@ void testLargeIndices() T* d_output = nullptr; HIP_CHECK(test_common_utils::hipMallocHelper(&d_output, sizeof(T))); - hipGraph_t graph; - hipGraphExec_t graph_instance; - if (use_graphs) - graph = test_utils::createGraphHelper(stream); - // temp storage size_t temp_storage_size_bytes = 0; void* d_temp_storage = nullptr; @@ -544,20 +532,20 @@ void testLargeIndices() input, d_output, size, - rocprim::plus {}, + rocprim::plus{}, stream, debug_synchronous)); - if (use_graphs) - graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); - // allocate temporary storage HIP_CHECK(test_common_utils::hipMallocHelper(&d_temp_storage, temp_storage_size_bytes)); HIP_CHECK(hipDeviceSynchronize()); - if (use_graphs) - test_utils::resetGraphHelper(graph, graph_instance, stream); - + hipGraph_t graph; + if(use_graphs) + { + graph = test_utils::createGraphHelper(stream); + } + // Run HIP_CHECK(rocprim::reduce(d_temp_storage, temp_storage_size_bytes, @@ -568,9 +556,12 @@ void testLargeIndices() stream, debug_synchronous)); - if (use_graphs) + hipGraphExec_t graph_instance; + if(use_graphs) + { graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, false); - + } + HIP_CHECK(hipGetLastError()); HIP_CHECK(hipDeviceSynchronize()); @@ -588,13 +579,17 @@ void testLargeIndices() hipFree(d_temp_storage); hipFree(d_output); - if (use_graphs) + if(use_graphs) + { test_utils::cleanupGraphHelper(graph, graph_instance); + } } } - if (use_graphs) + if(use_graphs) + { HIP_CHECK(hipStreamDestroy(stream)); + } } TEST(RocprimDeviceReduceTests, LargeIndices) @@ -664,13 +659,8 @@ TYPED_TEST(RocprimDeviceReducePrecisionTests, ReduceSumInputEqualExponentFunctio HIP_CHECK(hipDeviceSynchronize()); // Calculate expected results on host mathematically (instead of using reduce on host) - U expected = static_cast(static_cast(size) * static_cast(lowest)); + U expected = static_cast(static_cast(size) * static_cast(lowest)); - hipGraph_t graph; - hipGraphExec_t graph_instance; - if (TestFixture::use_graphs) - graph = test_utils::createGraphHelper(stream); - // temp storage size_t temp_storage_size_bytes; void* d_temp_storage = nullptr; @@ -685,9 +675,6 @@ TYPED_TEST(RocprimDeviceReducePrecisionTests, ReduceSumInputEqualExponentFunctio stream, debug_synchronous)); - if (TestFixture::use_graphs) - graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); - // temp_storage_size_bytes must be >0 ASSERT_GT(temp_storage_size_bytes, 0); @@ -695,9 +682,12 @@ TYPED_TEST(RocprimDeviceReducePrecisionTests, ReduceSumInputEqualExponentFunctio HIP_CHECK(test_common_utils::hipMallocHelper(&d_temp_storage, temp_storage_size_bytes)); HIP_CHECK(hipDeviceSynchronize()); - if (TestFixture::use_graphs) - test_utils::resetGraphHelper(graph, graph_instance, stream); - + hipGraph_t graph; + if(TestFixture::use_graphs) + { + graph = test_utils::createGraphHelper(stream); + } + // Run HIP_CHECK(rocprim::reduce( d_temp_storage, @@ -709,9 +699,12 @@ TYPED_TEST(RocprimDeviceReducePrecisionTests, ReduceSumInputEqualExponentFunctio stream, debug_synchronous)); - if (TestFixture::use_graphs) + hipGraphExec_t graph_instance; + if(TestFixture::use_graphs) + { graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); - + } + HIP_CHECK(hipGetLastError()); HIP_CHECK(hipDeviceSynchronize()); @@ -792,11 +785,6 @@ TYPED_TEST(RocprimDeviceReduceTests, ReduceMinimum) expected = min_op(expected, input[i]); } - hipGraph_t graph; - hipGraphExec_t graph_instance; - if (TestFixture::use_graphs) - graph = test_utils::createGraphHelper(stream); - // Get size of d_temp_storage size_t temp_storage_size_bytes; void * d_temp_storage = nullptr; @@ -809,9 +797,6 @@ TYPED_TEST(RocprimDeviceReduceTests, ReduceMinimum) ) ); - if (TestFixture::use_graphs) - graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); - // temp_storage_size_bytes must be >0 ASSERT_GT(temp_storage_size_bytes, 0); @@ -819,8 +804,11 @@ TYPED_TEST(RocprimDeviceReduceTests, ReduceMinimum) HIP_CHECK(test_common_utils::hipMallocHelper(&d_temp_storage, temp_storage_size_bytes)); HIP_CHECK(hipDeviceSynchronize()); - if (TestFixture::use_graphs) - test_utils::resetGraphHelper(graph, graph_instance, stream); + hipGraph_t graph; + if(TestFixture::use_graphs) + { + graph = test_utils::createGraphHelper(stream); + } // Run HIP_CHECK( @@ -841,9 +829,12 @@ TYPED_TEST(RocprimDeviceReduceTests, ReduceMinimum) ) ); - if (TestFixture::use_graphs) + hipGraphExec_t graph_instance; + if(TestFixture::use_graphs) + { graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, false); - + } + HIP_CHECK(hipDeviceSynchronize()); // Check if output values are as expected diff --git a/test/rocprim/test_device_reduce_by_key.cpp b/test/rocprim/test_device_reduce_by_key.cpp index 9c0f4dd43..877d33472 100644 --- a/test/rocprim/test_device_reduce_by_key.cpp +++ b/test/rocprim/test_device_reduce_by_key.cpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -30,6 +30,7 @@ #include // required test headers +#include "rocprim/types.hpp" #include "test_utils_custom_test_types.hpp" #include "test_utils_types.hpp" @@ -95,7 +96,11 @@ typedef ::testing::Types< params, 1, 10>, params, 1, 30>, params, 15, 100>, + // half should be supported, but is missing some key operators. + // we should uncomment these, as soon as these are implemented and the tests compile and work as intended. + //params, 15, 100>, params, 15, 100>, + params, 15, 100>, params, 20, 100>, params, 100, 400, long long, custom_key_compare_op1>, params, 200, 600>, @@ -134,12 +139,10 @@ TYPED_TEST(RocprimDeviceReduceByKey, ReduceByKey) typename std::conditional< test_utils::is_valid_for_int_distribution::value, std::uniform_int_distribution, - typename std::conditional::value, - std::uniform_int_distribution, - std::uniform_int_distribution - >::type - >::type - >::type; + typename std::conditional::value, + std::uniform_int_distribution, + std::uniform_int_distribution>::type>::type>:: + type; constexpr bool use_identity_iterator = TestFixture::params::use_identity_iterator; const bool debug_synchronous = false; @@ -183,8 +186,8 @@ TYPED_TEST(RocprimDeviceReduceByKey, ReduceByKey) std::vector values_input = test_utils::get_random_data(size, 0, 100, seed_value); size_t offset = 0; - key_type prev_key = key_distribution_type(0, 100)(gen); - key_type current_key = prev_key + key_delta_dis(gen); + key_type prev_key = static_cast(key_distribution_type(0, 100)(gen)); + key_type current_key = static_cast(prev_key + key_delta_dis(gen)); while(offset < size) { const size_t key_count = key_count_dis(gen); @@ -256,34 +259,31 @@ TYPED_TEST(RocprimDeviceReduceByKey, ReduceByKey) size_t temporary_storage_bytes; - hipGraph_t graph; - hipGraphExec_t graph_instance; - if (TestFixture::params::use_graphs) - graph = test_utils::createGraphHelper(stream); - - HIP_CHECK( - rocprim::reduce_by_key( - nullptr, temporary_storage_bytes, - d_keys_input, d_values_input, size, - test_utils::wrap_in_identity_iterator(d_unique_output), - test_utils::wrap_in_identity_iterator(d_aggregates_output), - d_unique_count_output, - reduce_op, key_compare_op, - stream, debug_synchronous - ) - ); + HIP_CHECK(rocprim::reduce_by_key( + nullptr, + temporary_storage_bytes, + d_keys_input, + d_values_input, + size, + test_utils::wrap_in_identity_iterator(d_unique_output), + test_utils::wrap_in_identity_iterator(d_aggregates_output), + d_unique_count_output, + reduce_op, + key_compare_op, + stream, + debug_synchronous)); - if (TestFixture::params::use_graphs) - graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); - ASSERT_GT(temporary_storage_bytes, 0); void * d_temporary_storage; HIP_CHECK(test_common_utils::hipMallocHelper(&d_temporary_storage, temporary_storage_bytes)); - if (TestFixture::params::use_graphs) - test_utils::resetGraphHelper(graph, graph_instance, stream); - + hipGraph_t graph; + if(TestFixture::params::use_graphs) + { + graph = test_utils::createGraphHelper(stream); + } + HIP_CHECK( rocprim::reduce_by_key( d_temporary_storage, temporary_storage_bytes, @@ -295,9 +295,12 @@ TYPED_TEST(RocprimDeviceReduceByKey, ReduceByKey) ) ); - if (TestFixture::params::use_graphs) + hipGraphExec_t graph_instance; + if(TestFixture::params::use_graphs) + { graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); - + } + HIP_CHECK(hipFree(d_temporary_storage)); std::vector unique_output(unique_count_expected); @@ -401,11 +404,6 @@ void large_indices_reduce_by_key() size_t temporary_storage_bytes; - hipGraph_t graph; - hipGraphExec_t graph_instance; - if (use_graphs) - graph = test_utils::createGraphHelper(stream); - HIP_CHECK(rocprim::reduce_by_key(nullptr, temporary_storage_bytes, d_keys_input, @@ -419,17 +417,17 @@ void large_indices_reduce_by_key() stream, debug_synchronous)); - if (use_graphs) - graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); - ASSERT_GT(temporary_storage_bytes, 0); void* d_temporary_storage; HIP_CHECK( test_common_utils::hipMallocHelper(&d_temporary_storage, temporary_storage_bytes)); - if (use_graphs) - test_utils::resetGraphHelper(graph, graph_instance, stream); + hipGraph_t graph; + if(use_graphs) + { + graph = test_utils::createGraphHelper(stream); + } HIP_CHECK(rocprim::reduce_by_key(d_temporary_storage, temporary_storage_bytes, @@ -444,8 +442,11 @@ void large_indices_reduce_by_key() stream, debug_synchronous)); - if (use_graphs) + hipGraphExec_t graph_instance; + if(use_graphs) + { graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); + } HIP_CHECK(hipFree(d_temporary_storage)); @@ -469,9 +470,11 @@ void large_indices_reduce_by_key() HIP_CHECK(hipFree(d_aggregates_output)); HIP_CHECK(hipFree(d_unique_count_output)); - if (use_graphs) + if(use_graphs) + { test_utils::cleanupGraphHelper(graph, graph_instance); - + } + ASSERT_EQ(unique_count_output[0], unique_count_expected); size_t total_size = 0; @@ -488,8 +491,10 @@ void large_indices_reduce_by_key() ASSERT_EQ(value_type(size - total_size), aggregates_output[last_idx]); } - if (use_graphs) + if(use_graphs) + { HIP_CHECK(hipStreamDestroy(stream)); + } } TEST(RocprimDeviceReduceByKey, LargeIndicesReduceByKeySmallValueType) @@ -548,11 +553,6 @@ void large_segment_count_reduce_by_key() HIP_CHECK(test_common_utils::hipMallocHelper(&d_unique_count_output, sizeof(*d_unique_count_output))); - hipGraph_t graph; - hipGraphExec_t graph_instance; - if (use_graphs) - graph = test_utils::createGraphHelper(stream); - size_t temporary_storage_bytes; HIP_CHECK(rocprim::reduce_by_key(nullptr, temporary_storage_bytes, @@ -566,18 +566,19 @@ void large_segment_count_reduce_by_key() key_compare_op, stream, debug_synchronous)); - if (use_graphs) - graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); - + ASSERT_GT(temporary_storage_bytes, 0); void* d_temporary_storage; HIP_CHECK( test_common_utils::hipMallocHelper(&d_temporary_storage, temporary_storage_bytes)); - if (use_graphs) - test_utils::resetGraphHelper(graph, graph_instance, stream); - + hipGraph_t graph; + if(use_graphs) + { + graph = test_utils::createGraphHelper(stream); + } + HIP_CHECK(rocprim::reduce_by_key(d_temporary_storage, temporary_storage_bytes, d_keys_input, @@ -590,9 +591,13 @@ void large_segment_count_reduce_by_key() key_compare_op, stream, debug_synchronous)); - if (use_graphs) + + hipGraphExec_t graph_instance; + if(use_graphs) + { graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); - + } + HIP_CHECK(hipFree(d_temporary_storage)); size_t unique_count_output; diff --git a/test/rocprim/test_device_run_length_encode.cpp b/test/rocprim/test_device_run_length_encode.cpp index 46e112b0d..f5eafd610 100644 --- a/test/rocprim/test_device_run_length_encode.cpp +++ b/test/rocprim/test_device_run_length_encode.cpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2017-2022 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -26,6 +26,7 @@ #include // required test headers +#include "rocprim/types.hpp" #include "test_utils_types.hpp" template< @@ -68,15 +69,17 @@ typedef ::testing::Types< params, params, params, - params, + // half should be supported, but is missing some key operators. + // we should uncomment these, as soon as these are implemented and the tests compile and work as intended. + //params, params, params, - params, + params, params, params, params, - params -> Params; + params> + Params; TYPED_TEST_SUITE(RocprimDeviceRunLengthEncode, Params); diff --git a/test/rocprim/test_device_scan.cpp b/test/rocprim/test_device_scan.cpp index 9f2bdac22..9870e2093 100644 --- a/test/rocprim/test_device_scan.cpp +++ b/test/rocprim/test_device_scan.cpp @@ -186,32 +186,28 @@ TYPED_TEST(RocprimDeviceScanTests, InclusiveScanEmptyInput) = rocprim::make_transform_iterator(rocprim::make_constant_iterator(T(345)), [](T in) { return static_cast(in); }); - hipGraph_t graph; - hipGraphExec_t graph_instance; - if (TestFixture::use_graphs) - graph = test_utils::createGraphHelper(stream); - // temp storage size_t temp_storage_size_bytes; void * d_temp_storage = nullptr; // Get size of d_temp_storage - HIP_CHECK( - rocprim::inclusive_scan( - d_temp_storage, temp_storage_size_bytes, - input_iterator, d_checking_output, - 0, scan_op, stream, debug_synchronous - ) - ); + HIP_CHECK(rocprim::inclusive_scan(d_temp_storage, + temp_storage_size_bytes, + input_iterator, + d_checking_output, + 0, + scan_op, + stream, + debug_synchronous)); - if (TestFixture::use_graphs) - graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); - // allocate temporary storage HIP_CHECK(test_common_utils::hipMallocHelper(&d_temp_storage, temp_storage_size_bytes)); - if (TestFixture::use_graphs) - test_utils::resetGraphHelper(graph, graph_instance, stream); - + hipGraph_t graph; + if(TestFixture::use_graphs) + { + graph = test_utils::createGraphHelper(stream); + } + // Run HIP_CHECK( rocprim::inclusive_scan( @@ -221,9 +217,12 @@ TYPED_TEST(RocprimDeviceScanTests, InclusiveScanEmptyInput) ) ); - if (TestFixture::use_graphs) + hipGraphExec_t graph_instance; + if(TestFixture::use_graphs) + { graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, false); - + } + HIP_CHECK(hipGetLastError()); HIP_CHECK(hipDeviceSynchronize()); @@ -321,11 +320,6 @@ TYPED_TEST(RocprimDeviceScanTests, InclusiveScan) = rocprim::make_transform_iterator(d_input, [](T in) { return static_cast(in); }); - hipGraph_t graph; - hipGraphExec_t graph_instance; - if (TestFixture::use_graphs) - graph = test_utils::createGraphHelper(stream); - // temp storage size_t temp_storage_size_bytes; void * d_temp_storage = nullptr; @@ -340,9 +334,6 @@ TYPED_TEST(RocprimDeviceScanTests, InclusiveScan) stream, TestFixture::debug_synchronous)); - if (TestFixture::use_graphs) - graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); - // temp_storage_size_bytes must be >0 ASSERT_GT(temp_storage_size_bytes, 0); @@ -350,9 +341,12 @@ TYPED_TEST(RocprimDeviceScanTests, InclusiveScan) HIP_CHECK(test_common_utils::hipMallocHelper(&d_temp_storage, temp_storage_size_bytes)); HIP_CHECK(hipDeviceSynchronize()); - if (TestFixture::use_graphs) - test_utils::resetGraphHelper(graph, graph_instance, stream); - + hipGraph_t graph; + if(TestFixture::use_graphs) + { + graph = test_utils::createGraphHelper(stream); + } + // Run HIP_CHECK(rocprim::inclusive_scan( d_temp_storage, @@ -364,9 +358,12 @@ TYPED_TEST(RocprimDeviceScanTests, InclusiveScan) stream, TestFixture::debug_synchronous)); - if (TestFixture::use_graphs) + hipGraphExec_t graph_instance; + if(TestFixture::use_graphs) + { graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, false); - + } + HIP_CHECK(hipGetLastError()); HIP_CHECK(hipDeviceSynchronize()); @@ -483,11 +480,6 @@ TYPED_TEST(RocprimDeviceScanTests, ExclusiveScan) = rocprim::make_transform_iterator(d_input, [](T in) { return static_cast(in); }); - hipGraph_t graph; - hipGraphExec_t graph_instance; - if (TestFixture::use_graphs) - graph = test_utils::createGraphHelper(stream); - // temp storage size_t temp_storage_size_bytes; void * d_temp_storage = nullptr; @@ -503,9 +495,6 @@ TYPED_TEST(RocprimDeviceScanTests, ExclusiveScan) stream, debug_synchronous)); - if (TestFixture::use_graphs) - graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); - // temp_storage_size_bytes must be >0 ASSERT_GT(temp_storage_size_bytes, 0); @@ -513,9 +502,12 @@ TYPED_TEST(RocprimDeviceScanTests, ExclusiveScan) HIP_CHECK(test_common_utils::hipMallocHelper(&d_temp_storage, temp_storage_size_bytes)); HIP_CHECK(hipDeviceSynchronize()); - if (TestFixture::use_graphs) - test_utils::resetGraphHelper(graph, graph_instance, stream); - + hipGraph_t graph; + if(TestFixture::use_graphs) + { + graph = test_utils::createGraphHelper(stream); + } + // Run HIP_CHECK(rocprim::exclusive_scan( d_temp_storage, @@ -528,8 +520,11 @@ TYPED_TEST(RocprimDeviceScanTests, ExclusiveScan) stream, debug_synchronous)); - if (TestFixture::use_graphs) + hipGraphExec_t graph_instance; + if(TestFixture::use_graphs) + { graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, false); + } HIP_CHECK(hipGetLastError()); HIP_CHECK(hipDeviceSynchronize()); @@ -662,11 +657,6 @@ TYPED_TEST(RocprimDeviceScanTests, InclusiveScanByKey) = rocprim::make_transform_iterator(d_input, [](T in) { return static_cast(in); }); - hipGraph_t graph; - hipGraphExec_t graph_instance; - if (TestFixture::use_graphs) - graph = test_utils::createGraphHelper(stream); - // temp storage size_t temp_storage_size_bytes; void * d_temp_storage = nullptr; @@ -682,9 +672,6 @@ TYPED_TEST(RocprimDeviceScanTests, InclusiveScanByKey) stream, debug_synchronous)); - if (TestFixture::use_graphs) - graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); - // temp_storage_size_bytes must be >0 ASSERT_GT(temp_storage_size_bytes, 0); @@ -692,9 +679,12 @@ TYPED_TEST(RocprimDeviceScanTests, InclusiveScanByKey) HIP_CHECK(test_common_utils::hipMallocHelper(&d_temp_storage, temp_storage_size_bytes)); HIP_CHECK(hipDeviceSynchronize()); - if (TestFixture::use_graphs) - test_utils::resetGraphHelper(graph, graph_instance, stream); - + hipGraph_t graph; + if(TestFixture::use_graphs) + { + graph = test_utils::createGraphHelper(stream); + } + // Run HIP_CHECK(rocprim::inclusive_scan_by_key(d_temp_storage, temp_storage_size_bytes, @@ -707,9 +697,12 @@ TYPED_TEST(RocprimDeviceScanTests, InclusiveScanByKey) stream, debug_synchronous)); - if (TestFixture::use_graphs) + hipGraphExec_t graph_instance; + if(TestFixture::use_graphs) + { graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, false); - + } + HIP_CHECK(hipGetLastError()); HIP_CHECK(hipDeviceSynchronize()); @@ -845,11 +838,6 @@ TYPED_TEST(RocprimDeviceScanTests, ExclusiveScanByKey) = rocprim::make_transform_iterator(d_input, [](T in) { return static_cast(in); }); - hipGraph_t graph; - hipGraphExec_t graph_instance; - if (TestFixture::use_graphs) - graph = test_utils::createGraphHelper(stream); - // temp storage size_t temp_storage_size_bytes; void * d_temp_storage = nullptr; @@ -866,9 +854,6 @@ TYPED_TEST(RocprimDeviceScanTests, ExclusiveScanByKey) stream, debug_synchronous)); - if (TestFixture::use_graphs) - graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); - // temp_storage_size_bytes must be >0 ASSERT_GT(temp_storage_size_bytes, 0); @@ -876,8 +861,11 @@ TYPED_TEST(RocprimDeviceScanTests, ExclusiveScanByKey) HIP_CHECK(test_common_utils::hipMallocHelper(&d_temp_storage, temp_storage_size_bytes)); HIP_CHECK(hipDeviceSynchronize()); - if (TestFixture::use_graphs) - test_utils::resetGraphHelper(graph, graph_instance, stream); + hipGraph_t graph; + if(TestFixture::use_graphs) + { + graph = test_utils::createGraphHelper(stream); + } // Run HIP_CHECK(rocprim::exclusive_scan_by_key(d_temp_storage, @@ -892,8 +880,11 @@ TYPED_TEST(RocprimDeviceScanTests, ExclusiveScanByKey) stream, debug_synchronous)); - if (TestFixture::use_graphs) + hipGraphExec_t graph_instance; + if(TestFixture::use_graphs) + { graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, false); + } HIP_CHECK(hipGetLastError()); HIP_CHECK(hipDeviceSynchronize()); @@ -1036,13 +1027,8 @@ void testLargeIndicesInclusiveScan() // temp storage size_t temp_storage_size_bytes; - void * d_temp_storage = nullptr; + void* d_temp_storage = nullptr; - hipGraph_t graph; - hipGraphExec_t graph_instance; - if (UseGraphs) - graph = test_utils::createGraphHelper(stream); - // Get temporary array size HIP_CHECK( rocprim::inclusive_scan( @@ -1053,9 +1039,6 @@ void testLargeIndicesInclusiveScan() ) ); - if (UseGraphs) - graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); - // temp_storage_size_bytes must be >0 ASSERT_GT(temp_storage_size_bytes, 0); @@ -1063,8 +1046,11 @@ void testLargeIndicesInclusiveScan() HIP_CHECK(test_common_utils::hipMallocHelper(&d_temp_storage, temp_storage_size_bytes)); HIP_CHECK(hipDeviceSynchronize()); - if (UseGraphs) - test_utils::resetGraphHelper(graph, graph_instance, stream); + hipGraph_t graph; + if(UseGraphs) + { + graph = test_utils::createGraphHelper(stream); + } // Run HIP_CHECK( @@ -1076,8 +1062,11 @@ void testLargeIndicesInclusiveScan() ) ); - if (UseGraphs) + hipGraphExec_t graph_instance; + if(UseGraphs) + { graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, false); + } HIP_CHECK(hipGetLastError()); HIP_CHECK(hipDeviceSynchronize()); @@ -1098,13 +1087,17 @@ void testLargeIndicesInclusiveScan() hipFree(d_temp_storage); hipFree(d_output); - if (UseGraphs) + if(UseGraphs) + { test_utils::cleanupGraphHelper(graph, graph_instance); + } } } - if (UseGraphs) + if(UseGraphs) + { HIP_CHECK(hipStreamDestroy(stream)); + } } TEST(RocprimDeviceScanTests, LargeIndicesInclusiveScan) @@ -1164,11 +1157,6 @@ void testLargeIndicesExclusiveScan() size_t temp_storage_size_bytes; void * d_temp_storage = nullptr; - hipGraph_t graph; - hipGraphExec_t graph_instance; - if (UseGraphs) - graph = test_utils::createGraphHelper(stream); - // Get temporary array size HIP_CHECK( rocprim::exclusive_scan( @@ -1180,9 +1168,6 @@ void testLargeIndicesExclusiveScan() ) ); - if (UseGraphs) - graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); - // temp_storage_size_bytes must be >0 ASSERT_GT(temp_storage_size_bytes, 0); @@ -1190,8 +1175,11 @@ void testLargeIndicesExclusiveScan() HIP_CHECK(test_common_utils::hipMallocHelper(&d_temp_storage, temp_storage_size_bytes)); HIP_CHECK(hipDeviceSynchronize()); - if (UseGraphs) - test_utils::resetGraphHelper(graph, graph_instance, stream); + hipGraph_t graph; + if(UseGraphs) + { + graph = test_utils::createGraphHelper(stream); + } // Run HIP_CHECK( @@ -1204,8 +1192,11 @@ void testLargeIndicesExclusiveScan() ) ); - if (UseGraphs) + hipGraphExec_t graph_instance; + if(UseGraphs) + { graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, false); + } HIP_CHECK(hipGetLastError()); HIP_CHECK(hipDeviceSynchronize()); @@ -1228,14 +1219,18 @@ void testLargeIndicesExclusiveScan() hipFree(d_temp_storage); hipFree(d_output); - - if (UseGraphs) + + if(UseGraphs) + { test_utils::cleanupGraphHelper(graph, graph_instance); + } } } - if (UseGraphs) + if(UseGraphs) + { HIP_CHECK(hipStreamDestroy(stream)); + } } TEST(RocprimDeviceScanTests, LargeIndicesExclusiveScan) @@ -1415,11 +1410,6 @@ void large_indices_scan_by_key_test(ScanByKeyFun scan_by_key_fun) { return value / run_length; }); const auto values_input = rocprim::counting_iterator(0); - hipGraph_t graph; - hipGraphExec_t graph_instance; - if (UseGraphs) - graph = test_utils::createGraphHelper(stream); - size_t temp_storage_size_bytes; void* d_temp_storage = nullptr; HIP_CHECK(scan_by_key_fun(d_temp_storage, @@ -1432,14 +1422,15 @@ void large_indices_scan_by_key_test(ScanByKeyFun scan_by_key_fun) stream, debug_synchronous, seed_value)); - if (UseGraphs) - graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); ASSERT_GT(temp_storage_size_bytes, 0); HIP_CHECK(test_common_utils::hipMallocHelper(&d_temp_storage, temp_storage_size_bytes)); - - if (UseGraphs) - test_utils::resetGraphHelper(graph, graph_instance, stream); + + hipGraph_t graph; + if(UseGraphs) + { + graph = test_utils::createGraphHelper(stream); + } HIP_CHECK(scan_by_key_fun(d_temp_storage, temp_storage_size_bytes, @@ -1452,8 +1443,11 @@ void large_indices_scan_by_key_test(ScanByKeyFun scan_by_key_fun) debug_synchronous, seed_value)); - if (UseGraphs) + hipGraphExec_t graph_instance; + if(UseGraphs) + { graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); + } HIP_CHECK(hipGetLastError()); @@ -1666,11 +1660,6 @@ TYPED_TEST(RocprimDeviceScanFutureTests, ExclusiveScan) = rocprim::make_transform_iterator(d_input, [](T in) { return static_cast(in); }); - hipGraph_t graph; - hipGraphExec_t graph_instance; - if (TestFixture::use_graphs) - graph = test_utils::createGraphHelper(stream); - // temp storage size_t temp_storage_size_bytes; char* d_temp_storage = nullptr; @@ -1686,29 +1675,23 @@ TYPED_TEST(RocprimDeviceScanFutureTests, ExclusiveScan) stream, debug_synchronous)); - if (TestFixture::use_graphs) - graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); - // temp_storage_size_bytes must be >0 ASSERT_GT(temp_storage_size_bytes, 0); - if (TestFixture::use_graphs) - test_utils::resetGraphHelper(graph, graph_instance, stream); - size_t temp_storage_reduce = 0; HIP_CHECK(rocprim::reduce( nullptr, temp_storage_reduce, d_future_input, d_initial_value, 2048, rocprim::plus(), stream)); - if (TestFixture::use_graphs) - graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); - // allocate temporary storage HIP_CHECK(test_common_utils::hipMallocHelper( &d_temp_storage, temp_storage_size_bytes + temp_storage_reduce)); HIP_CHECK(hipDeviceSynchronize()); - if (TestFixture::use_graphs) - test_utils::resetGraphHelper(graph, graph_instance, stream); + hipGraph_t graph; + if(TestFixture::use_graphs) + { + graph = test_utils::createGraphHelper(stream); + } // Fill initial value on the device HIP_CHECK(rocprim::reduce(d_temp_storage + temp_storage_size_bytes, @@ -1732,8 +1715,11 @@ TYPED_TEST(RocprimDeviceScanFutureTests, ExclusiveScan) debug_synchronous)); HIP_CHECK(hipGetLastError()); - if (TestFixture::use_graphs) + hipGraphExec_t graph_instance; + if(TestFixture::use_graphs) + { graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); + } // Copy output to host HIP_CHECK(hipMemcpy( diff --git a/test/rocprim/test_device_segmented_reduce.cpp b/test/rocprim/test_device_segmented_reduce.cpp index 2f2404435..065f2b1d6 100644 --- a/test/rocprim/test_device_segmented_reduce.cpp +++ b/test/rocprim/test_device_segmented_reduce.cpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -241,11 +241,6 @@ TYPED_TEST(RocprimDeviceSegmentedReduce, Reduce) HIP_CHECK(test_common_utils::hipMallocHelper(&d_aggregates_output, segments_count * sizeof(output_type))); - hipGraph_t graph; - hipGraphExec_t graph_instance; - if (TestFixture::params::use_graphs) - graph = test_utils::createGraphHelper(stream); - size_t temporary_storage_bytes; HIP_CHECK(rocprim::segmented_reduce(nullptr, @@ -260,18 +255,18 @@ TYPED_TEST(RocprimDeviceSegmentedReduce, Reduce) stream, debug_synchronous)); - if (TestFixture::params::use_graphs) - graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); - ASSERT_GT(temporary_storage_bytes, 0); void* d_temporary_storage; HIP_CHECK( test_common_utils::hipMallocHelper(&d_temporary_storage, temporary_storage_bytes)); - if (TestFixture::params::use_graphs) - test_utils::resetGraphHelper(graph, graph_instance, stream); - + hipGraph_t graph; + if(TestFixture::params::use_graphs) + { + graph = test_utils::createGraphHelper(stream); + } + HIP_CHECK(rocprim::segmented_reduce( d_temporary_storage, temporary_storage_bytes, @@ -285,8 +280,11 @@ TYPED_TEST(RocprimDeviceSegmentedReduce, Reduce) stream, debug_synchronous)); - if (TestFixture::params::use_graphs) + hipGraphExec_t graph_instance; + if(TestFixture::params::use_graphs) + { graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); + } HIP_CHECK(hipFree(d_temporary_storage)); diff --git a/test/rocprim/test_device_segmented_scan.cpp b/test/rocprim/test_device_segmented_scan.cpp index f8a06c247..5f79a7287 100644 --- a/test/rocprim/test_device_segmented_scan.cpp +++ b/test/rocprim/test_device_segmented_scan.cpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2017-2022 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -70,9 +70,10 @@ using bfloat16 = rocprim::bfloat16; typedef ::testing::Types< params>, params, -100, 0, 10000>, + params, -100, 0, 10000>, params, 1000, 0, 10000>, params, 10, 1000, 10000>, - params, 50, 2, 10>, + params, 50, 2, 10>, params, 123, 100, 200, true>, params, 0, 3, 50, true>, params, 0, 1000, 30000>, @@ -193,34 +194,29 @@ TYPED_TEST(RocprimDeviceSegmentedScan, InclusiveScan) ); HIP_CHECK(hipDeviceSynchronize()); - hipGraph_t graph; - hipGraphExec_t graph_instance; - if (TestFixture::params::use_graphs) - graph = test_utils::createGraphHelper(stream); - size_t temporary_storage_bytes; - HIP_CHECK( - rocprim::segmented_inclusive_scan( - nullptr, temporary_storage_bytes, - d_values_input, - test_utils::wrap_in_identity_iterator(d_values_output), - segments_count, - d_offsets, d_offsets + 1, - scan_op, - stream, debug_synchronous - ) - ); + HIP_CHECK(rocprim::segmented_inclusive_scan( + nullptr, + temporary_storage_bytes, + d_values_input, + test_utils::wrap_in_identity_iterator(d_values_output), + segments_count, + d_offsets, + d_offsets + 1, + scan_op, + stream, + debug_synchronous)); - if (TestFixture::params::use_graphs) - graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); - ASSERT_GT(temporary_storage_bytes, 0); void * d_temporary_storage; HIP_CHECK(test_common_utils::hipMallocHelper(&d_temporary_storage, temporary_storage_bytes)); - if (TestFixture::params::use_graphs) - test_utils::resetGraphHelper(graph, graph_instance, stream); - + hipGraph_t graph; + if(TestFixture::params::use_graphs) + { + graph = test_utils::createGraphHelper(stream); + } + HIP_CHECK( rocprim::segmented_inclusive_scan( d_temporary_storage, temporary_storage_bytes, @@ -233,9 +229,12 @@ TYPED_TEST(RocprimDeviceSegmentedScan, InclusiveScan) ) ); - if (TestFixture::params::use_graphs) + hipGraphExec_t graph_instance; + if(TestFixture::params::use_graphs) + { graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, false); - + } + HIP_CHECK(hipDeviceSynchronize()); std::vector values_output(size); @@ -256,13 +255,17 @@ TYPED_TEST(RocprimDeviceSegmentedScan, InclusiveScan) HIP_CHECK(hipFree(d_offsets)); HIP_CHECK(hipFree(d_values_output)); - if (TestFixture::params::use_graphs) + if(TestFixture::params::use_graphs) + { test_utils::cleanupGraphHelper(graph, graph_instance); + } } } - if (TestFixture::params::use_graphs) + if(TestFixture::params::use_graphs) + { HIP_CHECK(hipStreamDestroy(stream)); + } } TYPED_TEST(RocprimDeviceSegmentedScan, ExclusiveScan) @@ -376,11 +379,6 @@ TYPED_TEST(RocprimDeviceSegmentedScan, ExclusiveScan) ); HIP_CHECK(hipDeviceSynchronize()); - hipGraph_t graph; - hipGraphExec_t graph_instance; - if (TestFixture::params::use_graphs) - graph = test_utils::createGraphHelper(stream); - size_t temporary_storage_bytes; HIP_CHECK( rocprim::segmented_exclusive_scan( @@ -394,17 +392,17 @@ TYPED_TEST(RocprimDeviceSegmentedScan, ExclusiveScan) ) ); - if (TestFixture::params::use_graphs) - graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); - HIP_CHECK(hipDeviceSynchronize()); ASSERT_GT(temporary_storage_bytes, 0); void * d_temporary_storage; HIP_CHECK(test_common_utils::hipMallocHelper(&d_temporary_storage, temporary_storage_bytes)); - if (TestFixture::params::use_graphs) - test_utils::resetGraphHelper(graph, graph_instance, stream); + hipGraph_t graph; + if(TestFixture::params::use_graphs) + { + graph = test_utils::createGraphHelper(stream); + } HIP_CHECK( rocprim::segmented_exclusive_scan( @@ -418,8 +416,11 @@ TYPED_TEST(RocprimDeviceSegmentedScan, ExclusiveScan) ) ); - if (TestFixture::params::use_graphs) + hipGraphExec_t graph_instance; + if(TestFixture::params::use_graphs) + { graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, false); + } HIP_CHECK(hipDeviceSynchronize()); @@ -441,13 +442,17 @@ TYPED_TEST(RocprimDeviceSegmentedScan, ExclusiveScan) HIP_CHECK(hipFree(d_offsets)); HIP_CHECK(hipFree(d_values_output)); - if (TestFixture::params::use_graphs) + if(TestFixture::params::use_graphs) + { test_utils::cleanupGraphHelper(graph, graph_instance); + } } } - if (TestFixture::params::use_graphs) + if(TestFixture::params::use_graphs) + { HIP_CHECK(hipStreamDestroy(stream)); + } } TYPED_TEST(RocprimDeviceSegmentedScan, InclusiveScanUsingHeadFlags) @@ -560,11 +565,6 @@ TYPED_TEST(RocprimDeviceSegmentedScan, InclusiveScanUsingHeadFlags) expected.begin(), scan_op); - hipGraph_t graph; - hipGraphExec_t graph_instance; - if (TestFixture::params::use_graphs) - graph = test_utils::createGraphHelper(stream); - // temp storage size_t temp_storage_size_bytes; // Get size of d_temp_storage @@ -577,9 +577,6 @@ TYPED_TEST(RocprimDeviceSegmentedScan, InclusiveScanUsingHeadFlags) ) ); - if (TestFixture::params::use_graphs) - graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, false); - HIP_CHECK(hipDeviceSynchronize()); // temp_storage_size_bytes must be >0 @@ -590,8 +587,11 @@ TYPED_TEST(RocprimDeviceSegmentedScan, InclusiveScanUsingHeadFlags) HIP_CHECK(test_common_utils::hipMallocHelper(&d_temp_storage, temp_storage_size_bytes)); HIP_CHECK(hipDeviceSynchronize()); - if (TestFixture::params::use_graphs) - test_utils::resetGraphHelper(graph, graph_instance, stream); + hipGraph_t graph; + if(TestFixture::params::use_graphs) + { + graph = test_utils::createGraphHelper(stream); + } // Run HIP_CHECK( @@ -603,8 +603,11 @@ TYPED_TEST(RocprimDeviceSegmentedScan, InclusiveScanUsingHeadFlags) ) ); - if (TestFixture::params::use_graphs) + hipGraphExec_t graph_instance; + if(TestFixture::params::use_graphs) + { graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, false); + } HIP_CHECK(hipDeviceSynchronize()); @@ -626,13 +629,17 @@ TYPED_TEST(RocprimDeviceSegmentedScan, InclusiveScanUsingHeadFlags) HIP_CHECK(hipFree(d_flags)); HIP_CHECK(hipFree(d_output)); - if (TestFixture::params::use_graphs) + if(TestFixture::params::use_graphs) + { test_utils::cleanupGraphHelper(graph, graph_instance); + } } } - if (TestFixture::params::use_graphs) + if(TestFixture::params::use_graphs) + { HIP_CHECK(hipStreamDestroy(stream)); + } } TYPED_TEST(RocprimDeviceSegmentedScan, ExclusiveScanUsingHeadFlags) @@ -748,11 +755,6 @@ TYPED_TEST(RocprimDeviceSegmentedScan, ExclusiveScanUsingHeadFlags) scan_op, init); - hipGraph_t graph; - hipGraphExec_t graph_instance; - if (TestFixture::params::use_graphs) - graph = test_utils::createGraphHelper(stream); - // temp storage size_t temp_storage_size_bytes; // Get size of d_temp_storage @@ -764,9 +766,6 @@ TYPED_TEST(RocprimDeviceSegmentedScan, ExclusiveScanUsingHeadFlags) ) ); - if (TestFixture::params::use_graphs) - graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); - HIP_CHECK(hipDeviceSynchronize()); // temp_storage_size_bytes must be >0 @@ -777,8 +776,11 @@ TYPED_TEST(RocprimDeviceSegmentedScan, ExclusiveScanUsingHeadFlags) HIP_CHECK(test_common_utils::hipMallocHelper(&d_temp_storage, temp_storage_size_bytes)); HIP_CHECK(hipDeviceSynchronize()); - if (TestFixture::params::use_graphs) - test_utils::resetGraphHelper(graph, graph_instance, stream); + hipGraph_t graph; + if(TestFixture::params::use_graphs) + { + graph = test_utils::createGraphHelper(stream); + } // Run HIP_CHECK( @@ -789,8 +791,11 @@ TYPED_TEST(RocprimDeviceSegmentedScan, ExclusiveScanUsingHeadFlags) ) ); - if (TestFixture::params::use_graphs) + hipGraphExec_t graph_instance; + if(TestFixture::params::use_graphs) + { graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, false); + } HIP_CHECK(hipDeviceSynchronize()); @@ -804,9 +809,11 @@ TYPED_TEST(RocprimDeviceSegmentedScan, ExclusiveScanUsingHeadFlags) ) ); - if (TestFixture::params::use_graphs) + if(TestFixture::params::use_graphs) + { test_utils::cleanupGraphHelper(graph, graph_instance); - + } + HIP_CHECK(hipDeviceSynchronize()); ASSERT_NO_FATAL_FAILURE(test_utils::assert_near(output, expected, precision)); @@ -818,6 +825,8 @@ TYPED_TEST(RocprimDeviceSegmentedScan, ExclusiveScanUsingHeadFlags) } } - if (TestFixture::params::use_graphs) + if(TestFixture::params::use_graphs) + { HIP_CHECK(hipStreamDestroy(stream)); + } } diff --git a/test/rocprim/test_device_select.cpp b/test/rocprim/test_device_select.cpp index 88d5562b4..ffda63d63 100644 --- a/test/rocprim/test_device_select.cpp +++ b/test/rocprim/test_device_select.cpp @@ -61,17 +61,20 @@ class RocprimDeviceSelectTests : public ::testing::Test static constexpr bool use_graphs = Params::use_graphs; }; -typedef ::testing::Types< - DeviceSelectParams, - DeviceSelectParams, - DeviceSelectParams, - DeviceSelectParams, - DeviceSelectParams, - DeviceSelectParams, - DeviceSelectParams, - DeviceSelectParams, test_utils::custom_test_type, int, true>, - DeviceSelectParams -> RocprimDeviceSelectTestsParams; +typedef ::testing::Types, + DeviceSelectParams, + DeviceSelectParams, + DeviceSelectParams, + DeviceSelectParams, + DeviceSelectParams, + DeviceSelectParams, + DeviceSelectParams, + DeviceSelectParams, + test_utils::custom_test_type, + int, + true>, + DeviceSelectParams> + RocprimDeviceSelectTestsParams; TYPED_TEST_SUITE(RocprimDeviceSelectTests, RocprimDeviceSelectTestsParams); @@ -141,31 +144,20 @@ TYPED_TEST(RocprimDeviceSelectTests, Flagged) } } - hipGraph_t graph; - hipGraphExec_t graph_instance; - if (TestFixture::use_graphs) - graph = test_utils::createGraphHelper(stream); - // temp storage size_t temp_storage_size_bytes; // Get size of d_temp_storage - HIP_CHECK( - rocprim::select( - nullptr, - temp_storage_size_bytes, - d_input, - d_flags, - test_utils::wrap_in_identity_iterator(d_output), - d_selected_count_output, - input.size(), - stream, - TestFixture::debug_synchronous - ) - ); + HIP_CHECK(rocprim::select( + nullptr, + temp_storage_size_bytes, + d_input, + d_flags, + test_utils::wrap_in_identity_iterator(d_output), + d_selected_count_output, + input.size(), + stream, + TestFixture::debug_synchronous)); - if (TestFixture::use_graphs) - graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, false); - HIP_CHECK(hipDeviceSynchronize()); // temp_storage_size_bytes must be >0 @@ -176,9 +168,12 @@ TYPED_TEST(RocprimDeviceSelectTests, Flagged) HIP_CHECK(test_common_utils::hipMallocHelper(&d_temp_storage, temp_storage_size_bytes)); HIP_CHECK(hipDeviceSynchronize()); - if (TestFixture::use_graphs) - test_utils::resetGraphHelper(graph, graph_instance, stream); - + hipGraph_t graph; + if(TestFixture::use_graphs) + { + graph = test_utils::createGraphHelper(stream); + } + // Run HIP_CHECK( rocprim::select( @@ -194,9 +189,12 @@ TYPED_TEST(RocprimDeviceSelectTests, Flagged) ) ); - if (TestFixture::use_graphs) + hipGraphExec_t graph_instance; + if(TestFixture::use_graphs) + { graph_instance = graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, false); - + } + HIP_CHECK(hipDeviceSynchronize()); // Check if number of selected value is as expected @@ -229,13 +227,17 @@ TYPED_TEST(RocprimDeviceSelectTests, Flagged) hipFree(d_selected_count_output); hipFree(d_temp_storage); - if (TestFixture::use_graphs) + if(TestFixture::use_graphs) + { test_utils::cleanupGraphHelper(graph, graph_instance); + } } } - if (TestFixture::use_graphs) + if(TestFixture::use_graphs) + { HIP_CHECK(hipStreamDestroy(stream)); + } } @@ -305,31 +307,20 @@ TYPED_TEST(RocprimDeviceSelectTests, SelectOp) } } - hipGraph_t graph; - hipGraphExec_t graph_instance; - if (TestFixture::use_graphs) - graph = test_utils::createGraphHelper(stream); - // temp storage size_t temp_storage_size_bytes; // Get size of d_temp_storage - HIP_CHECK( - rocprim::select( - nullptr, - temp_storage_size_bytes, - d_input, - test_utils::wrap_in_identity_iterator(d_output), - d_selected_count_output, - input.size(), - select_op(), - stream, - debug_synchronous - ) - ); + HIP_CHECK(rocprim::select( + nullptr, + temp_storage_size_bytes, + d_input, + test_utils::wrap_in_identity_iterator(d_output), + d_selected_count_output, + input.size(), + select_op(), + stream, + debug_synchronous)); - if (TestFixture::use_graphs) - graph_instance = graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, false); - HIP_CHECK(hipDeviceSynchronize()); // temp_storage_size_bytes must be >0 @@ -340,8 +331,11 @@ TYPED_TEST(RocprimDeviceSelectTests, SelectOp) HIP_CHECK(test_common_utils::hipMallocHelper(&d_temp_storage, temp_storage_size_bytes)); HIP_CHECK(hipDeviceSynchronize()); - if (TestFixture::use_graphs) - test_utils::resetGraphHelper(graph, graph_instance, stream); + hipGraph_t graph; + if(TestFixture::use_graphs) + { + graph = test_utils::createGraphHelper(stream); + } // Run HIP_CHECK( @@ -358,8 +352,11 @@ TYPED_TEST(RocprimDeviceSelectTests, SelectOp) ) ); - if (TestFixture::use_graphs) + hipGraphExec_t graph_instance; + if(TestFixture::use_graphs) + { graph_instance = graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, false); + } HIP_CHECK(hipDeviceSynchronize()); @@ -392,13 +389,17 @@ TYPED_TEST(RocprimDeviceSelectTests, SelectOp) hipFree(d_selected_count_output); hipFree(d_temp_storage); - if (TestFixture::use_graphs) + if(TestFixture::use_graphs) + { test_utils::cleanupGraphHelper(graph, graph_instance); + } } } - if (TestFixture::use_graphs) + if(TestFixture::use_graphs) + { HIP_CHECK(hipStreamDestroy(stream)); + } } std::vector get_discontinuity_probabilities() @@ -484,31 +485,20 @@ TYPED_TEST(RocprimDeviceSelectTests, Unique) } } - hipGraph_t graph; - hipGraphExec_t graph_instance; - if (TestFixture::use_graphs) - graph = test_utils::createGraphHelper(stream); - // temp storage size_t temp_storage_size_bytes; // Get size of d_temp_storage - HIP_CHECK( - rocprim::unique( - nullptr, - temp_storage_size_bytes, - d_input, - test_utils::wrap_in_identity_iterator(d_output), - d_selected_count_output, - input.size(), - op_type(), - stream, - debug_synchronous - ) - ); + HIP_CHECK(rocprim::unique( + nullptr, + temp_storage_size_bytes, + d_input, + test_utils::wrap_in_identity_iterator(d_output), + d_selected_count_output, + input.size(), + op_type(), + stream, + debug_synchronous)); - if (TestFixture::use_graphs) - graph_instance = graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, false); - HIP_CHECK(hipDeviceSynchronize()); // temp_storage_size_bytes must be >0 @@ -519,8 +509,11 @@ TYPED_TEST(RocprimDeviceSelectTests, Unique) HIP_CHECK(test_common_utils::hipMallocHelper(&d_temp_storage, temp_storage_size_bytes)); HIP_CHECK(hipDeviceSynchronize()); - if (TestFixture::use_graphs) - test_utils::resetGraphHelper(graph, graph_instance, stream); + hipGraph_t graph; + if(TestFixture::use_graphs) + { + graph = test_utils::createGraphHelper(stream); + } // Run HIP_CHECK( @@ -537,9 +530,12 @@ TYPED_TEST(RocprimDeviceSelectTests, Unique) ) ); - if (TestFixture::use_graphs) + hipGraphExec_t graph_instance; + if(TestFixture::use_graphs) + { graph_instance = graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, false); - + } + HIP_CHECK(hipDeviceSynchronize()); // Check if number of selected value is as expected @@ -571,14 +567,18 @@ TYPED_TEST(RocprimDeviceSelectTests, Unique) hipFree(d_selected_count_output); hipFree(d_temp_storage); - if (TestFixture::use_graphs) + if(TestFixture::use_graphs) + { test_utils::cleanupGraphHelper(graph, graph_instance); + } } } } - if (TestFixture::use_graphs) + if(TestFixture::use_graphs) + { HIP_CHECK(hipStreamDestroy(stream)); + } } // The operator must be only called, when we have valid element in a block @@ -690,11 +690,6 @@ void testUniqueGuardedOperator() } } - hipGraph_t graph; - hipGraphExec_t graph_instance; - if (UseGraphs) - graph = test_utils::createGraphHelper(stream); - // temp storage size_t temp_storage_size_bytes; // Get size of d_temp_storage @@ -712,9 +707,6 @@ void testUniqueGuardedOperator() ) ); - if (UseGraphs) - graph_instance = graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, false); - HIP_CHECK(hipDeviceSynchronize()); // temp_storage_size_bytes must be >0 @@ -725,8 +717,11 @@ void testUniqueGuardedOperator() HIP_CHECK(hipMalloc(&d_temp_storage, temp_storage_size_bytes)); HIP_CHECK(hipDeviceSynchronize()); - if (UseGraphs) - test_utils::resetGraphHelper(graph, graph_instance, stream); + hipGraph_t graph; + if(UseGraphs) + { + graph = test_utils::createGraphHelper(stream); + } // Run HIP_CHECK( @@ -743,8 +738,11 @@ void testUniqueGuardedOperator() ) ); - if (UseGraphs) + hipGraphExec_t graph_instance; + if(UseGraphs) + { graph_instance = graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, false); + } HIP_CHECK(hipDeviceSynchronize()); @@ -778,14 +776,18 @@ void testUniqueGuardedOperator() hipFree(d_selected_count_output); hipFree(d_temp_storage); - if (UseGraphs) + if(UseGraphs) + { test_utils::cleanupGraphHelper(graph, graph_instance); + } } } } - if (UseGraphs) + if(UseGraphs) + { HIP_CHECK(hipStreamDestroy(stream)); + } } TEST(RocprimDeviceSelectTests, UniqueGuardedOperator) @@ -938,11 +940,6 @@ TYPED_TEST(RocprimDeviceUniqueByKeyTests, UniqueByKey) } } - hipGraph_t graph; - hipGraphExec_t graph_instance; - if (TestFixture::use_graphs) - graph = test_utils::createGraphHelper(stream); - // temp storage size_t temp_storage_size_bytes; // Get size of d_temp_storage @@ -962,9 +959,6 @@ TYPED_TEST(RocprimDeviceUniqueByKeyTests, UniqueByKey) ) ); - if (TestFixture::use_graphs) - graph_instance = graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, false); - HIP_CHECK(hipDeviceSynchronize()); // temp_storage_size_bytes must be >0 @@ -975,8 +969,11 @@ TYPED_TEST(RocprimDeviceUniqueByKeyTests, UniqueByKey) HIP_CHECK(test_common_utils::hipMallocHelper(&d_temp_storage, temp_storage_size_bytes)); HIP_CHECK(hipDeviceSynchronize()); - if (TestFixture::use_graphs) - test_utils::resetGraphHelper(graph, graph_instance, stream); + hipGraph_t graph; + if(TestFixture::use_graphs) + { + graph = test_utils::createGraphHelper(stream); + } // Run HIP_CHECK( @@ -995,8 +992,11 @@ TYPED_TEST(RocprimDeviceUniqueByKeyTests, UniqueByKey) ) ); - if (TestFixture::use_graphs) + hipGraphExec_t graph_instance; + if(TestFixture::use_graphs) + { graph_instance = graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, false); + } HIP_CHECK(hipDeviceSynchronize()); @@ -1040,14 +1040,18 @@ TYPED_TEST(RocprimDeviceUniqueByKeyTests, UniqueByKey) hipFree(d_selected_count_output); hipFree(d_temp_storage); - if (TestFixture::use_graphs) + if(TestFixture::use_graphs) + { test_utils::cleanupGraphHelper(graph, graph_instance); + } } } } - if (TestFixture::use_graphs) + if(TestFixture::use_graphs) + { HIP_CHECK(hipStreamDestroy(stream)); + } } TYPED_TEST(RocprimDeviceUniqueByKeyTests, UniqueByKeyAlias) @@ -1144,11 +1148,6 @@ TYPED_TEST(RocprimDeviceUniqueByKeyTests, UniqueByKeyAlias) } } - hipGraph_t graph; - hipGraphExec_t graph_instance; - if(TestFixture::use_graphs) - graph = test_utils::createGraphHelper(stream); - // temp storage size_t temp_storage_size_bytes; // Get size of d_temp_storage @@ -1165,10 +1164,6 @@ TYPED_TEST(RocprimDeviceUniqueByKeyTests, UniqueByKeyAlias) stream, debug_synchronous)); - if(TestFixture::use_graphs) - graph_instance = graph_instance - = test_utils::endCaptureGraphHelper(graph, stream, true, false); - HIP_CHECK(hipDeviceSynchronize()); // temp_storage_size_bytes must be >0 @@ -1180,8 +1175,11 @@ TYPED_TEST(RocprimDeviceUniqueByKeyTests, UniqueByKeyAlias) test_common_utils::hipMallocHelper(&d_temp_storage, temp_storage_size_bytes)); HIP_CHECK(hipDeviceSynchronize()); + hipGraph_t graph; if(TestFixture::use_graphs) - test_utils::resetGraphHelper(graph, graph_instance, stream); + { + graph = test_utils::createGraphHelper(stream); + } // Run HIP_CHECK(rocprim::unique_by_key( @@ -1197,9 +1195,12 @@ TYPED_TEST(RocprimDeviceUniqueByKeyTests, UniqueByKeyAlias) stream, debug_synchronous)); + hipGraphExec_t graph_instance; if(TestFixture::use_graphs) + { graph_instance = graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, false); + } HIP_CHECK(hipDeviceSynchronize()); @@ -1235,13 +1236,17 @@ TYPED_TEST(RocprimDeviceUniqueByKeyTests, UniqueByKeyAlias) hipFree(d_temp_storage); if(TestFixture::use_graphs) + { test_utils::cleanupGraphHelper(graph, graph_instance); + } } } } if(TestFixture::use_graphs) + { HIP_CHECK(hipStreamDestroy(stream)); + } } class RocprimDeviceSelectLargeInputTests : public ::testing::TestWithParam> { @@ -1317,11 +1322,6 @@ TEST_P(RocprimDeviceSelectLargeInputTests, LargeInputFlagged) size_t temp_storage_size_bytes; void *d_temp_storage = nullptr; - hipGraph_t graph; - hipGraphExec_t graph_instance; - if (use_graphs) - graph = test_utils::createGraphHelper(stream); - // Get size of d_temp_storage HIP_CHECK( rocprim::select( @@ -1337,9 +1337,6 @@ TEST_P(RocprimDeviceSelectLargeInputTests, LargeInputFlagged) ) ); - if (use_graphs) - graph_instance = graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, false); - HIP_CHECK(hipDeviceSynchronize()); // temp_storage_size_bytes must be >0 @@ -1349,8 +1346,11 @@ TEST_P(RocprimDeviceSelectLargeInputTests, LargeInputFlagged) HIP_CHECK(test_common_utils::hipMallocHelper(&d_temp_storage, temp_storage_size_bytes)); HIP_CHECK(hipDeviceSynchronize()); - if (use_graphs) - test_utils::resetGraphHelper(graph, graph_instance, stream); + hipGraph_t graph; + if(use_graphs) + { + graph = test_utils::createGraphHelper(stream); + } // Run HIP_CHECK( @@ -1367,8 +1367,11 @@ TEST_P(RocprimDeviceSelectLargeInputTests, LargeInputFlagged) ) ); - if (use_graphs) + hipGraphExec_t graph_instance; + if(use_graphs) + { graph_instance = graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, false); + } HIP_CHECK(hipDeviceSynchronize()); @@ -1399,8 +1402,10 @@ TEST_P(RocprimDeviceSelectLargeInputTests, LargeInputFlagged) hipFree(d_selected_count_output); hipFree(d_temp_storage); - if (use_graphs) + if(use_graphs) + { test_utils::cleanupGraphHelper(graph, graph_instance); + } } if (use_graphs) @@ -1448,11 +1453,6 @@ TEST_P(RocprimDeviceSelectLargeInputTests, LargeInputUnique) HIP_CHECK(test_common_utils::hipMallocHelper(&d_unique_count_output, sizeof(*d_unique_count_output))); - hipGraph_t graph; - hipGraphExec_t graph_instance; - if (use_graphs) - graph = test_utils::createGraphHelper(stream); - size_t temp_storage_size_bytes{}; void* d_temp_storage{}; HIP_CHECK(rocprim::unique(d_temp_storage, @@ -1465,14 +1465,14 @@ TEST_P(RocprimDeviceSelectLargeInputTests, LargeInputUnique) stream, debug_synchronous)); - if (use_graphs) - graph_instance = graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); - ASSERT_GT(temp_storage_size_bytes, 0); HIP_CHECK(test_common_utils::hipMallocHelper(&d_temp_storage, temp_storage_size_bytes)); - if (use_graphs) - test_utils::resetGraphHelper(graph, graph_instance, stream); + hipGraph_t graph; + if(use_graphs) + { + graph = test_utils::createGraphHelper(stream); + } HIP_CHECK(rocprim::unique(d_temp_storage, temp_storage_size_bytes, @@ -1484,8 +1484,11 @@ TEST_P(RocprimDeviceSelectLargeInputTests, LargeInputUnique) stream, debug_synchronous)); - if (use_graphs) + hipGraphExec_t graph_instance; + if(use_graphs) + { graph_instance = graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); + } size_t unique_count_output{}; HIP_CHECK(hipMemcpyWithStream(&unique_count_output, diff --git a/test/rocprim/test_device_transform.cpp b/test/rocprim/test_device_transform.cpp index 7fee5604e..5d400927d 100644 --- a/test/rocprim/test_device_transform.cpp +++ b/test/rocprim/test_device_transform.cpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2017-2022 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -67,24 +67,23 @@ using custom_short2 = test_utils::custom_test_type; using custom_int2 = test_utils::custom_test_type; using custom_double2 = test_utils::custom_test_type; -typedef ::testing::Types< - DeviceTransformParams, - DeviceTransformParams, - DeviceTransformParams, - DeviceTransformParams, - DeviceTransformParams, - DeviceTransformParams, - DeviceTransformParams, - DeviceTransformParams, - DeviceTransformParams, - DeviceTransformParams, - DeviceTransformParams, - DeviceTransformParams, - DeviceTransformParams, - DeviceTransformParams, - DeviceTransformParams, - DeviceTransformParams -> RocprimDeviceTransformTestsParams; +typedef ::testing::Types, + DeviceTransformParams, + DeviceTransformParams, + DeviceTransformParams, + DeviceTransformParams, + DeviceTransformParams, + DeviceTransformParams, + DeviceTransformParams, + DeviceTransformParams, + DeviceTransformParams, + DeviceTransformParams, + DeviceTransformParams, + DeviceTransformParams, + DeviceTransformParams, + DeviceTransformParams, + DeviceTransformParams> + RocprimDeviceTransformTestsParams; template struct size_limit_config { @@ -160,10 +159,11 @@ TYPED_TEST(RocprimDeviceTransformTests, Transform) std::transform(input.begin(), input.end(), expected.begin(), transform()); hipGraph_t graph; - hipGraphExec_t graph_instance; - if (TestFixture::use_graphs) + if(TestFixture::use_graphs) + { graph = test_utils::createGraphHelper(stream); - + } + // Run HIP_CHECK( rocprim::transform( @@ -173,9 +173,12 @@ TYPED_TEST(RocprimDeviceTransformTests, Transform) ) ); - if (TestFixture::use_graphs) + hipGraphExec_t graph_instance; + if(TestFixture::use_graphs) + { graph_instance = graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, false); - + } + HIP_CHECK(hipGetLastError()); HIP_CHECK(hipDeviceSynchronize()); @@ -279,9 +282,10 @@ TYPED_TEST(RocprimDeviceTransformTests, BinaryTransform) ); hipGraph_t graph; - hipGraphExec_t graph_instance; - if (TestFixture::use_graphs) + if(TestFixture::use_graphs) + { graph = test_utils::createGraphHelper(stream); + } // Run HIP_CHECK( @@ -292,9 +296,12 @@ TYPED_TEST(RocprimDeviceTransformTests, BinaryTransform) ) ); - if (TestFixture::use_graphs) + hipGraphExec_t graph_instance; + if(TestFixture::use_graphs) + { graph_instance = graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, false); - + } + HIP_CHECK(hipGetLastError()); HIP_CHECK(hipDeviceSynchronize()); @@ -384,16 +391,20 @@ void testLargeIndices() }; hipGraph_t graph; - hipGraphExec_t graph_instance; - if (UseGraphs) + if(UseGraphs) + { graph = test_utils::createGraphHelper(stream); + } // Run HIP_CHECK( rocprim::transform(input, output, size, flag_expected, stream, debug_synchronous)); - if (UseGraphs) + hipGraphExec_t graph_instance; + if(UseGraphs) + { graph_instance = graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, false); + } HIP_CHECK(hipGetLastError()); HIP_CHECK(hipDeviceSynchronize()); @@ -412,8 +423,10 @@ void testLargeIndices() } } - if (UseGraphs) - HIP_CHECK(hipStreamDestroy(stream)); + if(UseGraphs) + { + HIP_CHECK(hipStreamDestroy(stream)); + } } TEST(RocprimDeviceTransformTests, LargeIndices) diff --git a/test/rocprim/test_predicate_iterator.cpp b/test/rocprim/test_predicate_iterator.cpp new file mode 100644 index 000000000..ff0b70128 --- /dev/null +++ b/test/rocprim/test_predicate_iterator.cpp @@ -0,0 +1,288 @@ +// MIT License +// +// Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +#include "rocprim/iterator/transform_iterator.hpp" +#include "test_utils_data_generation.hpp" + +#include + +#include +#include +#include +#include + +#include + +#include + +#include +#include +#include + +struct is_odd +{ + // While this can be "constexpr T(const T&) const", we want to verify that + // it compiles without the constness. + template + __device__ __host__ bool operator()(T& a) + { + return a % 2; + } +}; + +template +struct set_to +{ + template + __device__ __host__ constexpr T operator()(const T&) const + { + return V; + } +}; + +template +struct increment_by +{ + template + __device__ __host__ T constexpr operator()(const T& a) const + { + return a + V; + } +}; + +struct identity +{ + template + __device__ __host__ constexpr T operator()(const T& a) const + { + return a; + } +}; + +TEST(RocprimPredicateIteratorTests, TypeTraits) +{ + using value_type = int; + + value_type* data{}; + bool* mask{}; + + auto m_it = rocprim::make_mask_iterator(data, mask); + + using m_it_t = decltype(m_it); + using proxy_t = m_it_t::proxy; + + static_assert(std::is_assignable::value, + "discard type is not assignable with underlying type, even though it should be!"); + static_assert(std::is_assignable::value, + "iterator is not assignable with underlying type via dereference, even though it " + "should be!"); + static_assert(std::is_assignable::value, + "iterator is not assignablle with underlying type via array index, even though " + "is should be!"); + + // Check if we can apply predicate iterator on a constant iterator + auto c_it = rocprim::make_constant_iterator(0); + auto p_it = rocprim::make_predicate_iterator(c_it, is_odd{}); + + static_assert( + std::is_convertible::value, + "predicate iterator is not convertible to underlying type, even though it should be!"); +} + +// Test that we are only writing if predicate holds +TEST(RocprimPredicateIteratorTests, HostWrite) +{ + using T = int; + static constexpr size_t size = 100; + + std::vector data(size); + std::iota(data.begin(), data.end(), 0); + + // Make iterator that only writes to odd values + auto odd_it = rocprim::make_predicate_iterator(data.begin(), is_odd{}); + + // Increment all values in that iterator + std::transform(data.begin(), data.end(), odd_it, [](auto v) { return v + 1; }); + + // Such that none of data is odd + ASSERT_TRUE(std::none_of(data.begin(), data.end(), is_odd{})); +} + +// Test that we are only reading if predicate holds, excluding the required read for the predicate +TEST(RocprimPredicateIteratorTests, HostRead) +{ + using T = int; + static constexpr size_t size = 100; + + auto is_odd_or_default = [](T v) { return v % 2 || v == T{}; }; + + std::vector data(size); + std::iota(data.begin(), data.end(), 0); + + // Make iterator that only reads odd values + auto odd_it = rocprim::make_predicate_iterator(data.begin(), is_odd{}); + + // Read all values from that iterator + for(size_t i = 0; i < size; ++i) + { + data[i] = odd_it[i]; + } + + // Such that all of data is odd or default + ASSERT_TRUE(std::all_of(data.begin(), data.end(), is_odd_or_default)); +} + +// Test that we are only writing if predicate holds +TEST(RocprimPredicateIteratorTests, HostMaskWrite) +{ + using T = int; + static constexpr size_t size = 100; + + std::vector data(size); + std::vector mask = test_utils::get_random_data(size, false, true, 0); + std::iota(data.begin(), data.end(), 0); + test_utils::get_random_data(size, false, true, 0); + + auto masked_it = rocprim::make_predicate_iterator(data.begin(), mask.begin(), identity{}); + std::transform(data.begin(), data.end(), masked_it, set_to<-1>{}); + + for(size_t i = 0; i < size; ++i) + { + if(mask[i]) + { + ASSERT_EQ(data[i], -1); + } + else + { + ASSERT_EQ(data[i], i); + } + } +} + +// Test that we are only reading if predicate holds, excluding the required read for the predicate +TEST(RocprimPredicateIteratorTests, HostMaskRead) +{ + using T = int; + static constexpr size_t size = 100; + + std::vector data(size); + std::vector mask = test_utils::get_random_data(size, false, true, 0); + std::iota(data.begin(), data.end(), 0); + + auto masked_it = rocprim::make_mask_iterator(data.begin(), mask.begin()); + + for(size_t i = 0; i < size; ++i) + { + data[i] = masked_it[i]; + } + + for(size_t i = 0; i < size; ++i) + { + if(mask[i]) + { + ASSERT_EQ(data[i], i); + } + else + { + ASSERT_EQ(data[i], T{}); + } + } +} + +// Test if predicate iterator can be used on device +TEST(RocprimPredicateIteratorTests, DeviceInplace) +{ + using T = int; + using predicate = is_odd; + using transform = increment_by<5>; + + constexpr size_t size = 100; + constexpr size_t data_size = sizeof(T) * size; + + std::vector h_data(size); + std::iota(h_data.begin(), h_data.end(), 0); + + T* d_data; + HIP_CHECK(hipMalloc(&d_data, data_size)); + HIP_CHECK(hipMemcpy(d_data, h_data.data(), data_size, hipMemcpyHostToDevice)); + + auto w_it = rocprim::make_predicate_iterator(d_data, predicate{}); + + HIP_CHECK(rocprim::transform(d_data, w_it, size, transform{})); + + HIP_CHECK(hipMemcpy(h_data.data(), d_data, data_size, hipMemcpyDeviceToHost)); + HIP_CHECK(hipFree(d_data)); + + for(T i = 0; i < T{size}; ++i) + { + if(predicate{}(i)) + { + ASSERT_EQ(h_data[i], transform{}(i)); + } + else + { + ASSERT_EQ(h_data[i], i); + } + } +} + +// Test if predicate iterator can be used on device +TEST(RocprimPredicateIteratorTests, DeviceRead) +{ + using T = int; + using predicate = is_odd; + using transform = increment_by<5>; + + constexpr size_t size = 100; + constexpr size_t data_size = sizeof(T) * size; + + std::vector h_data(size); + std::iota(h_data.begin(), h_data.end(), 0); + + T* d_input; + T* d_output; + HIP_CHECK(hipMalloc(&d_input, data_size)); + HIP_CHECK(hipMalloc(&d_output, data_size)); + HIP_CHECK(hipMemcpy(d_input, h_data.data(), data_size, hipMemcpyHostToDevice)); + + auto t_it = rocprim::make_transform_iterator(d_input, transform{}); + auto r_it = rocprim::make_predicate_iterator(t_it, d_input, predicate{}); + + HIP_CHECK(rocprim::transform(r_it, d_output, size, identity{})); + + HIP_CHECK(hipMemcpy(h_data.data(), d_output, data_size, hipMemcpyDeviceToHost)); + HIP_CHECK(hipFree(d_input)); + HIP_CHECK(hipFree(d_output)); + + for(T i = 0; i < T{size}; ++i) + { + if(predicate{}(i)) + { + ASSERT_EQ(h_data[i], transform{}(i)); + } + else + { + ASSERT_EQ(h_data[i], T{}); + } + } + std::cout << std::endl; +} diff --git a/test/rocprim/test_radix_key_codec.cpp b/test/rocprim/test_radix_key_codec.cpp new file mode 100644 index 000000000..58b4b4c54 --- /dev/null +++ b/test/rocprim/test_radix_key_codec.cpp @@ -0,0 +1,451 @@ +// MIT License +// +// Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +#include "../common_test_header.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include +#include +#include +#include + +struct extract_digit_params +{ + unsigned int start; + unsigned int radix_bits; + unsigned int expected_result; +}; + +std::ostream& operator<<(std::ostream& os, const extract_digit_params& params) +{ + std::stringstream sstream; + sstream << "{ start: " << params.start << ", radix_bits: " << params.radix_bits + << ", expected_result: 0x" << std::hex << params.expected_result << " }"; + return os << sstream.str(); +} + +class RadixKeyCodecTest : public ::testing::TestWithParam +{}; + +INSTANTIATE_TEST_SUITE_P(RocprimBlockRadixSort, + RadixKeyCodecTest, + ::testing::Values(extract_digit_params{0, 8, 0x01}, + extract_digit_params{8, 16, 0xcdef}, + extract_digit_params{24, 8, 0xab}, + extract_digit_params{7, 11, 0b01'1110'1111'0}, + extract_digit_params{0, 1, 1}, + extract_digit_params{1, 1, 0}, + extract_digit_params{0, 32, 0xabcdef01}, + extract_digit_params{1, 31, 0xabcdef01 >> 1}, + extract_digit_params{8, 12, 0xdef}, + extract_digit_params{12, 12, 0xcde}, + extract_digit_params{12, 13, 0x1cde}, + extract_digit_params{12, 20, 0xabcde})); + +struct custom_key +{ + uint8_t a; + uint16_t b; + uint8_t c; +}; + +struct custom_key_decomposer +{ + auto operator()(custom_key& value) const + { + return ::rocprim::tuple{value.a, value.b, value.c}; + } +}; + +TEST_P(RadixKeyCodecTest, ExtractDigit) +{ + using codec = rocprim::detail::radix_key_codec; + + const custom_key key{0xab, 0xcdef, 0x01}; + const auto digit = codec::extract_digit(key, + GetParam().start, + GetParam().radix_bits, + custom_key_decomposer{}); + + ASSERT_EQ(digit, GetParam().expected_result); +} + +class RadixKeyCodecUnusedTest : public ::testing::TestWithParam +{}; + +INSTANTIATE_TEST_SUITE_P(RocprimBlockRadixSort, + RadixKeyCodecUnusedTest, + ::testing::Values(extract_digit_params{0, 16, 0xab01}, + extract_digit_params{0, 8, 0x01}, + extract_digit_params{8, 8, 0xab}, + extract_digit_params{1, 14, 0b010'1011'0000'000}, + extract_digit_params{14, 2, 0b10})); + +struct custom_key_decomposer_with_unused +{ + auto operator()(custom_key& value) const + { + return ::rocprim::tuple{value.a, value.c}; + } +}; + +TEST_P(RadixKeyCodecUnusedTest, ExtractDigitUnused) +{ + using codec = rocprim::detail::radix_key_codec; + + const custom_key key{0xab, 0xcdef, 0x01}; + const auto digit = codec::extract_digit(key, + GetParam().start, + GetParam().radix_bits, + custom_key_decomposer_with_unused{}); + + ASSERT_EQ(digit, GetParam().expected_result); +} + +TEST(RadixKeyCodecTest, ExtractCustomTestType) +{ + using T = test_utils::custom_test_type; + using codec_t = rocprim::detail::radix_key_codec; + + T value{12, 34}; + + test_utils::custom_test_type_decomposer decomposer; + codec_t::encode_inplace(value, decomposer); + + ASSERT_EQ(0x7FFFFFDD, codec_t::extract_digit(value, 0, 32, decomposer)); + ASSERT_EQ(0x7FFFFFF3, codec_t::extract_digit(value, 32, 32, decomposer)); +} + +template +struct RadixMergeCompareTest : public ::testing::Test +{ + using params = Params; +}; + +template +struct RadixMergeCompareTestParams +{ + static constexpr bool descending = Descending; +}; + +using RadixMergeCompareTestTypes + = ::testing::Types, RadixMergeCompareTestParams>; +TYPED_TEST_SUITE(RadixMergeCompareTest, RadixMergeCompareTestTypes); + +struct custom_large_key +{ + uint16_t a; + int64_t b; + uint8_t c; + double d; + + static constexpr size_t bits = 8 * (sizeof(a) + sizeof(b) + sizeof(c) + sizeof(d)); +}; + +struct custom_large_key_decomposer +{ + auto operator()(custom_large_key& value) const + { + return ::rocprim::tuple{value.a, + value.b, + value.c, + value.d}; + } +}; + +TYPED_TEST(RadixMergeCompareTest, FullRange) +{ + using params = typename TestFixture::params; + constexpr bool descending = params::descending; + using merge_compare = rocprim::detail:: + radix_merge_compare; + + const merge_compare comparator(0, custom_large_key::bits, custom_large_key_decomposer{}); + + { + const custom_large_key lhs{1, 2, 3, 4}; + const custom_large_key rhs{3, 2, 1, 11}; + EXPECT_TRUE(descending != comparator(lhs, rhs)); + } + { + const custom_large_key lhs{1, 3, 3, 4}; + const custom_large_key rhs{1, 2, 1, 11}; + EXPECT_FALSE(descending != comparator(lhs, rhs)); + } + { + const custom_large_key lhs{1, 2, 3, 4}; + const custom_large_key rhs{1, 2, 1, 11}; + EXPECT_FALSE(descending != comparator(lhs, rhs)); + } + { + const custom_large_key lhs{1, 2, 3, 4}; + const custom_large_key rhs{1, 2, 3, 11}; + EXPECT_TRUE(descending != comparator(lhs, rhs)); + } + { + const custom_large_key lhs{1, 2, 3, 11}; + const custom_large_key rhs{1, 2, 3, 11}; + EXPECT_FALSE(comparator(lhs, rhs)); + } +} + +TYPED_TEST(RadixMergeCompareTest, NotNullStartBit) +{ + using params = typename TestFixture::params; + constexpr bool descending = params::descending; + using merge_compare = rocprim::detail:: + radix_merge_compare; + + constexpr unsigned int start_bit = 64; + const merge_compare comparator(start_bit, + custom_large_key::bits - start_bit, + custom_large_key_decomposer{}); + + { + const custom_large_key lhs{3, 2, 3, 4}; + const custom_large_key rhs{3, 2, 1, 11}; + EXPECT_FALSE(descending != comparator(lhs, rhs)); + } + { + const custom_large_key lhs{3, 2, 1, 4}; + const custom_large_key rhs{3, 2, 3, 11}; + EXPECT_TRUE(descending != comparator(lhs, rhs)); + } + { + const custom_large_key lhs{3, 2, 1, 4}; + const custom_large_key rhs{3, 2, 1, 11}; + EXPECT_FALSE(comparator(lhs, rhs)); + } +} + +TYPED_TEST(RadixMergeCompareTest, MidRange) +{ + using params = typename TestFixture::params; + constexpr bool descending = params::descending; + using merge_compare = rocprim::detail:: + radix_merge_compare; + + constexpr unsigned int start_bit = 64; + constexpr unsigned int excluded_bits = 16; + const merge_compare comparator(start_bit, + custom_large_key::bits - start_bit - excluded_bits, + custom_large_key_decomposer{}); + + { + const custom_large_key lhs{3, 2, 3, 4}; + const custom_large_key rhs{4, 2, 1, 11}; + EXPECT_FALSE(descending != comparator(lhs, rhs)); + } + { + const custom_large_key lhs{3, 2, 1, 4}; + const custom_large_key rhs{4, 2, 3, 11}; + EXPECT_TRUE(descending != comparator(lhs, rhs)); + } + { + const custom_large_key lhs{3, 2, 3, 4}; + const custom_large_key rhs{4, 2, 3, 11}; + EXPECT_FALSE(comparator(lhs, rhs)); + } +} + +template +struct TypedRadixKeyCodecTest : public ::testing::Test +{ + using params = Params; +}; + +template +struct TypedRadixKeyCodecTestParams +{ + using Key = KeyType; + static constexpr unsigned int start_bit = StartBit; + static constexpr unsigned int radix_bits = RadixBits; +}; + +template +struct custom_test_type_decomposer +{ + auto operator()(test_utils::custom_test_type& value) const + { + return ::rocprim::tuple{value.x, value.y}; + } +}; + +using TypedRadixKeyCodecTestTypes + = ::testing::Types, + TypedRadixKeyCodecTestParams, + TypedRadixKeyCodecTestParams, + TypedRadixKeyCodecTestParams, + TypedRadixKeyCodecTestParams, + TypedRadixKeyCodecTestParams, + TypedRadixKeyCodecTestParams, + TypedRadixKeyCodecTestParams, + TypedRadixKeyCodecTestParams, + TypedRadixKeyCodecTestParams, + TypedRadixKeyCodecTestParams, + TypedRadixKeyCodecTestParams, + TypedRadixKeyCodecTestParams, + TypedRadixKeyCodecTestParams, + TypedRadixKeyCodecTestParams>; + +TYPED_TEST_SUITE(TypedRadixKeyCodecTest, TypedRadixKeyCodecTestTypes); + +template +void encode_then_decode_test(Key key, Decomposer decomposer) +{ + using codec_t = ::rocprim::radix_key_codec; + using BitKey = typename codec_t::bit_key_type; + + BitKey bit_key = codec_t::encode(key, decomposer); + codec_t::encode_inplace(key, decomposer); + + Key decoded_key = codec_t::decode(bit_key, decomposer); + codec_t::decode_inplace(key, decomposer); + + test_utils::assert_eq(decoded_key, key); +} + +template +void encode_then_decode_test(Key key, Decomposer decomposer = {}) +{ + encode_then_decode_test(key, decomposer); /*decreasing sort*/ + encode_then_decode_test(key, decomposer); /*increasing sort*/ +} + +template +void encode_then_extract_test(Key key, + const unsigned int start_bit, + const unsigned int radix_bits, + Decomposer decomposer) +{ + using codec_t = ::rocprim::radix_key_codec; + using BitKey = typename codec_t::bit_key_type; + + BitKey bit_key = codec_t::encode(key, decomposer); + codec_t::encode_inplace(key, decomposer); + + const unsigned int bits = codec_t::extract_digit(bit_key, start_bit, radix_bits); + const unsigned int inplace_bits + = codec_t::extract_digit(key, start_bit, radix_bits, decomposer); + + test_utils::assert_eq(bits, inplace_bits); +} + +template +void encode_then_extract_test(Key key, + const unsigned int start_bit, + const unsigned int radix_bits, + Decomposer decomposer = {}) +{ + encode_then_extract_test(key, start_bit, radix_bits, decomposer); /*decreasing sort*/ + encode_then_extract_test(key, start_bit, radix_bits, decomposer); /*increasing sort*/ +} + +template +void encode_then_extract_test_custom(Key key, + const unsigned int start_bit, + const unsigned int radix_bits, + Decomposer decomposer) +{ + using codec_t = ::rocprim::radix_key_codec; + using BitKey = typename codec_t::bit_key_type; + + BitKey bit_key = codec_t::encode(key, decomposer); + codec_t::encode_inplace(key, decomposer); + + const unsigned int bits = codec_t::extract_digit(bit_key, start_bit, radix_bits, decomposer); + const unsigned int inplace_bits + = codec_t::extract_digit(key, start_bit, radix_bits, decomposer); + + test_utils::assert_eq(bits, inplace_bits); +} + +template +void encode_then_extract_test_custom(Key key, + const unsigned int start_bit, + const unsigned int radix_bits, + Decomposer decomposer = {}) +{ + encode_then_extract_test_custom(key, + start_bit, + radix_bits, + decomposer); /*decreasing sort*/ + encode_then_extract_test_custom(key, + start_bit, + radix_bits, + decomposer); /*increasing sort*/ +} + +TYPED_TEST(TypedRadixKeyCodecTest, EncodeDecodeExtract) +{ + using params = typename TestFixture::params; + using Key = typename params::Key; + using CustomKey = typename test_utils::custom_test_type; + using CustomDecomposer = custom_test_type_decomposer; + constexpr unsigned int start_bit = params::start_bit; + constexpr unsigned int radix_bits = params::radix_bits; + + CustomDecomposer custom_decomposer{}; + + for(size_t seed_index = 0; seed_index < random_seeds_count + seed_size; seed_index++) + { + unsigned int seed_value + = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count]; + SCOPED_TRACE(testing::Message() << "with seed = " << seed_value); + + const size_t size = (1 << 20) + 123; + std::vector input_keys + = test_utils::get_random_data(size, + test_utils::numeric_limits::min(), + test_utils::numeric_limits::max(), + seed_value); + + for(size_t i = 0; i < size; ++i) + { + SCOPED_TRACE(testing::Message() << "with index = " << i); + + encode_then_decode_test(input_keys[i]); + + encode_then_extract_test(input_keys[i], start_bit, radix_bits); + + // With custom types + encode_then_decode_test(CustomKey(input_keys[i]), custom_decomposer); + + encode_then_extract_test_custom(CustomKey(input_keys[i]), + start_bit, + radix_bits, + custom_decomposer); + } + } +} diff --git a/test/rocprim/test_temporary_storage_partitioning.cpp b/test/rocprim/test_temporary_storage_partitioning.cpp index a9f558da1..92c07042a 100644 --- a/test/rocprim/test_temporary_storage_partitioning.cpp +++ b/test/rocprim/test_temporary_storage_partitioning.cpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2022 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2022-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -23,7 +23,6 @@ #include "common_test_header.hpp" #include -#include #include "test_utils_types.hpp" diff --git a/test/rocprim/test_utils.hpp b/test/rocprim/test_utils.hpp index f5010e8eb..ccd87ee6d 100644 --- a/test/rocprim/test_utils.hpp +++ b/test/rocprim/test_utils.hpp @@ -140,27 +140,21 @@ template struct select_plus_operator_host { typedef ::rocprim::plus type; - typedef T acc_type; - // #156 temporarily disable half test due to known issue with converting from double to half - // cast_type is the type that should be used to cast acc_type to T. - // overload needed temporarily due to compiler bug in half conversions - typedef T cast_type; + typedef T acc_type; }; template<> struct select_plus_operator_host<::rocprim::half> { typedef ::rocprim::plus type; - typedef double acc_type; - typedef float cast_type; + typedef double acc_type; }; template<> struct select_plus_operator_host<::rocprim::bfloat16> { typedef ::rocprim::plus type; - typedef double acc_type; - typedef ::rocprim::bfloat16 cast_type; + typedef double acc_type; }; template::value_type, + rocprim::half>::value, + bool>::type + = false> +void iota_modulo(ForwardIt first, ForwardIt last, T lbound, const size_t ubound) +{ + const T value_mod = static_cast(lbound) < ubound ? lbound : 0; + using value_type = typename std::iterator_traits::value_type; + + for(T value = value_mod; first != last; value++, *first++) + { + if(static_cast(value) >= ubound) + { + value = value_mod; + } + *first = static_cast(value); + } +} + +// Necessary because for rocprim::half even though lbound < ubound it gets cast as a greater +// value, as precision is bigger for values closer to the maximum. +template::value_type, + rocprim::half>::value, + bool>::type + = true> +void iota_modulo(ForwardIt first, ForwardIt last, T lbound, const size_t ubound) +{ + const T value_mod = static_cast(lbound) < ubound ? lbound : 0; + using value_type = rocprim::half; + + for(T value = value_mod; first != last; value++, *first++) + { + if(static_cast(static_cast(value)) >= ubound) + { + value = value_mod; + } + *first = static_cast(value); + } +} + #define SKIP_IF_UNSUPPORTED_WARP_SIZE(test_warp_size, device_id) \ { \ unsigned int host_warp_size; \ diff --git a/test/rocprim/test_utils_assertions.hpp b/test/rocprim/test_utils_assertions.hpp index 073b783c9..794f8fc11 100644 --- a/test/rocprim/test_utils_assertions.hpp +++ b/test/rocprim/test_utils_assertions.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2021-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2021-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -31,6 +31,10 @@ // Std::memcpy and std::memcmp #include +#include +#include +#include +#include #include namespace test_utils { @@ -60,6 +64,23 @@ void assert_eq(const std::vector& result, const std::vector& expected, con } } +template +void assert_eq(const std::vector>& result, + const std::vector>& expected, + const size_t max_length = SIZE_MAX) +{ + if(max_length == SIZE_MAX || max_length > expected.size()) + { + ASSERT_EQ(result.size(), expected.size()); + } + for(size_t i = 0; i < std::min(result.size(), max_length); i++) + { + if(bit_equal(result[i].x, expected[i].x) && bit_equal(result[i].y, expected[i].y)) + continue; // Check bitwise equality for +NaN, -NaN, +0.0, -0.0, +inf, -inf. + ASSERT_EQ(result[i], expected[i]) << "where index = " << i; + } +} + template<> inline void assert_eq(const std::vector& result, const std::vector& expected, const size_t max_length) { @@ -89,6 +110,14 @@ void assert_eq(const T& result, const T& expected) ASSERT_EQ(result, expected); } +template +void assert_eq(const custom_test_type& result, const custom_test_type& expected) +{ + if(bit_equal(result.x, expected.x) && bit_equal(result.y, expected.y)) + return; // Check bitwise equality for +NaN, -NaN, +0.0, -0.0, +inf, -inf. + ASSERT_EQ(result, expected); +} + template<> inline void assert_eq(const rocprim::half& result, const rocprim::half& expected) { @@ -102,6 +131,22 @@ inline void assert_eq(const rocprim::bfloat16& result, const if(bit_equal(result, expected)) return; // Check bitwise equality for +NaN, -NaN, +0.0, -0.0, +inf, -inf. ASSERT_EQ(bfloat16_to_native(result), bfloat16_to_native(expected)); } + +template +void assert_eq(ResultIt result_begin, + ResultIt result_end, + ExpectedIt expected_begin, + ExpectedIt expected_end) +{ + ASSERT_EQ(std::distance(result_begin, result_end), std::distance(expected_begin, expected_end)); + auto result_it = result_begin; + auto expected_it = expected_begin; + for(; result_it != result_end; ++result_it, ++expected_it) + { + assert_eq(static_cast::value_type>(*result_it), + static_cast::value_type>(*expected_it)); + } +} // end assert_eq // begin assert_near @@ -230,96 +275,65 @@ auto assert_near(const custom_test_type& result, const custom_test_type& e // End assert_near -template -void assert_bit_eq(const std::vector& result, const std::vector& expected) -{ - ASSERT_EQ(result.size(), expected.size()); - for(size_t i = 0; i < result.size(); i++) - { - if(!bit_equal(result[i], expected[i])) - { - FAIL() << "Expected strict/bitwise equality of these values: " << std::endl - << " result[i]: " << result[i] << std::endl - << " expected[i]: " << expected[i] << std::endl - << "where index = " << i; - } - } -} #if ROCPRIM_HAS_INT128_SUPPORT -inline void assert_bit_eq(const std::vector<__int128_t>& result, - const std::vector<__int128_t>& expected) +template +auto operator<<(std::ostream& os, const T& value) + -> std::enable_if_t::value || std::is_same::value, + std::ostream&> { - ASSERT_EQ(result.size(), expected.size()); - - auto to_string = [](__int128_t value) - { - static const char* charmap = "0123456789"; - - std::string result; - result.reserve(41); // max. 40 digits possible ( uint64_t has 20) plus sign - __uint128_t helper = (value < 0) ? -value : value; + static const char* charmap = "0123456789"; - do - { - result += charmap[helper % 10]; - helper /= 10; - } - while(helper); - if(value < 0) - { - result += "-"; - } - std::reverse(result.begin(), result.end()); - return result; - }; + std::string result; + result.reserve(41); // max. 40 digits possible ( uint64_t has 20) plus sign + __uint128_t helper = (value < 0) ? -value : value; - for(size_t i = 0; i < result.size(); i++) + do { - if(!bit_equal(result[i], expected[i])) - { - FAIL() << "Expected strict/bitwise equality of these values: " << std::endl - << " result[i]: " << to_string(result[i]) << std::endl - << " expected[i]: " << to_string(expected[i]) << std::endl - << "where index = " << i; - } + result += charmap[helper % 10]; + helper /= 10; + } + while(helper); + if(value < 0) + { + result += "-"; } + std::reverse(result.begin(), result.end()); + + os << result; + return os; } -inline void assert_bit_eq(const std::vector<__uint128_t>& result, - const std::vector<__uint128_t>& expected) +#endif + +template +void assert_bit_eq(IterA result_begin, IterA result_end, IterB expected_begin, IterB expected_end) { - ASSERT_EQ(result.size(), expected.size()); + using value_a_t = typename std::iterator_traits::value_type; + using value_b_t = typename std::iterator_traits::value_type; - auto to_string = [](__uint128_t value) + ASSERT_EQ(std::distance(result_begin, result_end), std::distance(expected_begin, expected_end)); + auto result_it = result_begin; + auto expected_it = expected_begin; + for(size_t index = 0; result_it != result_end; ++result_it, ++expected_it, ++index) { - static const char* charmap = "0123456789"; + // The cast is needed, because the argument can be an std::vector iterator, which's operator* + // returns a proxy object that must be converted to bool + const auto result = static_cast(*result_it); + const auto expected = static_cast(*expected_it); - std::string result; - result.reserve(40); // max. 40 digits possible ( uint64_t has 20) - __uint128_t helper = value; - - do - { - result += charmap[helper % 10]; - helper /= 10; - } - while(helper); - std::reverse(result.begin(), result.end()); - return result; - }; - - for(size_t i = 0; i < result.size(); i++) - { - if(!bit_equal(result[i], expected[i])) + if(!bit_equal(result, expected)) { + std::stringstream result_str; + std::stringstream expected_str; + result_str << result; + expected_str << expected; FAIL() << "Expected strict/bitwise equality of these values: " << std::endl - << " result[i]: " << to_string(result[i]) << std::endl - << " expected[i]: " << to_string(expected[i]) << std::endl - << "where index = " << i; + << " result[i]: " << result_str.str() << std::endl + << " expected[i]: " << expected_str.str() << std::endl + << "where index = " << index; } } } -#endif } #endif //ROCPRIM_TEST_UTILS_ASSERTIONS_HPP diff --git a/test/rocprim/test_utils_custom_float_type.hpp b/test/rocprim/test_utils_custom_float_type.hpp index 46a856391..76601fd42 100644 --- a/test/rocprim/test_utils_custom_float_type.hpp +++ b/test/rocprim/test_utils_custom_float_type.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2022 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2022-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -24,7 +24,7 @@ #include "test_utils_custom_test_types.hpp" // For radix_key_codec -#include +#include #include #include diff --git a/test/rocprim/test_utils_data_generation.hpp b/test/rocprim/test_utils_data_generation.hpp index 315e92d23..5c33dbb1b 100644 --- a/test/rocprim/test_utils_data_generation.hpp +++ b/test/rocprim/test_utils_data_generation.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2021-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2021-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -140,20 +140,26 @@ inline auto convert_to_native(const T& value) // Helper class to generate a vector of special values for any type template -struct special_values { +struct special_values +{ private: // sign_bit_flip needed because host-side operators for __half are missing. (e.g. -__half unary operator or (-1*) __half*__half binary operator - static T sign_bit_flip(T value){ + static T sign_bit_flip(T value) + { uint8_t* data = reinterpret_cast(&value); - data[sizeof(T)-1] ^= 0x80; + data[sizeof(T) - 1] ^= 0x80; return value; } public: - static std::vector vector(){ - if(std::is_integral::value){ + static std::vector vector() + { + if(std::is_integral::value) + { return std::vector(); - }else { + } + else + { std::vector r = {test_utils::numeric_limits::quiet_NaN(), sign_bit_flip(test_utils::numeric_limits::quiet_NaN()), // TODO: switch on when signaling_NaN will be supported on NVIDIA @@ -169,26 +175,26 @@ struct special_values { }; // end of special_values helpers +template +using it_value_t = typename std::iterator_traits::value_type; + /// Insert special values of type T at a random place in the source vector /// \tparam T /// \param source The source vector to modify -template -void add_special_values(std::vector& source, seed_type seed_value) +template +void add_special_values(OutputIter it, const size_t size, Generator&& gen) { - engine_type gen{seed_value}; + using T = it_value_t; std::vector special_values = test_utils::special_values::vector(); - if(source.size() > special_values.size()) + if(size > special_values.size()) { - unsigned int start = gen() % (source.size() - special_values.size()); - std::copy(special_values.begin(), special_values.end(), source.begin() + start); + unsigned int start = gen() % (size - special_values.size()); + std::copy(special_values.begin(), special_values.end(), it + start); } } -template -using it_value_t = typename std::iterator_traits::value_type; - template -inline OutputIter segmented_generate_n(OutputIter it, size_t size, Generator gen) +inline OutputIter segmented_generate_n(OutputIter it, size_t size, Generator&& gen) { const size_t segment_size = size / random_data_generation_segments; if(segment_size == 0) @@ -215,7 +221,7 @@ inline OutputIter segmented_generate_n(OutputIter it, size_t size, Generator gen } template -inline auto generate_random_data_n(OutputIter it, size_t size, U min, V max, Generator& gen) +inline auto generate_random_data_n(OutputIter it, size_t size, U min, V max, Generator&& gen) -> std::enable_if_t, __int128_t>::value, OutputIter> { using T = it_value_t; @@ -231,7 +237,7 @@ inline auto generate_random_data_n(OutputIter it, size_t size, U min, V max, Gen } template -inline auto generate_random_data_n(OutputIter it, size_t size, U min, V max, Generator& gen) +inline auto generate_random_data_n(OutputIter it, size_t size, U min, V max, Generator&& gen) -> std::enable_if_t, __uint128_t>::value, OutputIter> { using T = it_value_t; @@ -247,7 +253,7 @@ inline auto generate_random_data_n(OutputIter it, size_t size, U min, V max, Gen } template -inline auto generate_random_data_n(OutputIter it, size_t size, U min, V max, Generator& gen) +inline auto generate_random_data_n(OutputIter it, size_t size, U min, V max, Generator&& gen) -> std::enable_if_t>::value, OutputIter> { using T = it_value_t; @@ -266,7 +272,7 @@ inline auto generate_random_data_n(OutputIter it, size_t size, U min, V max, Gen } template -inline auto generate_random_data_n(OutputIter it, size_t size, U min, V max, Generator& gen) +inline auto generate_random_data_n(OutputIter it, size_t size, U min, V max, Generator&& gen) -> std::enable_if_t>::value && !is_custom_test_type>::value, OutputIter> @@ -286,18 +292,29 @@ inline auto generate_random_data_n(OutputIter it, size_t size, it_value_t min, it_value_t max, - Generator& gen) - -> std::enable_if_t>::value - && std::is_integral::value_type>::value, - OutputIter> + Generator&& gen) + -> std::enable_if_t< + is_custom_test_type>::value + && rocprim::is_integral::value_type>::value, + OutputIter> { using T = it_value_t; + using value_t = typename T::value_type; - std::uniform_int_distribution distribution(min.x, max.x); + using distribution_t + = std::conditional_t::value, + value_t, + std::conditional_t::value, int, unsigned int>>; + + std::uniform_int_distribution distribution(static_cast(min.x), + static_cast(max.x)); return segmented_generate_n(it, size, - [&]() { return T(distribution(gen), distribution(gen)); }); + [&]() { + return T(static_cast(distribution(gen)), + static_cast(distribution(gen))); + }); } template @@ -305,10 +322,10 @@ inline auto generate_random_data_n(OutputIter it, size_t size, it_value_t min, it_value_t max, - Generator& gen) + Generator&& gen) -> std::enable_if_t< is_custom_test_type>::value - && std::is_floating_point::value_type>::value, + && rocprim::is_floating_point::value_type>::value, OutputIter> { using T = typename std::iterator_traits::value_type; @@ -325,7 +342,7 @@ inline auto generate_random_data_n(OutputIter i size_t size, typename it_value_t::value_type min, typename it_value_t::value_type max, - Generator& gen) + Generator&& gen) -> std::enable_if_t>::value && std::is_integral::value_type>::value, OutputIter> diff --git a/test/rocprim/test_utils_hipgraphs.hpp b/test/rocprim/test_utils_hipgraphs.hpp index a108256ba..9d20b0ed7 100644 --- a/test/rocprim/test_utils_hipgraphs.hpp +++ b/test/rocprim/test_utils_hipgraphs.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2021-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2021-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -21,6 +21,10 @@ #ifndef ROCPRIM_TEST_UTILS_HIPGRAPHS_HPP #define ROCPRIM_TEST_UTILS_HIPGRAPHS_HPP +#include "common_test_header.hpp" + +#include + // Helper functions for testing with hipGraph stream capture. // Note: graphs will not work on the default stream. namespace test_utils diff --git a/test/rocprim/test_utils_sort_comparator.hpp b/test/rocprim/test_utils_sort_comparator.hpp index 3f9358fe8..bd4d6db00 100644 --- a/test/rocprim/test_utils_sort_comparator.hpp +++ b/test/rocprim/test_utils_sort_comparator.hpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -25,124 +25,155 @@ #include -#include "test_utils_half.hpp" #include "test_utils_bfloat16.hpp" +#include "test_utils_custom_float_type.hpp" +#include "test_utils_custom_test_types.hpp" +#include "test_utils_half.hpp" + +#include +#include namespace test_utils { +namespace detail +{ -template -constexpr bool is_floating_nan_host(const T& a) +template::value && !std::is_same::value) + || std::is_same::value + || std::is_same::value, + int> + = 0> +Key to_bits(const Key key) { - return (a != a); + static constexpr Key radix_mask_upper + = EndBit == 8 * sizeof(Key) ? ~Key(0) : static_cast((Key(1) << EndBit) - 1); + static constexpr Key radix_mask_bottom = static_cast((Key(1) << StartBit) - 1); + static constexpr Key radix_mask = radix_mask_upper ^ radix_mask_bottom; + + return key & radix_mask; } -template -struct key_comparator -{}; + class Key, + std::enable_if_t::value, int> = 0> +Key to_bits(const Key key) +{ + using unsigned_bits_type = typename rocprim::get_unsigned_bits_type::unsigned_type; + unsigned_bits_type bit_key; + std::memcpy(&bit_key, &key, sizeof(bit_key)); + return to_bits(bit_key); +} -template -struct key_comparator::value>::type> +template::value + // custom_float_type is used in testing a hacky way of + // radix sorting custom types. A part of this workaround + // is to specialize rocprim::is_floating_point + // that we must counter here. + && !std::is_same::value, + int> + = 0> +auto to_bits(const Key key) { - static constexpr Key radix_mask_upper - = EndBit == 8 * sizeof(Key) ? ~Key(0) : (Key(1) << EndBit) - 1; - static constexpr Key radix_mask_bottom = (Key(1) << StartBit) - 1; - static constexpr Key radix_mask = radix_mask_upper ^ radix_mask_bottom; + using unsigned_bits_type = typename rocprim::get_unsigned_bits_type::unsigned_type; + + unsigned_bits_type bit_key; + memcpy(&bit_key, &key, sizeof(Key)); - bool operator()(const Key& lhs, const Key& rhs) const + // Remove signed zero, this case is supposed to be treated the same as + // unsigned zero in rocprim sorting algorithms. + constexpr unsigned_bits_type minus_zero = unsigned_bits_type{1} << (8 * sizeof(Key) - 1); + // Positive and negative zero should compare the same. + if(bit_key == minus_zero) { - Key l = lhs & radix_mask; - Key r = rhs & radix_mask; - return Descending ? (r < l) : (l < r); + bit_key = 0; } -}; - -template -struct key_comparator::value>::type> -{ - static constexpr Key radix_mask_upper - = EndBit == 8 * sizeof(Key) ? ~Key(0) : (Key(1) << EndBit) - 1; - static constexpr Key radix_mask_bottom = (Key(1) << StartBit) - 1; - static constexpr Key radix_mask = radix_mask_upper ^ radix_mask_bottom; - - bool operator()(const Key& lhs, const Key& rhs) const + // Flip bits mantissa and exponent if the key is negative, so as to make + // 'more negative' values compare before 'less negative'. + if(bit_key & minus_zero) { - Key l = lhs & radix_mask; - Key r = rhs & radix_mask; - return Descending ? (r < l) : (l < r); + bit_key ^= ~minus_zero; } -}; + // Make negatives compare before positives. + bit_key ^= minus_zero; -template -struct key_comparator::value>::type> -{ - static constexpr Key radix_mask_upper - = EndBit == 8 * sizeof(Key) ? ~Key(0) : (Key(1) << EndBit) - 1; - static constexpr Key radix_mask_bottom = (Key(1) << StartBit) - 1; - static constexpr Key radix_mask = radix_mask_upper ^ radix_mask_bottom; + return to_bits(bit_key); +} - bool operator()(const Key& lhs, const Key& rhs) const +template::value + // custom_float_type is used in testing a hacky way of + // radix sorting custom types. A part of this workaround + // is to specialize rocprim::is_custom_test_type + // that we must counter here. + && !std::is_same::value, + int> + = 0> +auto to_bits(const Key& key) +{ + using inner_t = typename inner_type::type; + using unsigned_bits_type = typename ::rocprim::get_unsigned_bits_type::unsigned_type; + // For two doubles, we need uint128, but that is not part of rocprim::get_unsigned_bits_type + using result_bits_type = std::conditional_t< + sizeof(inner_t) == 8, + __uint128_t, + typename rocprim::get_unsigned_bits_type(8), + sizeof(inner_t) * 2)>::unsigned_type>; + + auto bit_key_upper = static_cast(to_bits<0, sizeof(key.x) * 8>(key.x)); + auto bit_key_lower = static_cast(to_bits<0, sizeof(key.y) * 8>(key.y)); + + // Flip sign bit to properly order signed types + if(::rocprim::is_signed::value) { - Key l = lhs & radix_mask; - Key r = rhs & radix_mask; - return Descending ? (r < l) : (l < r); + constexpr auto sign_bit = static_cast(1) << (sizeof(inner_t) * 8 - 1); + bit_key_upper ^= sign_bit; + bit_key_lower ^= sign_bit; } -}; -template -struct key_comparator::value>::type> + // Create the result containing both parts + const auto bit_key + = (static_cast(bit_key_upper) << (8 * sizeof(unsigned_bits_type))) + | bit_key_lower; + + // The last call to to_bits mask the result to the specified bit range + return to_bits(bit_key); +} + +template::value, int> = 0> +auto to_bits(const Key key) { - using unsigned_bits_type = typename rocprim::get_unsigned_bits_type::unsigned_type; + return to_bits(key.x); +} - bool operator()(const Key& lhs, const Key& rhs) const - { - return key_comparator()( - this->to_bits(lhs), - this->to_bits(rhs)); - } +} // namespace detail - unsigned_bits_type to_bits(const Key& key) const +template +constexpr bool is_floating_nan_host(const T& a) +{ + return (a != a); +} + +template +struct key_comparator +{ + bool operator()(const Key lhs, const Key rhs) const { - unsigned_bits_type bit_key; - memcpy(&bit_key, &key, sizeof(Key)); - - // Remove signed zero, this case is supposed to be treated the same as - // unsigned zero in rocprim sorting algorithms. - constexpr unsigned_bits_type minus_zero = unsigned_bits_type{1} << (8 * sizeof(Key) - 1); - // Positive and negative zero should compare the same. - if(bit_key == minus_zero) - { - bit_key = 0; - } - // Flip bits mantissa and exponent if the key is negative, so as to make - // 'more negative' values compare before 'less negative'. - if(bit_key & minus_zero) - { - bit_key ^= ~minus_zero; - } - // Make negatives compare before positives. - bit_key ^= minus_zero; - return bit_key; + const auto l = detail::to_bits(lhs); + const auto r = detail::to_bits(rhs); + return Descending ? (r < l) : (l < r); } }; @@ -155,25 +186,34 @@ struct key_value_comparator } }; -template -struct key_comparator +template +struct custom_test_type_decomposer { - bool operator()(const rocprim::half& lhs, const rocprim::half& rhs) + static_assert(is_custom_test_type::value, + "custom_test_type_decomposer can only be used with custom_test_type"); + using inner_t = typename inner_type::type; + + __host__ __device__ auto operator()(CustomTestType& key) const { - // HIP's half doesn't have __host__ comparison operators, use floats instead - return key_comparator()(lhs, rhs); + return ::rocprim::tuple{key.x, key.y}; } }; -template -struct key_comparator +template +struct select_decomposer { - bool operator()(const rocprim::bfloat16& lhs, const rocprim::bfloat16& rhs) - { - // HIP's bfloat16 doesn't have __host__ comparison operators, use floats instead - return key_comparator()(lhs, rhs); - } + using type = ::rocprim::identity_decomposer; }; -} +template +struct select_decomposer> +{ + using type = custom_test_type_decomposer>; +}; + +template +using select_decomposer_t = typename select_decomposer::type; + +} // namespace test_utils + #endif // TEST_UTILS_SORT_COMPARATOR_HPP_ diff --git a/test/rocprim/test_utils_types.hpp b/test/rocprim/test_utils_types.hpp index fb901adc9..3a04ed224 100644 --- a/test/rocprim/test_utils_types.hpp +++ b/test/rocprim/test_utils_types.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2019-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2019-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -21,6 +21,7 @@ #ifndef TEST_TEST_UTILS_TYPES_HPP_ #define TEST_TEST_UTILS_TYPES_HPP_ +#include "rocprim/types.hpp" #include "test_utils.hpp" // required rocprim headers @@ -112,11 +113,11 @@ typedef ::testing::Types< warp_param_type(uint8_t) > WarpParamsIntegral; -typedef ::testing::Types< - warp_param_type(float), - warp_param_type(rocprim::half), - warp_param_type(rocprim::bfloat16) -> WarpParamsFloating; +typedef ::testing::Types + WarpParamsFloating; // Separate sort params (only power of two warp sizes) #define warp_sort_param_type(type, items_per_thread) \ @@ -134,10 +135,11 @@ typedef ::testing::Types< warp_sort_param_type(int8_t, 1) > WarpSortParamsIntegral; -typedef ::testing::Types< - warp_sort_param_type(rocprim::half, 1), - warp_sort_param_type(rocprim::bfloat16, 1) -> WarpSortParamsFloating; +typedef ::testing::Types + WarpSortParamsFloating; typedef ::testing::Types< warp_sort_param_type(int, 2), @@ -158,7 +160,8 @@ typedef ::testing::Types< typedef ::testing::Types), block_param_type(uint8_t, short), - block_param_type(int8_t, float) + block_param_type(int8_t, float), + block_param_type(bool, rocprim::half) #if ROCPRIM_HAS_INT128_SUPPORT , block_param_type(__uint128_t, short), @@ -167,12 +170,13 @@ typedef ::testing::Types > BlockParamsIntegralExtended; -typedef ::testing::Types< - block_param_type(float, long), - block_param_type(double, test_utils::custom_test_type), - block_param_type(rocprim::half, rocprim::half), - block_param_type(rocprim::bfloat16, rocprim::bfloat16) -> BlockParamsFloating; +typedef ::testing::Types), + block_param_type(rocprim::half, int), + block_param_type(rocprim::half, rocprim::half), + block_param_type(rocprim::bfloat16, int), + block_param_type(rocprim::bfloat16, rocprim::bfloat16)> + BlockParamsFloating; typedef ::testing::Types< block_param_type(test_utils::custom_test_type, int), @@ -180,44 +184,56 @@ typedef ::testing::Types< block_param_type(int8_t, bool) > BlockDiscParamsIntegral; -typedef ::testing::Types< - block_param_type(float, char), - block_param_type(double, unsigned int), - block_param_type(rocprim::half, int), - block_param_type(rocprim::bfloat16, int) -> BlockDiscParamsFloating; - -typedef ::testing::Types< - block_param_type(unsigned int, unsigned int) -> BlockHistAtomicParamsIntegral; - -typedef ::testing::Types< - block_param_type(float, long), - block_param_type(double, test_utils::custom_test_type), - block_param_type(rocprim::half, rocprim::half) -> BlockExchParamsFloating; - -typedef ::testing::Types< - block_param_type(float, float), - block_param_type(float, unsigned int), - block_param_type(float, unsigned long long), - block_param_type(double, float), - block_param_type(double, unsigned long long) -> BlockHistAtomicParamsFloating; - -typedef ::testing::Types< - block_param_type(int, uint8_t), - block_param_type(uint8_t, uint8_t), - block_param_type(short, uint8_t), - block_param_type(int, int8_t) -> BlockHistSortParamsIntegral; - -typedef ::testing::Types< - block_param_type(rocprim::half, unsigned short), - block_param_type(rocprim::half, unsigned int), - block_param_type(rocprim::bfloat16, unsigned short), - block_param_type(rocprim::bfloat16, unsigned int) -> BlockHistSortParamsFloating; +typedef ::testing::Types + BlockDiscParamsFloating; + +typedef ::testing::Types + BlockDiscParamsFloatingHalf; + +typedef ::testing::Types + BlockHistAtomicParamsIntegral; + +typedef ::testing::Types), + block_param_type(double, int8_t), + block_param_type(rocprim::half, rocprim::half), + block_param_type(rocprim::half, int16_t), + block_param_type(rocprim::bfloat16, rocprim::bfloat16)> + BlockExchParamsFloating; + +typedef ::testing::Types + BlockHistAtomicParamsFloating; + +typedef ::testing::Types + BlockHistSortParamsIntegral; + +typedef ::testing::Types + BlockHistSortParamsFloating; static constexpr size_t n_items = 7; static constexpr unsigned int items[n_items] = { diff --git a/test/rocprim/test_warp_exchange.cpp b/test/rocprim/test_warp_exchange.cpp index 0906b3681..636a84788 100644 --- a/test/rocprim/test_warp_exchange.cpp +++ b/test/rocprim/test_warp_exchange.cpp @@ -21,6 +21,7 @@ // SOFTWARE. #include "../common_test_header.hpp" +#include "rocprim/types.hpp" #include "test_utils.hpp" #include @@ -49,15 +50,21 @@ class WarpExchangeTest : public ::testing::Test struct BlockedToStripedOp { - template< - class T, - class warp_exchange_type, - unsigned int ItemsPerThread - > - ROCPRIM_DEVICE ROCPRIM_INLINE - void operator()(warp_exchange_type warp_exchange, - T (&thread_data)[ItemsPerThread], - typename warp_exchange_type::storage_type& storage) const + template + ROCPRIM_DEVICE ROCPRIM_INLINE void + operator()(warp_exchange_type warp_exchange, + T (&input_data)[ItemsPerThread], + T (&output_data)[ItemsPerThread], + typename warp_exchange_type::storage_type& storage) const + { + warp_exchange.blocked_to_striped(input_data, output_data, storage); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE void + operator()(warp_exchange_type warp_exchange, + T (&thread_data)[ItemsPerThread], + typename warp_exchange_type::storage_type& storage) const { warp_exchange.blocked_to_striped(thread_data, thread_data, storage); } @@ -65,15 +72,21 @@ struct BlockedToStripedOp struct BlockedToStripedShuffleOp { - template< - class T, - class warp_exchange_type, - unsigned int ItemsPerThread - > - ROCPRIM_DEVICE ROCPRIM_INLINE - void operator()(warp_exchange_type warp_exchange, - T (&thread_data)[ItemsPerThread], - typename warp_exchange_type::storage_type& /*storage*/) const + template + ROCPRIM_DEVICE ROCPRIM_INLINE void + operator()(warp_exchange_type warp_exchange, + T (&input_data)[ItemsPerThread], + T (&output_data)[ItemsPerThread], + typename warp_exchange_type::storage_type& /*storage*/) const + { + warp_exchange.blocked_to_striped_shuffle(input_data, output_data); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE void + operator()(warp_exchange_type warp_exchange, + T (&thread_data)[ItemsPerThread], + typename warp_exchange_type::storage_type& /*storage*/) const { warp_exchange.blocked_to_striped_shuffle(thread_data, thread_data); } @@ -81,15 +94,20 @@ struct BlockedToStripedShuffleOp struct StripedToBlockedOp { - template< - class T, - class warp_exchange_type, - unsigned int ItemsPerThread - > - ROCPRIM_DEVICE ROCPRIM_INLINE - void operator()(warp_exchange_type warp_exchange, - T (&thread_data)[ItemsPerThread], - typename warp_exchange_type::storage_type& storage) const + template + ROCPRIM_DEVICE ROCPRIM_INLINE void + operator()(warp_exchange_type warp_exchange, + T (&input_data)[ItemsPerThread], + T (&output_data)[ItemsPerThread], + typename warp_exchange_type::storage_type& storage) const + { + warp_exchange.striped_to_blocked(input_data, output_data, storage); + } + template + ROCPRIM_DEVICE ROCPRIM_INLINE void + operator()(warp_exchange_type warp_exchange, + T (&thread_data)[ItemsPerThread], + typename warp_exchange_type::storage_type& storage) const { warp_exchange.striped_to_blocked(thread_data, thread_data, storage); } @@ -97,15 +115,20 @@ struct StripedToBlockedOp struct StripedToBlockedShuffleOp { - template< - class T, - class warp_exchange_type, - unsigned int ItemsPerThread - > - ROCPRIM_DEVICE ROCPRIM_INLINE - void operator()(warp_exchange_type warp_exchange, - T (&thread_data)[ItemsPerThread], - typename warp_exchange_type::storage_type& /*storage*/) const + template + ROCPRIM_DEVICE ROCPRIM_INLINE void + operator()(warp_exchange_type warp_exchange, + T (&input_data)[ItemsPerThread], + T (&output_data)[ItemsPerThread], + typename warp_exchange_type::storage_type& /*storage*/) const + { + warp_exchange.striped_to_blocked_shuffle(input_data, output_data); + } + template + ROCPRIM_DEVICE ROCPRIM_INLINE void + operator()(warp_exchange_type warp_exchange, + T (&thread_data)[ItemsPerThread], + typename warp_exchange_type::storage_type& /*storage*/) const { warp_exchange.striped_to_blocked_shuffle(thread_data, thread_data); } @@ -129,32 +152,62 @@ struct ScatterToStripedOp } }; -using WarpExchangeTestParams - = ::testing::Types, - Params, - Params, - Params, - Params, - Params, - - Params, - Params, - Params, - Params, - Params, - - Params, - Params, - Params, - Params, - Params, - Params, - - Params, - Params, - Params, - Params, - Params>; +using WarpExchangeTestParams = ::testing::Types< + Params, + Params, + Params, + Params, + Params, + Params, + // half should be supported, but is missing some key operators. + // we should uncomment these, as soon as these are implemented and the tests compile and work as intended. + //Params, + Params, + + Params, + Params, + Params, + Params, + Params, + Params, + //Params, + Params, + + Params, + Params, + Params, + Params, + Params, + Params, + //Params, + Params, + + Params, + Params, + Params, + Params, + Params, + Params, + //Params, + Params, + + Params, + Params, + Params, + Params, + Params, + Params, + //Params, + Params, + + Params, + Params, + Params, + Params, + Params, + Params, + //Params, + Params>; template __device__ auto warp_exchange_test(T* d_input, T* d_output) @@ -171,7 +224,7 @@ __device__ auto warp_exchange_test(T* d_input, T* d_output) } const unsigned int warp_id = threadIdx.x / LogicalWarpSize; - Op{}(warp_exchange_type(), thread_data, storage[warp_id]); + Op{}(warp_exchange_type(), thread_data, thread_data, storage[warp_id]); for(unsigned int i = 0; i < ItemsPerThread; i++) { @@ -185,9 +238,46 @@ __device__ auto warp_exchange_test(T* /*d_input*/, T* /*d_output*/) {} template -__global__ void warp_exchange_kernel(T* d_input, T* d_output) +__device__ auto warp_exchange_test_not_inplace(T* d_input, T* d_output) + -> std::enable_if_t> +{ + using warp_exchange_type = ::rocprim::warp_exchange; + constexpr unsigned int num_warps = ::rocprim::device_warp_size() / LogicalWarpSize; + ROCPRIM_SHARED_MEMORY typename warp_exchange_type::storage_type storage[num_warps]; + + T thread_data[ItemsPerThread]; + for(unsigned int i = 0; i < ItemsPerThread; i++) + { + thread_data[i] = d_input[threadIdx.x * ItemsPerThread + i]; + } + + T output[ItemsPerThread]; + + const unsigned int warp_id = threadIdx.x / LogicalWarpSize; + Op{}(warp_exchange_type(), thread_data, output, storage[warp_id]); + + for(unsigned int i = 0; i < ItemsPerThread; i++) + { + d_output[threadIdx.x * ItemsPerThread + i] = output[i]; + } +} + +template +__device__ auto warp_exchange_test_not_inplace(T* /*d_input*/, T* /*d_output*/) + -> std::enable_if_t> +{} + +template +__global__ void warp_exchange_kernel(T* d_input, T* d_output, bool inplace = true) { - warp_exchange_test(d_input, d_output); + if(inplace) + { + warp_exchange_test(d_input, d_output); + } + else + { + warp_exchange_test_not_inplace(d_input, d_output); + } } template @@ -252,6 +342,59 @@ TYPED_TEST(WarpExchangeTest, WarpExchange) HIP_CHECK(hipFree(d_input)); HIP_CHECK(hipFree(d_output)); + if(std::is_same::value + || std::is_same::value) + { + expected = stripe_vector(expected, warp_size, items_per_thread); + } + + ASSERT_EQ(expected, output); +} + +TYPED_TEST_SUITE(WarpExchangeTest, WarpExchangeTestParams); + +TYPED_TEST(WarpExchangeTest, WarpExchangeNotInplace) +{ + using T = typename TestFixture::params::type; + constexpr unsigned int warp_size = TestFixture::params::warp_size; + constexpr unsigned int items_per_thread = TestFixture::params::items_per_thread; + using exchange_op = typename TestFixture::params::exchange_op; + + const int device_id = test_common_utils::obtain_device_from_ctest(); + SKIP_IF_UNSUPPORTED_WARP_SIZE(warp_size, device_id); + + unsigned int hw_warp_size; + HIP_CHECK(::rocprim::host_warp_size(device_id, hw_warp_size)); + const unsigned int block_size = hw_warp_size; + const unsigned int items_count = items_per_thread * block_size; + + std::vector input(items_count); + std::iota(input.begin(), input.end(), static_cast(0)); + auto expected = input; + if(std::is_same::value + || std::is_same::value) + { + input = stripe_vector(input, warp_size, items_per_thread); + } + + T* d_input{}; + HIP_CHECK(hipMalloc(&d_input, items_count * sizeof(T))); + HIP_CHECK(hipMemcpy(d_input, input.data(), items_count * sizeof(T), hipMemcpyHostToDevice)); + T* d_output{}; + HIP_CHECK(hipMalloc(&d_output, items_count * sizeof(T))); + HIP_CHECK(hipMemset(d_output, 0, items_count * sizeof(T))); + + warp_exchange_kernel + <<>>(d_input, d_output, false); + HIP_CHECK(hipGetLastError()); + HIP_CHECK(hipDeviceSynchronize()); + + std::vector output(items_count); + HIP_CHECK(hipMemcpy(output.data(), d_output, items_count * sizeof(T), hipMemcpyDeviceToHost)); + + HIP_CHECK(hipFree(d_input)); + HIP_CHECK(hipFree(d_output)); + if(std::is_same::value || std::is_same::value) { diff --git a/test/rocprim/test_warp_load.cpp b/test/rocprim/test_warp_load.cpp index c00e0501d..2adc6ffa9 100644 --- a/test/rocprim/test_warp_load.cpp +++ b/test/rocprim/test_warp_load.cpp @@ -21,6 +21,7 @@ // SOFTWARE. #include "../common_test_header.hpp" +#include "rocprim/types.hpp" #include "test_utils.hpp" #include @@ -48,11 +49,6 @@ class WarpLoadTest : public ::testing::Test }; using WarpLoadTestParams = ::testing::Types< - Params, - Params, - Params, - Params, - Params, Params, Params, @@ -76,8 +72,39 @@ using WarpLoadTestParams = ::testing::Types< Params, Params, Params, - Params ->; + Params, + + Params, + Params, + Params, + Params, + + Params, + Params, + Params, + Params, + + Params, + Params, + Params, + Params, + + Params, + Params, + Params, + Params, + + // half should be supported, but is missing some key operators. + // we should uncomment these, as soon as these are implemented and the tests compile and work as intended. + //Params, + //Params, + //Params, + //Params, + + Params, + Params, + Params, + Params>; template::type; binary_op_type_host binary_op_host; using acc_type = typename test_utils::select_plus_operator_host::acc_type; - using cast_type = typename test_utils::select_plus_operator_host::cast_type; static constexpr size_t logical_warp_size = TestFixture::params::warp_size; @@ -91,7 +90,7 @@ typed_test_def(RocprimWarpReduceTests, name_suffix, ReduceSum) auto idx = i * logical_warp_size + j; value = binary_op_host(input[idx], value); } - expected[i] = static_cast(value); + expected[i] = static_cast(value); } T* device_input; @@ -156,7 +155,6 @@ typed_test_def(RocprimWarpReduceTests, name_suffix, AllReduceSum) using binary_op_type_host = typename test_utils::select_plus_operator_host::type; binary_op_type_host binary_op_host; using acc_type = typename test_utils::select_plus_operator_host::acc_type; - using cast_type = typename test_utils::select_plus_operator_host::cast_type; static constexpr size_t logical_warp_size = TestFixture::params::warp_size; @@ -214,7 +212,7 @@ typed_test_def(RocprimWarpReduceTests, name_suffix, AllReduceSum) for (size_t j = 0; j < logical_warp_size; j++) { auto idx = i * logical_warp_size + j; - expected[idx] = static_cast(value); + expected[idx] = static_cast(value); } } @@ -280,7 +278,6 @@ typed_test_def(RocprimWarpReduceTests, name_suffix, ReduceSumValid) using binary_op_type_host = typename test_utils::select_plus_operator_host::type; binary_op_type_host binary_op_host; using acc_type = typename test_utils::select_plus_operator_host::acc_type; - using cast_type = typename test_utils::select_plus_operator_host::cast_type; static constexpr size_t logical_warp_size = TestFixture::params::warp_size; @@ -336,7 +333,7 @@ typed_test_def(RocprimWarpReduceTests, name_suffix, ReduceSumValid) auto idx = i * logical_warp_size + j; value = binary_op_host(input[idx], value); } - expected[i] = static_cast(value); + expected[i] = static_cast(value); } T* device_input; @@ -402,7 +399,6 @@ typed_test_def(RocprimWarpReduceTests, name_suffix, AllReduceSumValid) using binary_op_type_host = typename test_utils::select_plus_operator_host::type; binary_op_type_host binary_op_host; using acc_type = typename test_utils::select_plus_operator_host::acc_type; - using cast_type = typename test_utils::select_plus_operator_host::cast_type; static constexpr size_t logical_warp_size = TestFixture::params::warp_size; @@ -461,7 +457,7 @@ typed_test_def(RocprimWarpReduceTests, name_suffix, AllReduceSumValid) for (size_t j = 0; j < logical_warp_size; j++) { auto idx = i * logical_warp_size + j; - expected[idx] = static_cast(value); + expected[idx] = static_cast(value); } } @@ -524,8 +520,7 @@ typed_test_def(RocprimWarpReduceTests, name_suffix, ReduceCustomStruct) using base_type = typename TestFixture::params::type; using T = test_utils::custom_test_type; - using acc_type = typename test_utils::select_plus_operator_host::acc_type; - using cast_type = typename test_utils::select_plus_operator_host::cast_type; + using acc_type = typename test_utils::select_plus_operator_host::acc_type; // logical warp side for warp primitive, execution warp size is always rocprim::warp_size() static constexpr size_t logical_warp_size = TestFixture::params::warp_size; @@ -590,7 +585,7 @@ typed_test_def(RocprimWarpReduceTests, name_suffix, ReduceCustomStruct) auto idx = i * logical_warp_size + j; value = value + static_cast>(input[idx]); } - expected[i] = static_cast>(value); + expected[i] = static_cast(value); } T* device_input; @@ -658,7 +653,6 @@ typed_test_def(RocprimWarpReduceTests, name_suffix, HeadSegmentedReduceSum) using binary_op_type_host = typename test_utils::select_plus_operator_host::type; binary_op_type_host binary_op_host; using acc_type = typename test_utils::select_plus_operator_host::acc_type; - using cast_type = typename test_utils::select_plus_operator_host::cast_type; using flag_type = unsigned char; static constexpr size_t logical_warp_size = TestFixture::params::warp_size; @@ -739,7 +733,7 @@ typed_test_def(RocprimWarpReduceTests, name_suffix, HeadSegmentedReduceSum) { if(i%logical_warp_size == 0 || flags[i]) { - expected[segment_head_index] = static_cast(reduction); + expected[segment_head_index] = static_cast(reduction); segment_head_index = i; reduction = input[i]; } @@ -748,7 +742,7 @@ typed_test_def(RocprimWarpReduceTests, name_suffix, HeadSegmentedReduceSum) reduction = binary_op_host(input[i], reduction); } } - expected[segment_head_index] = static_cast(reduction); + expected[segment_head_index] = static_cast(reduction); // Launching kernel if (current_device_warp_size == ws32) @@ -814,7 +808,6 @@ typed_test_def(RocprimWarpReduceTests, name_suffix, TailSegmentedReduceSum) using binary_op_type_host = typename test_utils::select_plus_operator_host::type; binary_op_type_host binary_op_host; using acc_type = typename test_utils::select_plus_operator_host::acc_type; - using cast_type = typename test_utils::select_plus_operator_host::cast_type; using flag_type = unsigned char; static constexpr size_t logical_warp_size = TestFixture::params::warp_size; @@ -912,8 +905,7 @@ typed_test_def(RocprimWarpReduceTests, name_suffix, TailSegmentedReduceSum) next++; } i++; - expected[segment_index] - = static_cast(binary_op_host(reduction, input[i])); + expected[segment_index] = static_cast(binary_op_host(reduction, input[i])); segment_indexes.push_back(segment_index); } } diff --git a/test/rocprim/test_warp_scan.hpp b/test/rocprim/test_warp_scan.hpp index 99d82f33c..69d45506a 100644 --- a/test/rocprim/test_warp_scan.hpp +++ b/test/rocprim/test_warp_scan.hpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -35,7 +35,6 @@ typed_test_def(RocprimWarpScanTests, name_suffix, InclusiveScan) using binary_op_type_host = typename test_utils::select_plus_operator_host::type; binary_op_type_host binary_op_host; using acc_type = typename test_utils::select_plus_operator_host::acc_type; - using cast_type = typename test_utils::select_plus_operator_host::cast_type; // logical warp side for warp primitive, execution warp size is always rocprim::warp_size() static constexpr size_t logical_warp_size = TestFixture::params::warp_size; @@ -90,7 +89,7 @@ typed_test_def(RocprimWarpScanTests, name_suffix, InclusiveScan) { auto idx = i * logical_warp_size + j; accumulator = binary_op_host(input[idx], accumulator); - expected[idx] = static_cast(accumulator); + expected[idx] = static_cast(accumulator); } } @@ -158,7 +157,6 @@ typed_test_def(RocprimWarpScanTests, name_suffix, InclusiveScanReduce) using binary_op_type_host = typename test_utils::select_plus_operator_host::type; binary_op_type_host binary_op_host; using acc_type = typename test_utils::select_plus_operator_host::acc_type; - using cast_type = typename test_utils::select_plus_operator_host::cast_type; // logical warp side for warp primitive, execution warp size is always rocprim::warp_size() static constexpr size_t logical_warp_size = TestFixture::params::warp_size; @@ -215,7 +213,7 @@ typed_test_def(RocprimWarpScanTests, name_suffix, InclusiveScanReduce) { auto idx = i * logical_warp_size + j; accumulator = binary_op_host(input[idx],accumulator); - expected[idx] = static_cast(accumulator); + expected[idx] = static_cast(accumulator); } expected_reductions[i] = expected[(i+1) * logical_warp_size - 1]; } @@ -303,7 +301,6 @@ typed_test_def(RocprimWarpScanTests, name_suffix, ExclusiveScan) using binary_op_type_host = typename test_utils::select_plus_operator_host::type; binary_op_type_host binary_op_host; using acc_type = typename test_utils::select_plus_operator_host::acc_type; - using cast_type = typename test_utils::select_plus_operator_host::cast_type; // logical warp side for warp primitive, execution warp size is always rocprim::warp_size() static constexpr size_t logical_warp_size = TestFixture::params::warp_size; @@ -360,7 +357,7 @@ typed_test_def(RocprimWarpScanTests, name_suffix, ExclusiveScan) { auto idx = i * logical_warp_size + j; accumulator = binary_op_host(input[idx-1], accumulator); - expected[idx] = static_cast(accumulator); + expected[idx] = static_cast(accumulator); } } @@ -427,8 +424,7 @@ typed_test_def(RocprimWarpScanTests, name_suffix, ExclusiveScanWoInit) // for bfloat16 and half we use double for host-side accumulation using binary_op_type_host = typename test_utils::select_plus_operator_host::type; binary_op_type_host binary_op_host; - using acc_type = typename test_utils::select_plus_operator_host::acc_type; - using cast_type = typename test_utils::select_plus_operator_host::cast_type; + using acc_type = typename test_utils::select_plus_operator_host::acc_type; // logical warp side for warp primitive, execution warp size is always rocprim::warp_size() static constexpr size_t logical_warp_size = TestFixture::params::warp_size; @@ -487,13 +483,13 @@ typed_test_def(RocprimWarpScanTests, name_suffix, ExclusiveScanWoInit) acc_type accumulator(input[i * logical_warp_size]); static_assert(logical_warp_size > 2, "logical_warp_size assumed to be at least 2."); - expected[i * logical_warp_size + 1] = static_cast(accumulator); + expected[i * logical_warp_size + 1] = static_cast(accumulator); for(size_t j = 2; j < logical_warp_size; j++) { auto idx = i * logical_warp_size + j; accumulator = binary_op_host(input[idx - 1], accumulator); - expected[idx] = static_cast(accumulator); + expected[idx] = static_cast(accumulator); } } @@ -572,7 +568,6 @@ typed_test_def(RocprimWarpScanTests, name_suffix, ExclusiveReduceScan) using binary_op_type_host = typename test_utils::select_plus_operator_host::type; binary_op_type_host binary_op_host; using acc_type = typename test_utils::select_plus_operator_host::acc_type; - using cast_type = typename test_utils::select_plus_operator_host::cast_type; // logical warp side for warp primitive, execution warp size is always rocprim::warp_size() static constexpr size_t logical_warp_size = TestFixture::params::warp_size; @@ -631,7 +626,7 @@ typed_test_def(RocprimWarpScanTests, name_suffix, ExclusiveReduceScan) { auto idx = i * logical_warp_size + j; accumulator = binary_op_host(input[idx-1], accumulator); - expected[idx] = static_cast(accumulator); + expected[idx] = static_cast(accumulator); } acc_type accumulator_reductions(0); @@ -639,7 +634,7 @@ typed_test_def(RocprimWarpScanTests, name_suffix, ExclusiveReduceScan) { auto idx = i * logical_warp_size + j; accumulator_reductions = binary_op_host(input[idx], accumulator_reductions); - expected_reductions[i] = static_cast(accumulator_reductions); + expected_reductions[i] = static_cast(accumulator_reductions); } } @@ -724,8 +719,7 @@ typed_test_def(RocprimWarpScanTests, name_suffix, ExclusiveReduceScanWoInit) // for bfloat16 and half we use double for host-side accumulation using binary_op_type_host = typename test_utils::select_plus_operator_host::type; binary_op_type_host binary_op_host; - using acc_type = typename test_utils::select_plus_operator_host::acc_type; - using cast_type = typename test_utils::select_plus_operator_host::cast_type; + using acc_type = typename test_utils::select_plus_operator_host::acc_type; // logical warp side for warp primitive, execution warp size is always rocprim::warp_size() static constexpr size_t logical_warp_size = TestFixture::params::warp_size; @@ -786,13 +780,13 @@ typed_test_def(RocprimWarpScanTests, name_suffix, ExclusiveReduceScanWoInit) acc_type accumulator(input[i * logical_warp_size]); static_assert(logical_warp_size > 2, "logical_warp_size assumed to be at least 2."); - expected[i * logical_warp_size + 1] = static_cast(accumulator); + expected[i * logical_warp_size + 1] = static_cast(accumulator); for(size_t j = 2; j < logical_warp_size; j++) { auto idx = i * logical_warp_size + j; accumulator = binary_op_host(input[idx - 1], accumulator); - expected[idx] = static_cast(accumulator); + expected[idx] = static_cast(accumulator); } acc_type accumulator_reductions(0); @@ -800,7 +794,7 @@ typed_test_def(RocprimWarpScanTests, name_suffix, ExclusiveReduceScanWoInit) { auto idx = i * logical_warp_size + j; accumulator_reductions = binary_op_host(input[idx], accumulator_reductions); - expected_reductions[i] = static_cast(accumulator_reductions); + expected_reductions[i] = static_cast(accumulator_reductions); } } @@ -894,7 +888,6 @@ typed_test_def(RocprimWarpScanTests, name_suffix, Scan) using binary_op_type_host = typename test_utils::select_plus_operator_host::type; binary_op_type_host binary_op_host; using acc_type = typename test_utils::select_plus_operator_host::acc_type; - using cast_type = typename test_utils::select_plus_operator_host::cast_type; // logical warp side for warp primitive, execution warp size is always rocprim::warp_size() static constexpr size_t logical_warp_size = TestFixture::params::warp_size; @@ -954,11 +947,11 @@ typed_test_def(RocprimWarpScanTests, name_suffix, Scan) { auto idx = i * logical_warp_size + j; accumulator_inclusive = binary_op_host(input[idx], accumulator_inclusive); - expected_inclusive[idx] = static_cast(accumulator_inclusive); + expected_inclusive[idx] = static_cast(accumulator_inclusive); if(j > 0) { accumulator_exclusive = binary_op_host(input[idx-1], accumulator_exclusive); - expected_exclusive[idx] = static_cast(accumulator_exclusive); + expected_exclusive[idx] = static_cast(accumulator_exclusive); } } } @@ -1053,7 +1046,6 @@ typed_test_def(RocprimWarpScanTests, name_suffix, ScanReduce) using binary_op_type_host = typename test_utils::select_plus_operator_host::type; binary_op_type_host binary_op_host; using acc_type = typename test_utils::select_plus_operator_host::acc_type; - using cast_type = typename test_utils::select_plus_operator_host::cast_type; // logical warp side for warp primitive, execution warp size is always rocprim::warp_size() static constexpr size_t logical_warp_size = TestFixture::params::warp_size; @@ -1115,11 +1107,11 @@ typed_test_def(RocprimWarpScanTests, name_suffix, ScanReduce) { auto idx = i * logical_warp_size + j; accumulator_inclusive = binary_op_host(input[idx], accumulator_inclusive); - expected_inclusive[idx] = static_cast(accumulator_inclusive); + expected_inclusive[idx] = static_cast(accumulator_inclusive); if(j > 0) { accumulator_exclusive = binary_op_host(input[idx-1], accumulator_exclusive); - expected_exclusive[idx] = static_cast(accumulator_exclusive); + expected_exclusive[idx] = static_cast(accumulator_exclusive); } } expected_reductions[i] = expected_inclusive[(i+1) * logical_warp_size - 1]; @@ -1232,8 +1224,7 @@ typed_test_def(RocprimWarpScanTests, name_suffix, InclusiveScanCustomType) using base_type = typename TestFixture::params::type; using T = test_utils::custom_test_type; - using acc_type = typename test_utils::select_plus_operator_host::acc_type; - using cast_type = typename test_utils::select_plus_operator_host::cast_type; + using acc_type = typename test_utils::select_plus_operator_host::acc_type; // logical warp side for warp primitive, execution warp size is always rocprim::warp_size() static constexpr size_t logical_warp_size = TestFixture::params::warp_size; @@ -1298,7 +1289,7 @@ typed_test_def(RocprimWarpScanTests, name_suffix, InclusiveScanCustomType) { auto idx = i * logical_warp_size + j; accumulator = static_cast>(input[idx]) + accumulator; - expected[idx] = static_cast>(accumulator); + expected[idx] = static_cast(accumulator); } } diff --git a/test/rocprim/test_warp_store.cpp b/test/rocprim/test_warp_store.cpp index 63aab5355..987177f04 100644 --- a/test/rocprim/test_warp_store.cpp +++ b/test/rocprim/test_warp_store.cpp @@ -21,6 +21,7 @@ // SOFTWARE. #include "../common_test_header.hpp" +#include "rocprim/types.hpp" #include "test_utils.hpp" #include @@ -48,11 +49,6 @@ class WarpStoreTest : public ::testing::Test }; using WarpStoreTestParams = ::testing::Types< - Params, - Params, - Params, - Params, - Params, Params, Params, @@ -76,8 +72,39 @@ using WarpStoreTestParams = ::testing::Types< Params, Params, Params, - Params ->; + Params, + + Params, + Params, + Params, + Params, + + Params, + Params, + Params, + Params, + + Params, + Params, + Params, + Params, + + Params, + Params, + Params, + Params, + + // half should be supported, but is missing some key operators. + // we should uncomment these, as soon as these are implemented and the tests compile and work as intended. + //Params, + //Params, + //Params, + //Params, + + Params, + Params, + Params, + Params>; template