Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Commit

Permalink
P2322R6 accumulator types for reduce
Browse files Browse the repository at this point in the history
  • Loading branch information
gevtushenko committed Jul 28, 2022
1 parent 728a2a2 commit 921885f
Show file tree
Hide file tree
Showing 9 changed files with 2,349 additions and 1,422 deletions.
661 changes: 361 additions & 300 deletions cub/agent/agent_reduce.cuh

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion cub/block/specializations/block_reduce_warp_reductions.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ struct BlockReduceWarpReductions
// Share lane aggregates
if (lane_id == 0)
{
temp_storage.warp_aggregates[warp_id] = warp_aggregate;
new (temp_storage.warp_aggregates + warp_id) T(warp_aggregate);
}

CTA_SYNC();
Expand Down
10 changes: 7 additions & 3 deletions cub/detail/type_traits.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
#include <cub/util_cpp_dialect.cuh>
#include <cub/util_namespace.cuh>

#include <type_traits>
#include <cuda/std/type_traits>


CUB_NAMESPACE_BEGIN
Expand All @@ -44,11 +44,15 @@ namespace detail {
template <typename Invokable, typename... Args>
using invoke_result_t =
#if CUB_CPP_DIALECT < 2017
typename std::result_of<Invokable(Args...)>::type;
typename cuda::std::result_of<Invokable(Args...)>::type;
#else // 2017+
std::invoke_result_t<Invokable, Args...>;
cuda::std::invoke_result_t<Invokable, Args...>;
#endif

/// The type of intermediate accumulator (according to P2322R6)
template <typename Invokable, typename InitT, typename InputT>
using accumulator_t =
typename cuda::std::decay<invoke_result_t<Invokable, InitT, InputT>>::type;

} // namespace detail
CUB_NAMESPACE_END
155 changes: 88 additions & 67 deletions cub/device/device_reduce.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -204,14 +204,15 @@ struct DeviceReduce
return DispatchReduce<InputIteratorT,
OutputIteratorT,
OffsetT,
ReductionOpT>::Dispatch(d_temp_storage,
temp_storage_bytes,
d_in,
d_out,
num_items,
reduction_op,
init,
stream);
ReductionOpT,
T>::Dispatch(d_temp_storage,
temp_storage_bytes,
d_in,
d_out,
num_items,
reduction_op,
init,
stream);
}

template <typename InputIteratorT,
Expand Down Expand Up @@ -339,15 +340,20 @@ struct DeviceReduce
cub::detail::non_void_value_t<OutputIteratorT,
cub::detail::value_t<InputIteratorT>>;

return DispatchReduce<InputIteratorT, OutputIteratorT, OffsetT, cub::Sum>::
Dispatch(d_temp_storage,
temp_storage_bytes,
d_in,
d_out,
num_items,
cub::Sum(),
OutputT(), // zero-initialize
stream);
using InitT = OutputT;

return DispatchReduce<InputIteratorT,
OutputIteratorT,
OffsetT,
cub::Sum,
InitT>::Dispatch(d_temp_storage,
temp_storage_bytes,
d_in,
d_out,
num_items,
cub::Sum(),
InitT{}, // zero-initialize
stream);
}

template <typename InputIteratorT, typename OutputIteratorT>
Expand Down Expand Up @@ -458,17 +464,23 @@ struct DeviceReduce
// The input value type
using InputT = cub::detail::value_t<InputIteratorT>;

return DispatchReduce<InputIteratorT, OutputIteratorT, OffsetT, cub::Min>::
Dispatch(d_temp_storage,
temp_storage_bytes,
d_in,
d_out,
num_items,
cub::Min(),
Traits<InputT>::Max(), // replace with
// std::numeric_limits<T>::max() when
// C++11 support is more prevalent
stream);
using InitT = InputT;

return DispatchReduce<InputIteratorT,
OutputIteratorT,
OffsetT,
cub::Min,
InitT>::Dispatch(d_temp_storage,
temp_storage_bytes,
d_in,
d_out,
num_items,
cub::Min(),
// replace with
// std::numeric_limits<T>::max() when
// C++11 support is more prevalent
Traits<InitT>::Max(),
stream);
}

template <typename InputIteratorT, typename OutputIteratorT>
Expand Down Expand Up @@ -590,6 +602,8 @@ struct DeviceReduce
cub::detail::non_void_value_t<OutputIteratorT,
KeyValuePair<OffsetT, InputValueT>>;

using InitT = OutputTupleT;

// The output value type
using OutputValueT = typename OutputTupleT::Value;

Expand All @@ -600,23 +614,23 @@ struct DeviceReduce
ArgIndexInputIteratorT d_indexed_in(d_in);

// Initial value
OutputTupleT initial_value(1, Traits<InputValueT>::Max()); // replace with
// std::numeric_limits<T>::max()
// when C++11
// support is
// more prevalent

// replace with std::numeric_limits<T>::max() when C++11 support is
// more prevalent
InitT initial_value(1, Traits<InputValueT>::Max());

return DispatchReduce<ArgIndexInputIteratorT,
OutputIteratorT,
OffsetT,
cub::ArgMin>::Dispatch(d_temp_storage,
temp_storage_bytes,
d_indexed_in,
d_out,
num_items,
cub::ArgMin(),
initial_value,
stream);
cub::ArgMin,
InitT>::Dispatch(d_temp_storage,
temp_storage_bytes,
d_indexed_in,
d_out,
num_items,
cub::ArgMin(),
initial_value,
stream);
}

template <typename InputIteratorT, typename OutputIteratorT>
Expand Down Expand Up @@ -728,17 +742,24 @@ struct DeviceReduce
// The input value type
using InputT = cub::detail::value_t<InputIteratorT>;

return DispatchReduce<InputIteratorT, OutputIteratorT, OffsetT, cub::Max>::
Dispatch(d_temp_storage,
temp_storage_bytes,
d_in,
d_out,
num_items,
cub::Max(),
Traits<InputT>::Lowest(), // replace with
// std::numeric_limits<T>::lowest()
// when C++11 support is more prevalent
stream);
using InitT = InputT;

return DispatchReduce<InputIteratorT,
OutputIteratorT,
OffsetT,
cub::Max,
InitT>::Dispatch(d_temp_storage,
temp_storage_bytes,
d_in,
d_out,
num_items,
cub::Max(),
// replace with
// std::numeric_limits<T>::lowest()
// when C++11 support is more
// prevalent
Traits<InitT>::Lowest(),
stream);
}

template <typename InputIteratorT, typename OutputIteratorT>
Expand Down Expand Up @@ -864,32 +885,32 @@ struct DeviceReduce
// The output value type
using OutputValueT = typename OutputTupleT::Value;

using InitT = OutputTupleT;

// Wrapped input iterator to produce index-value <OffsetT, InputT> tuples
using ArgIndexInputIteratorT =
ArgIndexInputIterator<InputIteratorT, OffsetT, OutputValueT>;

ArgIndexInputIteratorT d_indexed_in(d_in);

// Initial value
OutputTupleT initial_value(1, Traits<InputValueT>::Lowest()); // replace
// with
// std::numeric_limits<T>::lowest()
// when C++11
// support is
// more
// prevalent

// replace with std::numeric_limits<T>::lowest() when C++11 support is
// more prevalent
InitT initial_value(1, Traits<InputValueT>::Lowest());

return DispatchReduce<ArgIndexInputIteratorT,
OutputIteratorT,
OffsetT,
cub::ArgMax>::Dispatch(d_temp_storage,
temp_storage_bytes,
d_indexed_in,
d_out,
num_items,
cub::ArgMax(),
initial_value,
stream);
cub::ArgMax,
InitT>::Dispatch(d_temp_storage,
temp_storage_bytes,
d_indexed_in,
d_out,
num_items,
cub::ArgMax(),
initial_value,
stream);
}

template <typename InputIteratorT, typename OutputIteratorT>
Expand Down
Loading

0 comments on commit 921885f

Please sign in to comment.