Skip to content

Commit

Permalink
Extract transform iterator. (#8498)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis authored Dec 5, 2022
1 parent d8544e4 commit e3bf556
Show file tree
Hide file tree
Showing 9 changed files with 118 additions and 71 deletions.
68 changes: 0 additions & 68 deletions src/common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -164,74 +164,6 @@ class Range {
Iterator end_;
};

/**
* \brief Transform iterator that takes an index and calls transform operator.
*
* This is CPU-only right now as taking host device function as operator complicates the
* code. For device side one can use `thrust::transform_iterator` instead.
*/
template <typename Fn>
class IndexTransformIter {
size_t iter_{0};
Fn fn_;

public:
using iterator_category = std::random_access_iterator_tag; // NOLINT
using value_type = std::result_of_t<Fn(size_t)>; // NOLINT
using difference_type = detail::ptrdiff_t; // NOLINT
using reference = std::add_lvalue_reference_t<value_type>; // NOLINT
using pointer = std::add_pointer_t<value_type>; // NOLINT

public:
/**
* \param op Transform operator, takes a size_t index as input.
*/
explicit IndexTransformIter(Fn &&op) : fn_{op} {}
IndexTransformIter(IndexTransformIter const &) = default;
IndexTransformIter& operator=(IndexTransformIter&&) = default;
IndexTransformIter& operator=(IndexTransformIter const& that) {
iter_ = that.iter_;
return *this;
}

value_type operator*() const { return fn_(iter_); }

auto operator-(IndexTransformIter const &that) const { return iter_ - that.iter_; }
bool operator==(IndexTransformIter const &that) const { return iter_ == that.iter_; }
bool operator!=(IndexTransformIter const &that) const { return !(*this == that); }

IndexTransformIter &operator++() {
iter_++;
return *this;
}
IndexTransformIter operator++(int) {
auto ret = *this;
++(*this);
return ret;
}
IndexTransformIter &operator+=(difference_type n) {
iter_ += n;
return *this;
}
IndexTransformIter &operator-=(difference_type n) {
(*this) += -n;
return *this;
}
IndexTransformIter operator+(difference_type n) const {
auto ret = *this;
return ret += n;
}
IndexTransformIter operator-(difference_type n) const {
auto ret = *this;
return ret -= n;
}
};

template <typename Fn>
auto MakeIndexTransformIter(Fn&& fn) {
return IndexTransformIter<Fn>(std::forward<Fn>(fn));
}

int AllVisibleGPUs();

inline void AssertGPUSupport() {
Expand Down
1 change: 1 addition & 0 deletions src/common/linalg_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include "common.h"
#include "threading_utils.h"
#include "transform_iterator.h" // MakeIndexTransformIter
#include "xgboost/generic_parameters.h"
#include "xgboost/linalg.h"

Expand Down
3 changes: 2 additions & 1 deletion src/common/quantile.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
*/
#include <thrust/binary_search.h>
#include <thrust/execution_policy.h>
#include <thrust/iterator/discard_iterator.h>
#include <thrust/iterator/constant_iterator.h>
#include <thrust/iterator/discard_iterator.h>
#include <thrust/transform_scan.h>
#include <thrust/unique.h>

Expand All @@ -20,6 +20,7 @@
#include "hist_util.h"
#include "quantile.cuh"
#include "quantile.h"
#include "transform_iterator.h" // MakeIndexTransformIter
#include "xgboost/span.h"

namespace xgboost {
Expand Down
3 changes: 2 additions & 1 deletion src/common/stats.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
#include <limits>
#include <vector>

#include "common.h" // AssertGPUSupport
#include "common.h" // AssertGPUSupport
#include "transform_iterator.h" // MakeIndexTransformIter
#include "xgboost/generic_parameters.h"
#include "xgboost/linalg.h"

Expand Down
89 changes: 89 additions & 0 deletions src/common/transform_iterator.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
/**
* Copyright 2022 by XGBoost Contributors
*/
#ifndef XGBOOST_COMMON_TRANSFORM_ITERATOR_H_
#define XGBOOST_COMMON_TRANSFORM_ITERATOR_H_

#include <cstddef> // std::size_t
#include <iterator> // std::random_access_iterator_tag
#include <type_traits> // std::result_of_t, std::add_pointer_t, std::add_lvalue_reference_t
#include <utility> // std::forward

#include "xgboost/span.h" // ptrdiff_t

namespace xgboost {
namespace common {
/**
* \brief Transform iterator that takes an index and calls transform operator.
*
* This is CPU-only right now as taking host device function as operator complicates the
* code. For device side one can use `thrust::transform_iterator` instead.
*/
template <typename Fn>
class IndexTransformIter {
std::size_t iter_{0};
Fn fn_;

public:
using iterator_category = std::random_access_iterator_tag; // NOLINT
using value_type = std::result_of_t<Fn(std::size_t)>; // NOLINT
using difference_type = detail::ptrdiff_t; // NOLINT
using reference = std::add_lvalue_reference_t<value_type>; // NOLINT
using pointer = std::add_pointer_t<value_type>; // NOLINT

public:
/**
* \param op Transform operator, takes a size_t index as input.
*/
explicit IndexTransformIter(Fn &&op) : fn_{op} {}
IndexTransformIter(IndexTransformIter const &) = default;
IndexTransformIter &operator=(IndexTransformIter &&) = default;
IndexTransformIter &operator=(IndexTransformIter const &that) {
iter_ = that.iter_;
return *this;
}

value_type operator*() const { return fn_(iter_); }
value_type operator[](std::size_t i) const {
auto iter = *this + i;
return *iter;
}

auto operator-(IndexTransformIter const &that) const { return iter_ - that.iter_; }
bool operator==(IndexTransformIter const &that) const { return iter_ == that.iter_; }
bool operator!=(IndexTransformIter const &that) const { return !(*this == that); }

IndexTransformIter &operator++() {
iter_++;
return *this;
}
IndexTransformIter operator++(int) {
auto ret = *this;
++(*this);
return ret;
}
IndexTransformIter &operator+=(difference_type n) {
iter_ += n;
return *this;
}
IndexTransformIter &operator-=(difference_type n) {
(*this) += -n;
return *this;
}
IndexTransformIter operator+(difference_type n) const {
auto ret = *this;
return ret += n;
}
IndexTransformIter operator-(difference_type n) const {
auto ret = *this;
return ret -= n;
}
};

template <typename Fn>
auto MakeIndexTransformIter(Fn &&fn) {
return IndexTransformIter<Fn>(std::forward<Fn>(fn));
}
} // namespace common
} // namespace xgboost
#endif // XGBOOST_COMMON_TRANSFORM_ITERATOR_H_
1 change: 1 addition & 0 deletions src/data/ellpack_page.cu
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "../common/categorical.h"
#include "../common/hist_util.cuh"
#include "../common/random.h"
#include "../common/transform_iterator.h" // MakeIndexTransformIter
#include "./ellpack_page.cuh"
#include "device_adapter.cuh"
#include "gradient_index.h"
Expand Down
3 changes: 2 additions & 1 deletion src/data/gradient_index.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "../common/hist_util.h"
#include "../common/numeric.h"
#include "../common/threading_utils.h"
#include "../common/transform_iterator.h" // MakeIndexTransformIter

namespace xgboost {

Expand Down Expand Up @@ -78,7 +79,7 @@ GHistIndexMatrix::~GHistIndexMatrix() = default;
void GHistIndexMatrix::PushBatch(SparsePage const &batch, common::Span<FeatureType const> ft,
int32_t n_threads) {
auto page = batch.GetView();
auto it = common::MakeIndexTransformIter([&](size_t ridx) { return page[ridx].size(); });
auto it = common::MakeIndexTransformIter([&](std::size_t ridx) { return page[ridx].size(); });
common::PartialSum(n_threads, it, it + page.Size(), static_cast<size_t>(0), row_ptr.begin());
data::SparsePageAdapterBatch adapter_batch{page};
auto is_valid = [](auto) { return true; }; // SparsePage always contains valid entries
Expand Down
1 change: 1 addition & 0 deletions src/data/gradient_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "../common/hist_util.h"
#include "../common/numeric.h"
#include "../common/threading_utils.h"
#include "../common/transform_iterator.h" // MakeIndexTransformIter
#include "adapter.h"
#include "proxy_dmatrix.h"
#include "xgboost/base.h"
Expand Down
20 changes: 20 additions & 0 deletions tests/cpp/common/test_transform_iterator.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
/**
* Copyright 2022 by XGBoost Contributors
*/
#include <gtest/gtest.h>

#include <cstddef> // std::size_t

#include "../../../src/common/transform_iterator.h"

namespace xgboost {
namespace common {
TEST(IndexTransformIter, Basic) {
auto sqr = [](std::size_t i) { return i * i; };
auto iter = MakeIndexTransformIter(sqr);
for (std::size_t i = 0; i < 4; ++i) {
ASSERT_EQ(iter[i], sqr(i));
}
}
} // namespace common
} // namespace xgboost

0 comments on commit e3bf556

Please sign in to comment.