Skip to content

Commit

Permalink
Review adaptations
Browse files Browse the repository at this point in the history
  • Loading branch information
NB4444 authored and Naraenda committed Jul 18, 2024
1 parent 711624a commit 608a043
Show file tree
Hide file tree
Showing 5 changed files with 393 additions and 188 deletions.
7 changes: 3 additions & 4 deletions benchmark/benchmark_device_partial_sort.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,9 @@ struct device_partial_sort_benchmark : public config_autotune_interface
size * sizeof(*d_keys_input),
hipMemcpyHostToDevice));

::rocprim::less<key_type> lesser_op;

void* d_temporary_storage = nullptr;
size_t temporary_storage_bytes = 0;
rocprim::less<key_type> lesser_op;
void* d_temporary_storage = nullptr;
size_t temporary_storage_bytes = 0;
HIP_CHECK(rocprim::partial_sort_copy(d_temporary_storage,
temporary_storage_bytes,
d_keys_input,
Expand Down
2 changes: 1 addition & 1 deletion docs/reference/ops_summary.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ Rearrangement
================

* ``sort`` rearranges the sequence by sorting it. It could be according to a comparison operator or a value using a radix approach
* ``partial_sort`` rearranges the sequence by sorting it up to and including the middle index, according to a comparison operator.
* ``partial_sort`` rearranges the sequence by sorting it up to and including a given index, according to a comparison operator.
* ``nth_element`` places the nth element in its sorted position, with elements less-than before, and greater after, according to a comparison operator.
* ``exchange`` rearranges the elements according to a different stride configuration which is equivalent to a tensor axis transposition
* ``shuffle`` rotates the elements
Expand Down
80 changes: 40 additions & 40 deletions rocprim/include/rocprim/device/device_nth_element.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,28 +43,29 @@ namespace detail
{

template<class Config, class KeysIterator, class BinaryFunction>
ROCPRIM_INLINE hipError_t nth_element_impl(
void* temporary_storage,
size_t& storage_size,
KeysIterator keys,
size_t nth,
size_t size,
BinaryFunction compare_function,
hipStream_t stream,
bool debug_synchronous,
typename std::iterator_traits<KeysIterator>::value_type* keys_double_buffer = nullptr)
ROCPRIM_INLINE
hipError_t
nth_element_impl(void* temporary_storage,
size_t& storage_size,
KeysIterator keys,
size_t nth,
size_t size,
BinaryFunction compare_function,
hipStream_t stream,
bool debug_synchronous,
typename std::iterator_traits<KeysIterator>::value_type* keys_double_buffer
= nullptr)
{
using key_type = typename std::iterator_traits<KeysIterator>::value_type;
using config = wrapped_nth_element_config<Config, key_type>;

detail::target_arch target_arch;
hipError_t result = host_target_arch(stream, target_arch);
target_arch target_arch;
hipError_t result = host_target_arch(stream, target_arch);
if(result != hipSuccess)
{
return result;
}
const detail::nth_element_config_params params
= detail::dispatch_target_arch<config>(target_arch);
const nth_element_config_params params = dispatch_target_arch<config>(target_arch);

constexpr unsigned int num_partitions = 3;
const unsigned int num_buckets = params.number_of_buckets;
Expand All @@ -73,19 +74,18 @@ ROCPRIM_INLINE hipError_t nth_element_impl(
const unsigned int num_items_per_threads = params.kernel_config.items_per_thread;
const unsigned int num_threads_per_block = params.kernel_config.block_size;
const unsigned int num_items_per_block = num_threads_per_block * num_items_per_threads;
const unsigned int num_blocks = detail::ceiling_div(size, num_items_per_block);
const unsigned int num_blocks = ceiling_div(size, num_items_per_block);

// oracles stores the bucket that correlates with the index
key_type* tree = nullptr;
size_t* buckets = nullptr;
detail::n_th_element_iteration_data* nth_element_data = nullptr;
bool* equality_buckets = nullptr;
detail::nth_element_onesweep_lookback_state* lookback_states = nullptr;
key_type* tree = nullptr;
size_t* buckets = nullptr;
n_th_element_iteration_data* nth_element_data = nullptr;
bool* equality_buckets = nullptr;
nth_element_onesweep_lookback_state* lookback_states = nullptr;

key_type* keys_buffer = nullptr;

{
using namespace detail::temp_storage;
using namespace temp_storage;

hipError_t partition_result;
if(keys_double_buffer == nullptr)
Expand Down Expand Up @@ -141,22 +141,22 @@ ROCPRIM_INLINE hipError_t nth_element_impl(
std::cout << "storage_size: " << storage_size << '\n';
}

return detail::nth_element_keys_impl<config, num_partitions>(keys,
keys_buffer,
tree,
nth,
size,
buckets,
equality_buckets,
lookback_states,
num_buckets,
stop_recursion_size,
num_threads_per_block,
num_items_per_threads,
nth_element_data,
compare_function,
stream,
debug_synchronous);
return nth_element_keys_impl<config, num_partitions>(keys,
keys_buffer,
tree,
nth,
size,
buckets,
equality_buckets,
lookback_states,
num_buckets,
stop_recursion_size,
num_threads_per_block,
num_items_per_threads,
nth_element_data,
compare_function,
stream,
debug_synchronous);
}

} // namespace detail
Expand Down Expand Up @@ -195,7 +195,7 @@ ROCPRIM_INLINE hipError_t nth_element_impl(
/// The signature of the function should be equivalent to the following:
/// <tt>bool f(const T &a, const T &b);</tt>. The signature does not need to have
/// <tt>const &</tt>, but function object must not modify the objects passed to it.
/// The comperator must meet the C++ named requirement Compare.
/// The comparator must meet the C++ named requirement Compare.
/// The default value is `BinaryFunction()`.
/// \param [in] stream [optional] HIP stream object. Default is `0` (default stream).
/// \param [in] debug_synchronous [optional] If true, synchronization after every kernel
Expand Down Expand Up @@ -295,7 +295,7 @@ ROCPRIM_INLINE hipError_t nth_element(void* temporary_storage,
/// The signature of the function should be equivalent to the following:
/// <tt>bool f(const T &a, const T &b);</tt>. The signature does not need to have
/// <tt>const &</tt>, but function object must not modify the objects passed to it.
/// The comperator must meet the C++ named requirement Compare.
/// The comparator must meet the C++ named requirement Compare.
/// The default value is `BinaryFunction()`.
/// \param [in] stream [optional] HIP stream object. Default is `0` (default stream).
/// \param [in] debug_synchronous [optional] If true, synchronization after every kernel
Expand Down
Loading

0 comments on commit 608a043

Please sign in to comment.