Skip to content

Commit

Permalink
Remove unnecessary fetch operations in external memory. (#10342)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis authored May 31, 2024
1 parent c2e3d4f commit d2d01d9
Show file tree
Hide file tree
Showing 12 changed files with 180 additions and 71 deletions.
11 changes: 4 additions & 7 deletions include/xgboost/data.h
Original file line number Diff line number Diff line change
Expand Up @@ -468,10 +468,7 @@ class BatchIterator {
return *(*impl_);
}

bool operator!=(const BatchIterator&) const {
CHECK(impl_ != nullptr);
return !impl_->AtEnd();
}
[[nodiscard]] bool operator!=(const BatchIterator&) const { return !this->AtEnd(); }

[[nodiscard]] bool AtEnd() const {
CHECK(impl_ != nullptr);
Expand Down Expand Up @@ -506,13 +503,13 @@ class DMatrix {
public:
/*! \brief default constructor */
DMatrix() = default;
/*! \brief meta information of the dataset */
virtual MetaInfo& Info() = 0;
/** @brief meta information of the dataset */
[[nodiscard]] virtual MetaInfo& Info() = 0;
virtual void SetInfo(const char* key, std::string const& interface_str) {
auto const& ctx = *this->Ctx();
this->Info().SetInfo(ctx, key, StringView{interface_str});
}
/*! \brief meta information of the dataset */
/** @brief meta information of the dataset */
[[nodiscard]] virtual const MetaInfo& Info() const = 0;

/*! \brief Get thread local memory for returning data from DMatrix. */
Expand Down
3 changes: 1 addition & 2 deletions src/data/ellpack_page_source.cu
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
/**
* Copyright 2019-2023, XGBoost contributors
* Copyright 2019-2024, XGBoost contributors
*/
#include <memory>
#include <utility>

#include "ellpack_page.cuh"
#include "ellpack_page.h" // for EllpackPage
Expand Down
1 change: 0 additions & 1 deletion src/data/proxy_dmatrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

#include <any> // for any, any_cast
#include <memory>
#include <string>
#include <type_traits> // for invoke_result_t
#include <utility>

Expand Down
29 changes: 12 additions & 17 deletions src/data/sparse_page_dmatrix.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,10 @@ SparsePageDMatrix::SparsePageDMatrix(DataIterHandle iter_handle, DMatrixHandle p
auto iter = DataIterProxy<DataIterResetCallback, XGDMatrixCallbackNext>{
iter_, reset_, next_};

uint32_t n_batches = 0;
size_t n_features = 0;
size_t n_samples = 0;
size_t nnz = 0;
std::uint32_t n_batches = 0;
bst_feature_t n_features = 0;
bst_idx_t n_samples = 0;
bst_idx_t nnz = 0;

auto num_rows = [&]() {
bool type_error {false};
Expand All @@ -72,7 +72,7 @@ SparsePageDMatrix::SparsePageDMatrix(DataIterHandle iter_handle, DMatrixHandle p
};
auto num_cols = [&]() {
bool type_error {false};
size_t n_features = HostAdapterDispatch(
bst_feature_t n_features = HostAdapterDispatch(
proxy, [](auto const &value) { return value.NumCols(); }, &type_error);
if (type_error) {
n_features = detail::NFeaturesDevice(proxy);
Expand Down Expand Up @@ -121,10 +121,9 @@ void SparsePageDMatrix::InitializeSparsePage(Context const *ctx) {
this->n_batches_, cache_info_.at(id));
}

BatchSet<SparsePage> SparsePageDMatrix::GetRowBatchesImpl(Context const* ctx) {
BatchSet<SparsePage> SparsePageDMatrix::GetRowBatchesImpl(Context const *ctx) {
this->InitializeSparsePage(ctx);
auto begin_iter = BatchIterator<SparsePage>(sparse_page_source_);
return BatchSet<SparsePage>(BatchIterator<SparsePage>(begin_iter));
return BatchSet{BatchIterator<SparsePage>{this->sparse_page_source_}};
}

BatchSet<SparsePage> SparsePageDMatrix::GetRowBatches() {
Expand All @@ -143,8 +142,7 @@ BatchSet<CSCPage> SparsePageDMatrix::GetColumnBatches(Context const *ctx) {
} else {
column_source_->Reset();
}
auto begin_iter = BatchIterator<CSCPage>(column_source_);
return BatchSet<CSCPage>(BatchIterator<CSCPage>(begin_iter));
return BatchSet{BatchIterator<CSCPage>{this->column_source_}};
}

BatchSet<SortedCSCPage> SparsePageDMatrix::GetSortedColumnBatches(Context const *ctx) {
Expand All @@ -158,8 +156,7 @@ BatchSet<SortedCSCPage> SparsePageDMatrix::GetSortedColumnBatches(Context const
} else {
sorted_column_source_->Reset();
}
auto begin_iter = BatchIterator<SortedCSCPage>(sorted_column_source_);
return BatchSet<SortedCSCPage>(BatchIterator<SortedCSCPage>(begin_iter));
return BatchSet{BatchIterator<SortedCSCPage>{this->sorted_column_source_}};
}

BatchSet<GHistIndexMatrix> SparsePageDMatrix::GetGradientIndex(Context const *ctx,
Expand All @@ -169,8 +166,8 @@ BatchSet<GHistIndexMatrix> SparsePageDMatrix::GetGradientIndex(Context const *ct
}
detail::CheckEmpty(batch_param_, param);
auto id = MakeCache(this, ".gradient_index.page", cache_prefix_, &cache_info_);
this->InitializeSparsePage(ctx);
if (!cache_info_.at(id)->written || detail::RegenGHist(batch_param_, param)) {
this->InitializeSparsePage(ctx);
cache_info_.erase(id);
MakeCache(this, ".gradient_index.page", cache_prefix_, &cache_info_);
LOG(INFO) << "Generating new Gradient Index.";
Expand All @@ -190,15 +187,13 @@ BatchSet<GHistIndexMatrix> SparsePageDMatrix::GetGradientIndex(Context const *ct
CHECK(ghist_index_source_);
ghist_index_source_->Reset();
}
auto begin_iter = BatchIterator<GHistIndexMatrix>(ghist_index_source_);
return BatchSet<GHistIndexMatrix>(BatchIterator<GHistIndexMatrix>(begin_iter));
return BatchSet{BatchIterator<GHistIndexMatrix>{this->ghist_index_source_}};
}

#if !defined(XGBOOST_USE_CUDA)
BatchSet<EllpackPage> SparsePageDMatrix::GetEllpackBatches(Context const *, const BatchParam &) {
common::AssertGPUSupport();
auto begin_iter = BatchIterator<EllpackPage>(ellpack_page_source_);
return BatchSet<EllpackPage>(BatchIterator<EllpackPage>(begin_iter));
return BatchSet{BatchIterator<EllpackPage>{this->ellpack_page_source_}};
}
#endif // !defined(XGBOOST_USE_CUDA)
} // namespace xgboost::data
7 changes: 3 additions & 4 deletions src/data/sparse_page_dmatrix.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright 2021-2023 by XGBoost contributors
* Copyright 2021-2024, XGBoost contributors
*/
#include <memory> // for unique_ptr

Expand All @@ -21,8 +21,8 @@ BatchSet<EllpackPage> SparsePageDMatrix::GetEllpackBatches(Context const* ctx,
detail::CheckEmpty(batch_param_, param);
auto id = MakeCache(this, ".ellpack.page", cache_prefix_, &cache_info_);
size_t row_stride = 0;
this->InitializeSparsePage(ctx);
if (!cache_info_.at(id)->written || detail::RegenGHist(batch_param_, param)) {
this->InitializeSparsePage(ctx);
// reinitialize the cache
cache_info_.erase(id);
MakeCache(this, ".ellpack.page", cache_prefix_, &cache_info_);
Expand Down Expand Up @@ -52,7 +52,6 @@ BatchSet<EllpackPage> SparsePageDMatrix::GetEllpackBatches(Context const* ctx,
ellpack_page_source_->Reset();
}

auto begin_iter = BatchIterator<EllpackPage>(ellpack_page_source_);
return BatchSet<EllpackPage>(BatchIterator<EllpackPage>(begin_iter));
return BatchSet{BatchIterator<EllpackPage>{this->ellpack_page_source_}};
}
} // namespace xgboost::data
49 changes: 33 additions & 16 deletions src/data/sparse_page_dmatrix.h
Original file line number Diff line number Diff line change
@@ -1,18 +1,16 @@
/**
* Copyright 2015-2023, XGBoost Contributors
* Copyright 2015-2024, XGBoost Contributors
* \file sparse_page_dmatrix.h
* \brief External-memory version of DMatrix.
* \author Tianqi Chen
*/
#ifndef XGBOOST_DATA_SPARSE_PAGE_DMATRIX_H_
#define XGBOOST_DATA_SPARSE_PAGE_DMATRIX_H_

#include <algorithm>
#include <map>
#include <memory>
#include <string>
#include <utility>
#include <vector>

#include "ellpack_page_source.h"
#include "gradient_index_page_source.h"
Expand All @@ -22,7 +20,7 @@

namespace xgboost::data {
/**
* \brief DMatrix used for external memory.
* @brief DMatrix used for external memory.
*
* The external memory is created for controlling memory usage by splitting up data into
* multiple batches. However that doesn't mean we will actually process exactly 1 batch
Expand Down Expand Up @@ -51,8 +49,13 @@ namespace xgboost::data {
* want to change the generated page like Ellpack, pass parameter into `GetBatches` to
* re-generate them instead of trying to modify the pages in-place.
*
* A possible optimization is dropping the sparse page once dependent pages like ellpack
* are constructed and cached.
* The overall chain of responsibility of external memory DMatrix:
*
* User defined iterator (in Python/C/R) -> Proxy DMatrix -> Sparse page Source ->
* Other sources (Like Ellpack) -> Sparse Page DMatrix -> Caller
*
* A possible optimization is skipping the sparse page source for `hist` based algorithms
* similar to the Quantile DMatrix.
*/
class SparsePageDMatrix : public DMatrix {
MetaInfo info_;
Expand All @@ -67,7 +70,7 @@ class SparsePageDMatrix : public DMatrix {
float missing_;
Context fmat_ctx_;
std::string cache_prefix_;
uint32_t n_batches_{0};
std::uint32_t n_batches_{0};
// sparse page is the source to other page types, we make a special member function.
void InitializeSparsePage(Context const *ctx);
// Non-virtual version that can be used in constructor
Expand All @@ -93,11 +96,11 @@ class SparsePageDMatrix : public DMatrix {
}
}

MetaInfo &Info() override;
const MetaInfo &Info() const override;
Context const *Ctx() const override { return &fmat_ctx_; }
[[nodiscard]] MetaInfo &Info() override;
[[nodiscard]] const MetaInfo &Info() const override;
[[nodiscard]] Context const *Ctx() const override { return &fmat_ctx_; }
// The only DMatrix implementation that returns false.
bool SingleColBlock() const override { return false; }
[[nodiscard]] bool SingleColBlock() const override { return false; }
DMatrix *Slice(common::Span<int32_t const>) override {
LOG(FATAL) << "Slicing DMatrix is not supported for external memory.";
return nullptr;
Expand All @@ -107,6 +110,20 @@ class SparsePageDMatrix : public DMatrix {
return nullptr;
}

[[nodiscard]] bool EllpackExists() const override {
return static_cast<bool>(ellpack_page_source_);
}
[[nodiscard]] bool GHistIndexExists() const override {
return static_cast<bool>(ghist_index_source_);
}
[[nodiscard]] bool SparsePageExists() const override {
return static_cast<bool>(sparse_page_source_);
}
// For testing, getter for the number of fetches for sparse page source.
[[nodiscard]] auto SparsePageFetchCount() const {
return this->sparse_page_source_->FetchCount();
}

private:
BatchSet<SparsePage> GetRowBatches() override;
BatchSet<CSCPage> GetColumnBatches(Context const *ctx) override;
Expand All @@ -118,24 +135,24 @@ class SparsePageDMatrix : public DMatrix {
return BatchSet<ExtSparsePage>(BatchIterator<ExtSparsePage>(nullptr));
}

private:
// source data pointers.
std::shared_ptr<SparsePageSource> sparse_page_source_;
std::shared_ptr<EllpackPageSource> ellpack_page_source_;
std::shared_ptr<CSCPageSource> column_source_;
std::shared_ptr<SortedCSCPageSource> sorted_column_source_;
std::shared_ptr<GradientIndexPageSource> ghist_index_source_;

bool EllpackExists() const override { return static_cast<bool>(ellpack_page_source_); }
bool GHistIndexExists() const override { return static_cast<bool>(ghist_index_source_); }
bool SparsePageExists() const override { return static_cast<bool>(sparse_page_source_); }
};

inline std::string MakeId(std::string prefix, SparsePageDMatrix *ptr) {
[[nodiscard]] inline std::string MakeId(std::string prefix, SparsePageDMatrix *ptr) {
std::stringstream ss;
ss << ptr;
return prefix + "-" + ss.str();
}

/**
* @brief Make cache if it doesn't exist yet.
*/
inline std::string MakeCache(SparsePageDMatrix *ptr, std::string format, std::string prefix,
std::map<std::string, std::shared_ptr<Cache>> *out) {
auto &cache_info = *out;
Expand Down
Loading

0 comments on commit d2d01d9

Please sign in to comment.