Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Commit

Permalink
In-place guarantees for scan algorithms
Browse files Browse the repository at this point in the history
  • Loading branch information
gevtushenko committed May 23, 2022
1 parent f80aa78 commit b99db91
Show file tree
Hide file tree
Showing 6 changed files with 2,624 additions and 1,148 deletions.
37 changes: 24 additions & 13 deletions cub/agent/agent_scan_by_key.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ struct AgentScanByKey

TempStorage & storage;
WrappedKeysInputIteratorT d_keys_in;
KeyT* d_keys_prev_in;
WrappedValuesInputIteratorT d_values_in;
ValuesOutputIteratorT d_values_out;
InequalityWrapper<EqualityOp> inequality_op;
Expand Down Expand Up @@ -364,19 +365,27 @@ struct AgentScanByKey
}
else
{
KeyT tile_pred_key = (threadIdx.x == 0) ? d_keys_in[tile_base - 1] : KeyT();
BlockDiscontinuityKeysT(storage.scan_storage.discontinuity)
.FlagHeads(segment_flags, keys, inequality_op, tile_pred_key);

// Zip values and segment_flags
ZipValuesAndFlags<IS_LAST_TILE>(num_remaining,
values,
segment_flags,
scan_items);

SizeValuePairT tile_aggregate;
TilePrefixCallbackT prefix_op(tile_state, storage.scan_storage.prefix, pair_scan_op, tile_idx);
ScanTile(scan_items, tile_aggregate, prefix_op, Int2Type<IS_INCLUSIVE>());
KeyT tile_pred_key = (threadIdx.x == 0) ? d_keys_prev_in[tile_idx]
: KeyT();

BlockDiscontinuityKeysT(storage.scan_storage.discontinuity)
.FlagHeads(segment_flags, keys, inequality_op, tile_pred_key);

// Zip values and segment_flags
ZipValuesAndFlags<IS_LAST_TILE>(num_remaining,
values,
segment_flags,
scan_items);

SizeValuePairT tile_aggregate;
TilePrefixCallbackT prefix_op(tile_state,
storage.scan_storage.prefix,
pair_scan_op,
tile_idx);
ScanTile(scan_items,
tile_aggregate,
prefix_op,
Int2Type<IS_INCLUSIVE>());
}

CTA_SYNC();
Expand Down Expand Up @@ -408,6 +417,7 @@ struct AgentScanByKey
AgentScanByKey(
TempStorage & storage,
KeysInputIteratorT d_keys_in,
KeyT * d_keys_prev_in,
ValuesInputIteratorT d_values_in,
ValuesOutputIteratorT d_values_out,
EqualityOp equality_op,
Expand All @@ -416,6 +426,7 @@ struct AgentScanByKey
:
storage(storage),
d_keys_in(d_keys_in),
d_keys_prev_in(d_keys_prev_in),
d_values_in(d_values_in),
d_values_out(d_values_out),
inequality_op(equality_op),
Expand Down
Loading

0 comments on commit b99db91

Please sign in to comment.