Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pass obj info by reference instead of by value. #8889

Merged
merged 1 commit into from
Mar 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 21 additions & 21 deletions include/xgboost/tree_updater.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*!
* Copyright 2014-2022 by XGBoost Contributors
/**
* Copyright 2014-2023 by XGBoost Contributors
* \file tree_updater.h
* \brief General primitive for tree learning,
* Updating a collection of trees given the information.
Expand All @@ -9,19 +9,17 @@
#define XGBOOST_TREE_UPDATER_H_

#include <dmlc/registry.h>
#include <xgboost/base.h>
#include <xgboost/context.h>
#include <xgboost/data.h>
#include <xgboost/host_device_vector.h>
#include <xgboost/linalg.h>
#include <xgboost/model.h>
#include <xgboost/task.h>
#include <xgboost/tree_model.h>
#include <xgboost/base.h> // for Args, GradientPair
#include <xgboost/data.h> // DMatrix
#include <xgboost/host_device_vector.h> // for HostDeviceVector
#include <xgboost/linalg.h> // for VectorView
#include <xgboost/model.h> // for Configurable
#include <xgboost/span.h> // for Span
#include <xgboost/tree_model.h> // for RegTree

#include <functional>
#include <string>
#include <utility>
#include <vector>
#include <functional> // for function
#include <string> // for string
#include <vector> // for vector

namespace xgboost {
namespace tree {
Expand All @@ -30,8 +28,9 @@ struct TrainParam;

class Json;
struct Context;
struct ObjInfo;

/*!
/**
* \brief interface of tree update module, that performs update of a tree.
*/
class TreeUpdater : public Configurable {
Expand All @@ -53,12 +52,12 @@ class TreeUpdater : public Configurable {
* used for modifying existing trees (like `prune`). Return true if it can modify
* existing trees.
*/
virtual bool CanModifyTree() const { return false; }
[[nodiscard]] virtual bool CanModifyTree() const { return false; }
/*!
* \brief Wether the out_position in `Update` is valid. This determines whether adaptive
* tree can be used.
*/
virtual bool HasNodePosition() const { return false; }
[[nodiscard]] virtual bool HasNodePosition() const { return false; }
/**
* \brief perform update to the tree models
*
Expand Down Expand Up @@ -91,22 +90,23 @@ class TreeUpdater : public Configurable {
return false;
}

virtual char const* Name() const = 0;
[[nodiscard]] virtual char const* Name() const = 0;

/*!
/**
* \brief Create a tree updater given name
* \param name Name of the tree updater.
* \param ctx A global runtime parameter
* \param task Infomation about the objective.
*/
static TreeUpdater* Create(const std::string& name, Context const* ctx, ObjInfo task);
static TreeUpdater* Create(const std::string& name, Context const* ctx, ObjInfo const* task);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
static TreeUpdater* Create(const std::string& name, Context const* ctx, ObjInfo const* task);
static TreeUpdater* Create(const std::string& name, Context const* ctx, ObjInfo const& task);

Have you considered using a constant reference instead? I don't see a reason for allowing a nullptr here.

Copy link
Member Author

@trivialfis trivialfis Mar 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

using a pointer can ask the compiler to check whether I'm accidentally passing it by value. If we change it to reference, we need to disable the default copy constructor and assignment operator, then define an explicit copy method to have the same result, which seems to be a little too much for a simple struct.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, then pointer is fine.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hcho3 Could you please approve the PR if there are no other change requests?

};

/*!
* \brief Registry entry for tree updater.
*/
struct TreeUpdaterReg
: public dmlc::FunctionRegEntryBase<
TreeUpdaterReg, std::function<TreeUpdater*(Context const* ctx, ObjInfo task)>> {};
TreeUpdaterReg, std::function<TreeUpdater*(Context const* ctx, ObjInfo const* task)>> {};

/*!
* \brief Macro to register tree updater.
Expand Down
4 changes: 2 additions & 2 deletions src/gbm/gbtree.cc
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ void GBTree::InitUpdater(Args const& cfg) {
// create new updaters
for (const std::string& pstr : ups) {
std::unique_ptr<TreeUpdater> up(
TreeUpdater::Create(pstr.c_str(), ctx_, model_.learner_model_param->task));
TreeUpdater::Create(pstr.c_str(), ctx_, &model_.learner_model_param->task));
up->Configure(cfg);
updaters_.push_back(std::move(up));
}
Expand Down Expand Up @@ -448,7 +448,7 @@ void GBTree::LoadConfig(Json const& in) {
LOG(WARNING) << "Changing updater from `grow_gpu_hist` to `grow_quantile_histmaker`.";
}
std::unique_ptr<TreeUpdater> up{
TreeUpdater::Create(name, ctx_, model_.learner_model_param->task)};
TreeUpdater::Create(name, ctx_, &model_.learner_model_param->task)};
up->LoadConfig(kv.second);
updaters_.push_back(std::move(up));
}
Expand Down
19 changes: 8 additions & 11 deletions src/tree/tree_updater.cc
Original file line number Diff line number Diff line change
@@ -1,32 +1,30 @@
/*!
* Copyright 2015-2022 by XGBoost Contributors
/**
* Copyright 2015-2023 by XGBoost Contributors
* \file tree_updater.cc
* \brief Registry of tree updaters.
*/
#include "xgboost/tree_updater.h"

#include <dmlc/registry.h>

#include "xgboost/tree_updater.h"
#include "xgboost/host_device_vector.h"
#include <string> // for string

namespace dmlc {
DMLC_REGISTRY_ENABLE(::xgboost::TreeUpdaterReg);
} // namespace dmlc

namespace xgboost {

TreeUpdater* TreeUpdater::Create(const std::string& name, Context const* ctx, ObjInfo task) {
TreeUpdater* TreeUpdater::Create(const std::string& name, Context const* ctx, ObjInfo const* task) {
auto* e = ::dmlc::Registry< ::xgboost::TreeUpdaterReg>::Get()->Find(name);
if (e == nullptr) {
LOG(FATAL) << "Unknown tree updater " << name;
}
auto p_updater = (e->body)(ctx, task);
return p_updater;
}

} // namespace xgboost

namespace xgboost {
namespace tree {
namespace xgboost::tree {
// List of files that will be force linked in static links.
DMLC_REGISTRY_LINK_TAG(updater_colmaker);
DMLC_REGISTRY_LINK_TAG(updater_refresh);
Expand All @@ -37,5 +35,4 @@ DMLC_REGISTRY_LINK_TAG(updater_sync);
#ifdef XGBOOST_USE_CUDA
DMLC_REGISTRY_LINK_TAG(updater_gpu_hist);
#endif // XGBOOST_USE_CUDA
} // namespace tree
} // namespace xgboost
} // namespace xgboost::tree
27 changes: 16 additions & 11 deletions src/tree/updater_approx.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,15 @@
#include "driver.h"
#include "hist/evaluate_splits.h"
#include "hist/histogram.h"
#include "hist/sampler.h" // SampleGradient
#include "hist/sampler.h" // for SampleGradient
#include "param.h"
#include "xgboost/base.h"
#include "xgboost/data.h"
#include "xgboost/json.h"
#include "xgboost/linalg.h"
#include "xgboost/task.h" // for ObjInfo
#include "xgboost/tree_model.h"
#include "xgboost/tree_updater.h"
#include "xgboost/tree_updater.h" // for TreeUpdater

namespace xgboost::tree {

Expand All @@ -40,12 +41,12 @@ auto BatchSpec(TrainParam const &p, common::Span<float> hess) {

class GloablApproxBuilder {
protected:
TrainParam const* param_;
TrainParam const *param_;
std::shared_ptr<common::ColumnSampler> col_sampler_;
HistEvaluator<CPUExpandEntry> evaluator_;
HistogramBuilder<CPUExpandEntry> histogram_builder_;
Context const *ctx_;
ObjInfo const task_;
ObjInfo const *const task_;

std::vector<CommonRowPartitioner> partitioner_;
// Pointer to last updated tree, used for update prediction cache.
Expand All @@ -63,7 +64,8 @@ class GloablApproxBuilder {
bst_bin_t n_total_bins = 0;
partitioner_.clear();
// Generating the GHistIndexMatrix is quite slow, is there a way to speed it up?
for (auto const &page : p_fmat->GetBatches<GHistIndexMatrix>(BatchSpec(*param_, hess, task_))) {
for (auto const &page :
p_fmat->GetBatches<GHistIndexMatrix>(BatchSpec(*param_, hess, *task_))) {
if (n_total_bins == 0) {
n_total_bins = page.cut.TotalBins();
feature_values_ = page.cut;
Expand Down Expand Up @@ -157,7 +159,7 @@ class GloablApproxBuilder {
void LeafPartition(RegTree const &tree, common::Span<float const> hess,
std::vector<bst_node_t> *p_out_position) {
monitor_->Start(__func__);
if (!task_.UpdateTreeLeaf()) {
if (!task_->UpdateTreeLeaf()) {
return;
}
for (auto const &part : partitioner_) {
Expand All @@ -168,8 +170,8 @@ class GloablApproxBuilder {

public:
explicit GloablApproxBuilder(TrainParam const *param, MetaInfo const &info, Context const *ctx,
std::shared_ptr<common::ColumnSampler> column_sampler, ObjInfo task,
common::Monitor *monitor)
std::shared_ptr<common::ColumnSampler> column_sampler,
ObjInfo const *task, common::Monitor *monitor)
: param_{param},
col_sampler_{std::move(column_sampler)},
evaluator_{ctx, param_, info, col_sampler_},
Expand Down Expand Up @@ -256,10 +258,11 @@ class GlobalApproxUpdater : public TreeUpdater {
DMatrix *cached_{nullptr};
std::shared_ptr<common::ColumnSampler> column_sampler_ =
std::make_shared<common::ColumnSampler>();
ObjInfo task_;
ObjInfo const *task_;

public:
explicit GlobalApproxUpdater(Context const *ctx, ObjInfo task) : TreeUpdater(ctx), task_{task} {
explicit GlobalApproxUpdater(Context const *ctx, ObjInfo const *task)
: TreeUpdater(ctx), task_{task} {
monitor_.Init(__func__);
}

Expand Down Expand Up @@ -317,5 +320,7 @@ XGBOOST_REGISTER_TREE_UPDATER(GlobalHistMaker, "grow_histmaker")
.describe(
"Tree constructor that uses approximate histogram construction "
"for each node.")
.set_body([](Context const *ctx, ObjInfo task) { return new GlobalApproxUpdater(ctx, task); });
.set_body([](Context const *ctx, ObjInfo const *task) {
return new GlobalApproxUpdater(ctx, task);
});
} // namespace xgboost::tree
2 changes: 1 addition & 1 deletion src/tree/updater_colmaker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -603,5 +603,5 @@ class ColMaker: public TreeUpdater {

XGBOOST_REGISTER_TREE_UPDATER(ColMaker, "grow_colmaker")
.describe("Grow tree with parallelization over columns.")
.set_body([](Context const *ctx, ObjInfo) { return new ColMaker(ctx); });
.set_body([](Context const *ctx, auto) { return new ColMaker(ctx); });
} // namespace xgboost::tree
36 changes: 18 additions & 18 deletions src/tree/updater_gpu_hist.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@
#include "../collective/device_communicator.cuh"
#include "../common/bitfield.h"
#include "../common/categorical.h"
#include "../common/cuda_context.cuh" // CUDAContext
#include "../common/device_helpers.cuh"
#include "../common/hist_util.h"
#include "../common/io.h"
#include "../common/timer.h"
#include "../data/ellpack_page.cuh"
#include "../common/cuda_context.cuh" // CUDAContext
#include "constraints.cuh"
#include "driver.h"
#include "gpu_hist/evaluate_splits.cuh"
Expand All @@ -39,11 +39,10 @@
#include "xgboost/json.h"
#include "xgboost/parameter.h"
#include "xgboost/span.h"
#include "xgboost/task.h"
#include "xgboost/task.h" // for ObjInfo
#include "xgboost/tree_model.h"

namespace xgboost {
namespace tree {
namespace xgboost::tree {
#if !defined(GTEST_TEST)
DMLC_REGISTRY_FILE_TAG(updater_gpu_hist);
#endif // !defined(GTEST_TEST)
Expand Down Expand Up @@ -106,12 +105,12 @@ class DeviceHistogramStorage {
nidx_map_.clear();
overflow_nidx_map_.clear();
}
bool HistogramExists(int nidx) const {
[[nodiscard]] bool HistogramExists(int nidx) const {
return nidx_map_.find(nidx) != nidx_map_.cend() ||
overflow_nidx_map_.find(nidx) != overflow_nidx_map_.cend();
}
int Bins() const { return n_bins_; }
size_t HistogramSize() const { return n_bins_ * kNumItemsInGradientSum; }
[[nodiscard]] int Bins() const { return n_bins_; }
[[nodiscard]] size_t HistogramSize() const { return n_bins_ * kNumItemsInGradientSum; }
dh::device_vector<typename GradientSumT::ValueT>& Data() { return data_; }

void AllocateHistograms(const std::vector<int>& new_nidxs) {
Expand Down Expand Up @@ -690,8 +689,9 @@ struct GPUHistMakerDevice {
return root_entry;
}

void UpdateTree(HostDeviceVector<GradientPair>* gpair_all, DMatrix* p_fmat, ObjInfo task,
RegTree* p_tree, collective::DeviceCommunicator* communicator,
void UpdateTree(HostDeviceVector<GradientPair>* gpair_all, DMatrix* p_fmat,
ObjInfo const* task, RegTree* p_tree,
collective::DeviceCommunicator* communicator,
HostDeviceVector<bst_node_t>* p_out_position) {
auto& tree = *p_tree;
// Process maximum 32 nodes at a time
Expand Down Expand Up @@ -741,7 +741,7 @@ struct GPUHistMakerDevice {
}

monitor.Start("FinalisePosition");
this->FinalisePosition(p_tree, p_fmat, task, p_out_position);
this->FinalisePosition(p_tree, p_fmat, *task, p_out_position);
monitor.Stop("FinalisePosition");
}
};
Expand All @@ -750,7 +750,7 @@ class GPUHistMaker : public TreeUpdater {
using GradientSumT = GradientPairPrecise;

public:
explicit GPUHistMaker(Context const* ctx, ObjInfo task)
explicit GPUHistMaker(Context const* ctx, ObjInfo const* task)
: TreeUpdater(ctx), task_{task} {};
void Configure(const Args& args) override {
// Used in test to count how many configurations are performed
Expand Down Expand Up @@ -872,8 +872,8 @@ class GPUHistMaker : public TreeUpdater {

std::unique_ptr<GPUHistMakerDevice<GradientSumT>> maker; // NOLINT

char const* Name() const override { return "grow_gpu_hist"; }
bool HasNodePosition() const override { return true; }
[[nodiscard]] char const* Name() const override { return "grow_gpu_hist"; }
[[nodiscard]] bool HasNodePosition() const override { return true; }

private:
bool initialised_{false};
Expand All @@ -882,16 +882,16 @@ class GPUHistMaker : public TreeUpdater {

DMatrix* p_last_fmat_{nullptr};
RegTree const* p_last_tree_{nullptr};
ObjInfo task_;
ObjInfo const* task_{nullptr};

common::Monitor monitor_;
};

#if !defined(GTEST_TEST)
XGBOOST_REGISTER_TREE_UPDATER(GPUHistMaker, "grow_gpu_hist")
.describe("Grow tree with GPU.")
.set_body([](Context const* ctx, ObjInfo task) { return new GPUHistMaker(ctx, task); });
.set_body([](Context const* ctx, ObjInfo const* task) {
return new GPUHistMaker(ctx, task);
});
#endif // !defined(GTEST_TEST)

} // namespace tree
} // namespace xgboost
} // namespace xgboost::tree
6 changes: 4 additions & 2 deletions src/tree/updater_prune.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ DMLC_REGISTRY_FILE_TAG(updater_prune);
/*! \brief pruner that prunes a tree after growing finishes */
class TreePruner : public TreeUpdater {
public:
explicit TreePruner(Context const* ctx, ObjInfo task) : TreeUpdater(ctx) {
explicit TreePruner(Context const* ctx, ObjInfo const* task) : TreeUpdater(ctx) {
syncher_.reset(TreeUpdater::Create("sync", ctx_, task));
pruner_monitor_.Init("TreePruner");
}
Expand Down Expand Up @@ -90,5 +90,7 @@ class TreePruner : public TreeUpdater {

XGBOOST_REGISTER_TREE_UPDATER(TreePruner, "prune")
.describe("Pruner that prune the tree according to statistics.")
.set_body([](Context const* ctx, ObjInfo task) { return new TreePruner(ctx, task); });
.set_body([](Context const* ctx, ObjInfo const* task) {
return new TreePruner{ctx, task};
});
} // namespace xgboost::tree
6 changes: 4 additions & 2 deletions src/tree/updater_quantile_hist.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ void QuantileHistMaker::Update(TrainParam const *param, HostDeviceVector<Gradien
// build tree
const size_t n_trees = trees.size();
if (!pimpl_) {
pimpl_.reset(new Builder(n_trees, param, dmat, task_, ctx_));
pimpl_.reset(new Builder(n_trees, param, dmat, *task_, ctx_));
}

size_t t_idx{0};
Expand Down Expand Up @@ -287,6 +287,8 @@ void QuantileHistMaker::Builder::InitData(DMatrix *fmat, const RegTree &tree,

XGBOOST_REGISTER_TREE_UPDATER(QuantileHistMaker, "grow_quantile_histmaker")
.describe("Grow tree using quantized histogram.")
.set_body([](Context const *ctx, ObjInfo task) { return new QuantileHistMaker(ctx, task); });
.set_body([](Context const *ctx, ObjInfo const *task) {
return new QuantileHistMaker(ctx, task);
});
} // namespace tree
} // namespace xgboost
Loading