Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Commit

Permalink
In-place select
Browse files Browse the repository at this point in the history
  • Loading branch information
gevtushenko committed Jun 4, 2022
1 parent ac7c06b commit 7896eb4
Show file tree
Hide file tree
Showing 3 changed files with 511 additions and 3 deletions.
286 changes: 285 additions & 1 deletion cub/device/device_select.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,9 @@ struct DeviceSelect
* `char`, `int`, etc.).
* - Copies of the selected items are compacted into `d_out` and maintain
* their original relative ordering.
* - The range `[d_out, d_out + *d_num_selected_out)` shall not overlap
* `[d_in, d_in + num_items)`, `[d_flags, d_flags + num_items)` nor
* `d_num_selected_out` in any way.
* - @devicestorage
*
* @par Snippet
Expand Down Expand Up @@ -215,6 +218,137 @@ struct DeviceSelect
stream,
debug_synchronous);
}

/**
* @brief Uses the `d_flags` sequence to selectively compact the items in
* `d_data`. The total number of items selected is written to
* `d_num_selected_out`. ![](select_flags_logo.png)
*
* @par
* - The value type of `d_flags` must be castable to `bool` (e.g., `bool`,
* `char`, `int`, etc.).
* - Copies of the selected items are compacted in-place and maintain
* their original relative ordering.
* - The `d_data` may equal `d_flags`. The range
* `[d_data, d_data + num_items)` shall not overlap
* `[d_flags, d_flags + num_items)` in any other way.
* - @devicestorage
*
* @par Snippet
* The code snippet below illustrates the compaction of items selected from
* an `int` device vector.
* @par
* @code
* #include <cub/cub.cuh> // or equivalently <cub/device/device_select.cuh>
*
* // Declare, allocate, and initialize device-accessible pointers for input,
* // flags, and output
* int num_items; // e.g., 8
* int *d_data; // e.g., [1, 2, 3, 4, 5, 6, 7, 8]
* char *d_flags; // e.g., [1, 0, 0, 1, 0, 1, 1, 0]
* int *d_num_selected_out; // e.g., [ ]
* ...
*
* // Determine temporary device storage requirements
* void *d_temp_storage = NULL;
* size_t temp_storage_bytes = 0;
* cub::DeviceSelect::Flagged(
* d_temp_storage, temp_storage_bytes,
* d_in, d_flags, d_num_selected_out, num_items);
*
* // Allocate temporary storage
* cudaMalloc(&d_temp_storage, temp_storage_bytes);
*
* // Run selection
* cub::DeviceSelect::Flagged(
* d_temp_storage, temp_storage_bytes,
* d_in, d_flags, d_num_selected_out, num_items);
*
* // d_data <-- [1, 4, 6, 7]
* // d_num_selected_out <-- [4]
*
* @endcode
*
* @tparam IteratorT
* **[inferred]** Random-access iterator type for reading and writing
* selected items \iterator
*
* @tparam FlagIterator
* **[inferred]** Random-access input iterator type for reading selection
* flags \iterator
*
* @tparam NumSelectedIteratorT
* **[inferred]** Output iterator type for recording the number of items
* selected \iterator
*
* @param[in] d_temp_storage
* Device-accessible allocation of temporary storage. When `nullptr`, the
* required allocation size is written to `temp_storage_bytes` and no work
* is done.
*
* @param[in,out] temp_storage_bytes
* Reference to size in bytes of `d_temp_storage` allocation
*
* @param[in,out] d_data
* Pointer to the sequence of data items
*
* @param[in] d_flags
* Pointer to the input sequence of selection flags
*
* @param[out] d_num_selected_out
* Pointer to the output total number of items selected
*
* @param[in] num_items
* Total number of input items (i.e., length of `d_data`)
*
* @param[in] stream
* **[optional]** CUDA stream to launch kernels within.
* Default is stream<sub>0</sub>.
*
* @param[in] debug_synchronous
* **[optional]** Whether or not to synchronize the stream after every
* kernel launch to check for errors. May cause significant slowdown.
* Default is `false`.
*/
template <typename IteratorT,
typename FlagIterator,
typename NumSelectedIteratorT>
CUB_RUNTIME_FUNCTION __forceinline__ static cudaError_t
Flagged(void *d_temp_storage,
size_t &temp_storage_bytes,
IteratorT d_data,
FlagIterator d_flags,
NumSelectedIteratorT d_num_selected_out,
int num_items,
cudaStream_t stream = 0,
bool debug_synchronous = false)
{
using OffsetT = int; // Signed integer type for global offsets
using SelectOp = NullType; // Selection op (not used)
using EqualityOp = NullType; // Equality operator (not used)

constexpr bool may_alias = true;

return DispatchSelectIf<IteratorT,
FlagIterator,
IteratorT,
NumSelectedIteratorT,
SelectOp,
EqualityOp,
OffsetT,
false,
may_alias>::Dispatch(d_temp_storage,
temp_storage_bytes,
d_data, // in
d_flags,
d_data, // out
d_num_selected_out,
SelectOp(),
EqualityOp(),
num_items,
stream,
debug_synchronous);
}

/**
* @brief Uses the `select_op` functor to selectively copy items from `d_in`
Expand All @@ -224,6 +358,8 @@ struct DeviceSelect
* @par
* - Copies of the selected items are compacted into `d_out` and maintain
* their original relative ordering.
* - The range `[d_out, d_out + *d_num_selected_out)` shall not overlap
* `[d_in, d_in + num_items)` nor `d_num_selected_out` in any way.
* - @devicestorage
*
* @par Performance
Expand Down Expand Up @@ -377,6 +513,145 @@ struct DeviceSelect
debug_synchronous);
}

/**
* @brief Uses the `select_op` functor to selectively compact items in
* `d_data`. The total number of items selected is written to
* `d_num_selected_out`. ![](select_logo.png)
*
* @par
* - Copies of the selected items are compacted in `d_data` and maintain
* their original relative ordering.
* - @devicestorage
*
* @par Snippet
* The code snippet below illustrates the compaction of items selected from
* an `int` device vector.
* @par
* @code
* #include <cub/cub.cuh> // or equivalently <cub/device/device_select.cuh>
*
* // Functor type for selecting values less than some criteria
* struct LessThan
* {
* int compare;
*
* CUB_RUNTIME_FUNCTION __forceinline__
* LessThan(int compare) : compare(compare) {}
*
* CUB_RUNTIME_FUNCTION __forceinline__
* bool operator()(const int &a) const {
* return (a < compare);
* }
* };
*
* // Declare, allocate, and initialize device-accessible pointers
* // for input and output
* int num_items; // e.g., 8
* int *d_data; // e.g., [0, 2, 3, 9, 5, 2, 81, 8]
* int *d_num_selected_out; // e.g., [ ]
* LessThan select_op(7);
* ...
*
* // Determine temporary device storage requirements
* void *d_temp_storage = NULL;
* size_t temp_storage_bytes = 0;
* cub::DeviceSelect::If(
* d_temp_storage, temp_storage_bytes,
* d_data, d_num_selected_out, num_items, select_op);
*
* // Allocate temporary storage
* cudaMalloc(&d_temp_storage, temp_storage_bytes);
*
* // Run selection
* cub::DeviceSelect::If(
* d_temp_storage, temp_storage_bytes,
* d_data, d_num_selected_out, num_items, select_op);
*
* // d_data <-- [0, 2, 3, 5, 2]
* // d_num_selected_out <-- [5]
* @endcode
*
* @tparam IteratorT
* **[inferred]** Random-access input iterator type for reading and
* writing items \iterator
*
* @tparam NumSelectedIteratorT
* **[inferred]** Output iterator type for recording the number of items
* selected \iterator
*
* @tparam SelectOp
* **[inferred]** Selection operator type having member
* `bool operator()(const T &a)`
*
* @param[in] d_temp_storage
* Device-accessible allocation of temporary storage. When `nullptr`, the
* required allocation size is written to `temp_storage_bytes` and no work
* is done.
*
* @param[in,out] temp_storage_bytes
* Reference to size in bytes of `d_temp_storage` allocation
*
* @param[in,out] d_data
* Pointer to the sequence of data items
*
* @param[out] d_num_selected_out
* Pointer to the output total number of items selected
*
* @param[in] num_items
* Total number of input items (i.e., length of `d_data`)
*
* @param[in] select_op
* Unary selection operator
*
* @param[in] stream
* **[optional]** CUDA stream to launch kernels within.
* Default is stream<sub>0</sub>.
*
* @param[in] debug_synchronous
* **[optional]** Whether or not to synchronize the stream after every
* kernel launch to check for errors. May cause significant slowdown.
* Default is `false`.
*/
template <typename IteratorT,
typename NumSelectedIteratorT,
typename SelectOp>
CUB_RUNTIME_FUNCTION __forceinline__ static cudaError_t
If(void *d_temp_storage,
size_t &temp_storage_bytes,
IteratorT d_data,
NumSelectedIteratorT d_num_selected_out,
int num_items,
SelectOp select_op,
cudaStream_t stream = 0,
bool debug_synchronous = false)
{
using OffsetT = int; // Signed integer type for global offsets
using FlagIterator = NullType *; // FlagT iterator type (not used)
using EqualityOp = NullType; // Equality operator (not used)

constexpr bool may_alias = true;

return DispatchSelectIf<IteratorT,
FlagIterator,
IteratorT,
NumSelectedIteratorT,
SelectOp,
EqualityOp,
OffsetT,
false,
may_alias>::Dispatch(d_temp_storage,
temp_storage_bytes,
d_data, // in
NULL,
d_data, // out
d_num_selected_out,
select_op,
EqualityOp(),
num_items,
stream,
debug_synchronous);
}

/**
* @brief Given an input sequence `d_in` having runs of consecutive
* equal-valued keys, only the first key from each run is selectively
Expand All @@ -388,6 +663,8 @@ struct DeviceSelect
* equivalent
* - Copies of the selected items are compacted into `d_out` and maintain
* their original relative ordering.
* - The range `[d_out, d_out + *d_num_selected_out)` shall not overlap
* `[d_in, d_in + num_items)` nor `d_num_selected_out` in any way.
* - @devicestorage
*
* @par Performance
Expand Down Expand Up @@ -526,11 +803,18 @@ struct DeviceSelect
* `d_keys_out` and `d_values_out`. The total number of items selected
* is written to `d_num_selected_out`. ![](unique_logo.png)
*
* \par
* @par
* - The `==` equality operator is used to determine whether keys are
* equivalent
* - Copies of the selected items are compacted into `d_out` and maintain
* their original relative ordering.
* - In-place operations are not supported. There must be no overlap between
* any of the provided ranges:
* - `[d_keys_in, d_keys_in + num_items)`
* - `[d_keys_out, d_keys_out + *d_num_selected_out)`
* - `[d_values_in, d_values_in + num_items)`
* - `[d_values_out, d_values_out + *d_num_selected_out)`
* - `[d_num_selected_out, d_num_selected_out + 1)`
* - @devicestorage
*
* @par Snippet
Expand Down
5 changes: 3 additions & 2 deletions cub/device/dispatch/dispatch_select_if.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,8 @@ template <
typename SelectOpT, ///< Selection operator type (NullType if selection flags or discontinuity flagging is to be used for selection)
typename EqualityOpT, ///< Equality operator type (NullType if selection functor or selection flags is to be used for selection)
typename OffsetT, ///< Signed integer type for global offsets
bool KEEP_REJECTS> ///< Whether or not we push rejected items to the back of the output
bool KEEP_REJECTS, ///< Whether or not we push rejected items to the back of the output
bool MayAlias = false>
struct DispatchSelectIf
{
/******************************************************************************
Expand Down Expand Up @@ -161,7 +162,7 @@ struct DispatchSelectIf
128,
ITEMS_PER_THREAD,
BLOCK_LOAD_DIRECT,
LOAD_LDG,
MayAlias ? LOAD_CA : LOAD_LDG,
BLOCK_SCAN_WARP_SCANS>
SelectIfPolicyT;
};
Expand Down
Loading

0 comments on commit 7896eb4

Please sign in to comment.