Skip to content

Commit

Permalink
Switch to CUB/Thrust backend for cuda executor argmax
Browse files Browse the repository at this point in the history
  • Loading branch information
tmartin-gh committed Oct 17, 2024
1 parent 4528b94 commit f8d9d8a
Show file tree
Hide file tree
Showing 3 changed files with 358 additions and 9 deletions.
20 changes: 17 additions & 3 deletions include/matx/core/iterator.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ struct RandomOperatorIterator {
using reference = value_type&;
using iterator_category = std::random_access_iterator_tag;
using difference_type = index_t;
using OperatorBaseType = typename detail::base_type_t<OperatorType>;

__MATX_INLINE__ RandomOperatorIterator(const RandomOperatorIterator &) = default;
__MATX_INLINE__ RandomOperatorIterator(RandomOperatorIterator &&) = default;
Expand All @@ -63,6 +64,12 @@ struct RandomOperatorIterator {
__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ RandomOperatorIterator(const OperatorType &t, stride_type offset) : t_(t), offset_(offset) {}
__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ RandomOperatorIterator(OperatorType &&t, stride_type offset) : t_(t), offset_(offset) {}

template<typename T = OperatorType, std::enable_if_t<!std::is_same<T, OperatorBaseType>::value, bool> = true>
__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ RandomOperatorIterator(const OperatorBaseType &t, stride_type offset) : t_(t), offset_(offset) {}

template<typename T = OperatorType, std::enable_if_t<!std::is_same<T, OperatorBaseType>::value, bool> = true>
__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ RandomOperatorIterator(OperatorBaseType &&t, stride_type offset) : t_(t), offset_(offset) {}

/**
* @brief Dereference value at a pre-computed offset
*
Expand Down Expand Up @@ -160,7 +167,7 @@ struct RandomOperatorIterator {
return t_.Size(dim);
}

typename detail::base_type_t<OperatorType> t_;
OperatorBaseType t_;
stride_type offset_;
};

Expand Down Expand Up @@ -191,6 +198,7 @@ struct RandomOperatorOutputIterator {
using reference = value_type&;
using iterator_category = std::random_access_iterator_tag;
using difference_type = index_t;
using OperatorBaseType = typename detail::base_type_t<OperatorType>;

__MATX_INLINE__ RandomOperatorOutputIterator(RandomOperatorOutputIterator &&) = default;
__MATX_INLINE__ RandomOperatorOutputIterator(const RandomOperatorOutputIterator &) = default;
Expand All @@ -199,7 +207,13 @@ struct RandomOperatorOutputIterator {
__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ RandomOperatorOutputIterator(const OperatorType &t, stride_type offset) : t_(t), offset_(offset) {}
__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ RandomOperatorOutputIterator(OperatorType &&t, stride_type offset) : t_(t), offset_(offset) {}

[[nodiscard]] __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ reference operator*()
template<typename T = OperatorType, std::enable_if_t<!std::is_same<T, OperatorBaseType>::value, bool> = true>
__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ RandomOperatorOutputIterator(const OperatorBaseType &t, stride_type offset) : t_(t), offset_(offset) {}

template<typename T = OperatorType, std::enable_if_t<!std::is_same<T, OperatorBaseType>::value, bool> = true>
__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ RandomOperatorOutputIterator(OperatorBaseType &&t, stride_type offset) : t_(t), offset_(offset) {}

[[nodiscard]] __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ reference operator*() const
{
if constexpr (OperatorType::Rank() == 0) {
auto &tmp = t_.operator()();
Expand Down Expand Up @@ -288,7 +302,7 @@ struct RandomOperatorOutputIterator {
return t_.Size(dim);
}

typename detail::base_type_t<OperatorType> t_;
OperatorBaseType t_;
stride_type offset_;
};

Expand Down
Loading

0 comments on commit f8d9d8a

Please sign in to comment.