Skip to content

Commit

Permalink
Fixed bug with inplicit casting in partial sort
Browse files Browse the repository at this point in the history
  • Loading branch information
NB4444 authored and Naraenda committed Jul 18, 2024
1 parent b2c60bb commit 88cb742
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 45 deletions.
125 changes: 82 additions & 43 deletions rocprim/include/rocprim/device/device_partial_sort.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,18 +58,22 @@ namespace detail
} \
while(0)

template<class Config, class KeysIterator, class BinaryFunction, bool inplace = true>
hipError_t partial_sort_impl(void* temporary_storage,
size_t& storage_size,
KeysIterator keys_in,
KeysIterator keys_out,
size_t middle,
size_t size,
BinaryFunction compare_function,
hipStream_t stream,
bool debug_synchronous)
template<class Config,
class KeysInputIterator,
class KeysOutputIterator,
class BinaryFunction,
bool inplace = true>
hipError_t partial_sort_impl(void* temporary_storage,
size_t& storage_size,
KeysInputIterator keys_in,
KeysOutputIterator keys_out,
size_t middle,
size_t size,
BinaryFunction compare_function,
hipStream_t stream,
bool debug_synchronous)
{
using key_type = typename std::iterator_traits<KeysIterator>::value_type;
using key_type = typename std::iterator_traits<KeysInputIterator>::value_type;
using config = default_or_custom_config<Config, detail::default_partial_sort_config<key_type>>;
using config_merge_sort = typename config::merge_sort;
using config_nth_element = typename config::nth_element;
Expand Down Expand Up @@ -148,16 +152,30 @@ hipError_t partial_sort_impl(void* temporary_storage,

if(!full_sort)
{
RETURN_ON_ERROR(
nth_element_impl<config_nth_element>(temporary_storage_nth_element,
storage_size_nth_element,
inplace ? keys_in : keys_output_nth_element,
middle,
size,
compare_function,
stream,
debug_synchronous,
keys_buffer));
if(inplace)
{
RETURN_ON_ERROR(nth_element_impl<config_nth_element>(temporary_storage_nth_element,
storage_size_nth_element,
keys_in,
middle,
size,
compare_function,
stream,
debug_synchronous,
keys_buffer));
}
else
{
RETURN_ON_ERROR(nth_element_impl<config_nth_element>(temporary_storage_nth_element,
storage_size_nth_element,
keys_output_nth_element,
middle,
size,
compare_function,
stream,
debug_synchronous,
keys_buffer));
}
}

if(middle == 0)
Expand All @@ -174,18 +192,38 @@ hipError_t partial_sort_impl(void* temporary_storage,
return hipSuccess;
}

return merge_sort_impl<config_merge_sort>(temporary_storage_merge_sort,
storage_size_merge_sort,
inplace ? keys_in : keys_output_nth_element,
keys_out,
static_cast<empty_type*>(nullptr), // values_input
static_cast<empty_type*>(nullptr), // values_output
(!inplace || full_sort) ? middle + 1 : middle,
compare_function,
stream,
debug_synchronous,
keys_buffer, // keys_buffer
static_cast<empty_type*>(nullptr)); // values_buffer
if(inplace)
{
return merge_sort_impl<config_merge_sort>(
temporary_storage_merge_sort,
storage_size_merge_sort,
keys_in,
keys_out,
static_cast<empty_type*>(nullptr), // values_input
static_cast<empty_type*>(nullptr), // values_output
full_sort ? middle + 1 : middle,
compare_function,
stream,
debug_synchronous,
keys_buffer, // keys_buffer
static_cast<empty_type*>(nullptr)); // values_buffer
}
else
{
return merge_sort_impl<config_merge_sort>(
temporary_storage_merge_sort,
storage_size_merge_sort,
keys_output_nth_element,
keys_out,
static_cast<empty_type*>(nullptr), // values_input
static_cast<empty_type*>(nullptr), // values_output
middle + 1,
compare_function,
stream,
debug_synchronous,
keys_buffer, // keys_buffer
static_cast<empty_type*>(nullptr)); // values_buffer
}
}

} // namespace detail
Expand Down Expand Up @@ -283,16 +321,17 @@ hipError_t partial_sort_copy(void* temporary_storage,
typename std::iterator_traits<KeysOutputIterator>::value_type>::value,
"KeysInputIterator and KeysOutputIterator must have the same value_type");

return detail::partial_sort_impl<Config, KeysInputIterator, BinaryFunction, false>(
temporary_storage,
storage_size,
keys_input,
keys_output,
middle,
size,
compare_function,
stream,
debug_synchronous);
return detail::
partial_sort_impl<Config, KeysInputIterator, KeysOutputIterator, BinaryFunction, false>(
temporary_storage,
storage_size,
keys_input,
keys_output,
middle,
size,
compare_function,
stream,
debug_synchronous);
}

/// \brief Rearranges elements such that the range [0, middle) contains the sorted middle smallest elements in the range [0, size).
Expand Down
6 changes: 4 additions & 2 deletions test/rocprim/indirect_iterator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,13 @@ class indirect_iterator

using iterator_category = std::random_access_iterator_tag;

ROCPRIM_HOST_DEVICE inline indirect_iterator(T* ptr) : ptr_(ptr) {}
ROCPRIM_HOST_DEVICE inline explicit indirect_iterator(T* ptr) : ptr_(ptr) {}

ROCPRIM_HOST_DEVICE inline ~indirect_iterator() = default;

ROCPRIM_HOST_DEVICE inline indirect_iterator& operator++()
ROCPRIM_HOST_DEVICE
inline indirect_iterator&
operator++()
{
++ptr_;
return *this;
Expand Down

0 comments on commit 88cb742

Please sign in to comment.