Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switch to CUB/Thrust backend for cuda executor argmax #772

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 0 additions & 11 deletions docs_input/api/manipulation/selecting/reduce.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,3 @@ Examples
:start-after: example-begin reduce-1
:end-before: example-end reduce-1
:dedent:

Examples
~~~~~~~~

.. doxygenfunction:: reduce(OutType dest, const InType &in, ReduceOp op, cudaStream_t stream = 0, bool init = true)

.. literalinclude:: ../../../../include/matx/transforms/reduce.h
:language: cpp
:start-after: example-begin reduce-2
:end-before: example-end reduce-2
:dedent:
146 changes: 144 additions & 2 deletions include/matx/core/iterator.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ struct RandomOperatorIterator {
using reference = value_type&;
using iterator_category = std::random_access_iterator_tag;
using difference_type = index_t;
using OperatorBaseType = typename detail::base_type_t<OperatorType>;

__MATX_INLINE__ RandomOperatorIterator(const RandomOperatorIterator &) = default;
__MATX_INLINE__ RandomOperatorIterator(RandomOperatorIterator &&) = default;
Expand All @@ -63,6 +64,12 @@ struct RandomOperatorIterator {
__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ RandomOperatorIterator(const OperatorType &t, stride_type offset) : t_(t), offset_(offset) {}
__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ RandomOperatorIterator(OperatorType &&t, stride_type offset) : t_(t), offset_(offset) {}

template<typename T = OperatorType, std::enable_if_t<!std::is_same<T, OperatorBaseType>::value, bool> = true>
__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ RandomOperatorIterator(const OperatorBaseType &t, stride_type offset) : t_(t), offset_(offset) {}

template<typename T = OperatorType, std::enable_if_t<!std::is_same<T, OperatorBaseType>::value, bool> = true>
__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ RandomOperatorIterator(OperatorBaseType &&t, stride_type offset) : t_(t), offset_(offset) {}

/**
* @brief Dereference value at a pre-computed offset
*
Expand Down Expand Up @@ -160,7 +167,7 @@ struct RandomOperatorIterator {
return t_.Size(dim);
}

typename detail::base_type_t<OperatorType> t_;
OperatorBaseType t_;
stride_type offset_;
};

Expand Down Expand Up @@ -191,6 +198,7 @@ struct RandomOperatorOutputIterator {
using reference = value_type&;
using iterator_category = std::random_access_iterator_tag;
using difference_type = index_t;
using OperatorBaseType = typename detail::base_type_t<OperatorType>;

__MATX_INLINE__ RandomOperatorOutputIterator(RandomOperatorOutputIterator &&) = default;
__MATX_INLINE__ RandomOperatorOutputIterator(const RandomOperatorOutputIterator &) = default;
Expand All @@ -199,6 +207,12 @@ struct RandomOperatorOutputIterator {
__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ RandomOperatorOutputIterator(const OperatorType &t, stride_type offset) : t_(t), offset_(offset) {}
__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ RandomOperatorOutputIterator(OperatorType &&t, stride_type offset) : t_(t), offset_(offset) {}

template<typename T = OperatorType, std::enable_if_t<!std::is_same<T, OperatorBaseType>::value, bool> = true>
__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ RandomOperatorOutputIterator(const OperatorBaseType &t, stride_type offset) : t_(t), offset_(offset) {}

template<typename T = OperatorType, std::enable_if_t<!std::is_same<T, OperatorBaseType>::value, bool> = true>
__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ RandomOperatorOutputIterator(OperatorBaseType &&t, stride_type offset) : t_(t), offset_(offset) {}

[[nodiscard]] __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ reference operator*()
{
if constexpr (OperatorType::Rank() == 0) {
Expand Down Expand Up @@ -288,7 +302,135 @@ struct RandomOperatorOutputIterator {
return t_.Size(dim);
}

typename detail::base_type_t<OperatorType> t_;
OperatorBaseType t_;
stride_type offset_;
};

/**
* @brief Iterator around operators for libraries that can take iterators as input/output (Thrust).
*
* @tparam T Data type
* @tparam RANK Rank of tensor
* @tparam Desc Descriptor for tensor
*
*/
template <typename OperatorType, bool ConvertType = true>
struct RandomOperatorThrustIterator {
using self_type = RandomOperatorThrustIterator<OperatorType, ConvertType>;
using value_type = typename std::conditional_t<ConvertType, detail::convert_matx_type_t<typename OperatorType::value_type>, typename OperatorType::value_type>;
// using stride_type = std::conditional_t<is_tensor_view_v<OperatorType>, typename OperatorType::desc_type::stride_type,
// index_t>;
using stride_type = index_t;
using pointer = value_type*;
using reference = value_type&;
using const_reference = value_type&;
using iterator_category = std::random_access_iterator_tag;
using difference_type = index_t;
using OperatorBaseType = typename detail::base_type_t<OperatorType>;

__MATX_INLINE__ RandomOperatorThrustIterator(RandomOperatorThrustIterator &&) = default;
__MATX_INLINE__ RandomOperatorThrustIterator(const RandomOperatorThrustIterator &) = default;
__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ RandomOperatorThrustIterator(OperatorType &&t) : t_(t), offset_(0) { }
__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ RandomOperatorThrustIterator(const OperatorType &t) : t_(t), offset_(0) { }
__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ RandomOperatorThrustIterator(const OperatorType &t, stride_type offset) : t_(t), offset_(offset) {}
__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ RandomOperatorThrustIterator(OperatorType &&t, stride_type offset) : t_(t), offset_(offset) {}

template<typename T = OperatorType, std::enable_if_t<!std::is_same<T, OperatorBaseType>::value, bool> = true>
__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ RandomOperatorThrustIterator(const OperatorBaseType &t, stride_type offset) : t_(t), offset_(offset) {}

template<typename T = OperatorType, std::enable_if_t<!std::is_same<T, OperatorBaseType>::value, bool> = true>
__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ RandomOperatorThrustIterator(OperatorBaseType &&t, stride_type offset) : t_(t), offset_(offset) {}

[[nodiscard]] __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ reference operator*() const
{
if constexpr (OperatorType::Rank() == 0) {
auto &tmp = t_.operator()();
return tmp;
}
else {
auto arrs = detail::GetIdxFromAbs(t_, offset_);

return cuda::std::apply([&](auto &&...args) -> reference {
auto &tmp = t_.operator()(args...);
return tmp;
}, arrs);
}
}

[[nodiscard]] __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ self_type operator+(difference_type offset) const
{
return self_type{t_, offset_ + offset};
}


[[nodiscard]] __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ reference operator[](difference_type offset)
{
return *self_type{t_, offset_ + offset};
}

__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ self_type operator++(int)
{
self_type retval = *this;
offset_++;
return retval;
}

__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ self_type operator++()
{
offset_++;
return *this;
}

__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ self_type& operator+=(difference_type offset)
{
offset_ += offset;
return *this;
}

__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ self_type& operator=(const self_type &rhs)
{
t_.copy(rhs.t_);
offset_ = rhs.offset_;
return *this;
}

__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ self_type operator-(difference_type offset) const
{
return self_type{t_, offset_ - offset};
}


__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ self_type& operator-=(difference_type offset)
{
offset_ -= offset;
return *this;
}

__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ self_type& operator--() {
--offset_;
return *this;
}

__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ friend bool operator!=(const self_type &a, const self_type &b)
{
return a.offset_ != b.offset_;
}

__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ friend bool operator==(const self_type &a, const self_type &b)
{
return a.offset_ == b.offset_;
}

static __MATX_INLINE__ constexpr __MATX_HOST__ __MATX_DEVICE__ int32_t Rank() {
return OperatorType::Rank();
}

constexpr __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ index_t Size(int dim) const
{
return t_.Size(dim);
}

OperatorBaseType t_;
stride_type offset_;
};

Expand Down
Loading