Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Commit

Permalink
Fix review notes
Browse files Browse the repository at this point in the history
  • Loading branch information
gevtushenko committed Aug 23, 2021
1 parent 73e8648 commit 53ffc37
Show file tree
Hide file tree
Showing 8 changed files with 1,695 additions and 771 deletions.
68 changes: 26 additions & 42 deletions cub/agent/agent_adjacent_difference.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -60,20 +60,18 @@ struct AgentAdjacentDifferencePolicy
template <typename Policy,
typename InputIteratorT,
typename OutputIteratorT,
typename FlagOpT,
typename DifferenceOpT,
typename OffsetT,
typename InputT,
typename OutputT,
bool InPlace,
bool ReadLeft>
struct AgentDifference
{
// XXX output type must be result of BinaryOp(input_type,input_type);
using OutputT = InputT;

using LoadIt = typename THRUST_NS_QUALIFIER::cuda_cub::core::LoadIterator<Policy, InputIteratorT>::type;

using BlockLoad = typename cub::BlockLoadType<Policy, LoadIt>::type;
using BlockStore = typename cub::BlockStoreType<Policy, OutputIteratorT, InputT>::type;
using BlockStore = typename cub::BlockStoreType<Policy, OutputIteratorT, OutputT>::type;

using BlockAdjacentDifferenceT =
cub::BlockAdjacentDifference<InputT, Policy::BLOCK_THREADS>;
Expand All @@ -98,21 +96,23 @@ struct AgentDifference
LoadIt load_it;
InputT *first_tile_previous;
OutputIteratorT result;
FlagOpT flag_op;
DifferenceOpT difference_op;
OffsetT num_items;

__device__ __forceinline__ AgentDifference(TempStorage &temp_storage,
InputIteratorT input_it,
InputT *first_tile_previous,
OutputIteratorT result,
FlagOpT flag_op,
DifferenceOpT difference_op,
OffsetT num_items)
: temp_storage(temp_storage.Alias())
, input_it(input_it)
, load_it(input_it)
, load_it(
THRUST_NS_QUALIFIER::cuda_cub::core::make_load_iterator(Policy(),
input_it))
, first_tile_previous(first_tile_previous)
, result(result)
, flag_op(flag_op)
, difference_op(difference_op)
, num_items(num_items)
{}

Expand Down Expand Up @@ -141,36 +141,34 @@ struct AgentDifference

if (ReadLeft)
{
InputT input_prev[ITEMS_PER_THREAD];

if (IS_FIRST_TILE)
{
BlockAdjacentDifferenceT(temp_storage.adjacent_difference)
.SubtractLeft(output, input, input_prev, flag_op);
.SubtractLeft(output, input, difference_op);
}
else
{
InputT tile_prev_input = InPlace ? first_tile_previous[tile_idx]
: *(input_it + tile_base - 1);

BlockAdjacentDifferenceT(temp_storage.adjacent_difference)
.SubtractLeft(output, input, input_prev, flag_op, tile_prev_input);
.SubtractLeft(output, input, difference_op, tile_prev_input);
}
}
else
{
if (IS_LAST_TILE)
{
BlockAdjacentDifferenceT(temp_storage.adjacent_difference)
.SubtractRightPartialTile(output, input, flag_op, num_remaining);
.SubtractRightPartialTile(output, input, difference_op, num_remaining);
}
else
{
InputT tile_next_input = InPlace ? first_tile_previous[tile_idx]
: *(input_it + tile_base + ITEMS_PER_TILE);

BlockAdjacentDifferenceT(temp_storage.adjacent_difference)
.SubtractRight(output, input, flag_op, tile_next_input);
.SubtractRight(output, input, difference_op, tile_next_input);
}
}

Expand Down Expand Up @@ -223,45 +221,31 @@ struct AgentDifference
};

template <typename InputIteratorT,
typename OutputIteratorT,
typename OffsetT>
struct AgentDifferenceInitLeft
typename InputT,
typename OffsetT,
bool ReadLeft>
struct AgentDifferenceInit
{
static constexpr int BLOCK_THREADS = 128;

static __device__ __forceinline__ void Process(int tile_idx,
InputIteratorT first,
OutputIteratorT result,
InputT *result,
OffsetT num_tiles,
int items_per_tile)
{
OffsetT tile_base = static_cast<OffsetT>(tile_idx) * items_per_tile;

if (tile_base > 0 && tile_idx < num_tiles)
{
result[tile_idx] = first[tile_base - 1];
}
}
};

template <typename InputIteratorT,
typename OutputIteratorT,
typename OffsetT>
struct AgentDifferenceInitRight
{
static constexpr int BLOCK_THREADS = 128;

static __device__ __forceinline__ void Process(int tile_idx,
InputIteratorT first,
OutputIteratorT result,
OffsetT num_tiles,
int items_per_tile)
{
OffsetT tile_base = static_cast<OffsetT>(tile_idx) * items_per_tile;

if (tile_base > 0 && tile_idx < num_tiles)
{
result[tile_idx - 1] = first[tile_base];
if (ReadLeft)
{
result[tile_idx] = first[tile_base - 1];
}
else
{
result[tile_idx - 1] = first[tile_base];
}
}
}
};
Expand Down
Loading

0 comments on commit 53ffc37

Please sign in to comment.