Skip to content

Commit

Permalink
Pass obj info instead of model parameter.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Mar 9, 2023
1 parent e986afd commit b06792c
Show file tree
Hide file tree
Showing 14 changed files with 80 additions and 80 deletions.
9 changes: 4 additions & 5 deletions include/xgboost/tree_updater.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
#include <vector>

namespace xgboost {
struct LearnerModelParam;
struct ObjInfo;
struct Context;
namespace tree {
struct TrainParam;
Expand Down Expand Up @@ -97,18 +97,17 @@ class TreeUpdater : public Configurable {
* \brief Create a tree updater given name
* \param name Name of the tree updater.
* \param ctx A global runtime parameter
* \param task Infomation about the objective.
*/
static TreeUpdater* Create(const std::string& name, Context const* ctx,
LearnerModelParam const* model);
static TreeUpdater* Create(const std::string& name, Context const* ctx, ObjInfo const* task);
};

/*!
* \brief Registry entry for tree updater.
*/
struct TreeUpdaterReg
: public dmlc::FunctionRegEntryBase<
TreeUpdaterReg,
std::function<TreeUpdater*(Context const* ctx, LearnerModelParam const* model)>> {};
TreeUpdaterReg, std::function<TreeUpdater*(Context const* ctx, ObjInfo const* task)>> {};

/*!
* \brief Macro to register tree updater.
Expand Down
5 changes: 3 additions & 2 deletions src/gbm/gbtree.cc
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ void GBTree::InitUpdater(Args const& cfg) {
// create new updaters
for (const std::string& pstr : ups) {
std::unique_ptr<TreeUpdater> up(
TreeUpdater::Create(pstr.c_str(), ctx_, model_.learner_model_param));
TreeUpdater::Create(pstr.c_str(), ctx_, &model_.learner_model_param->task));
up->Configure(cfg);
updaters_.push_back(std::move(up));
}
Expand Down Expand Up @@ -470,7 +470,8 @@ void GBTree::LoadConfig(Json const& in) {
name = "grow_quantile_histmaker";
LOG(WARNING) << "Changing updater from `grow_gpu_hist` to `grow_quantile_histmaker`.";
}
std::unique_ptr<TreeUpdater> up{TreeUpdater::Create(name, ctx_, model_.learner_model_param)};
std::unique_ptr<TreeUpdater> up{
TreeUpdater::Create(name, ctx_, &model_.learner_model_param->task)};
up->LoadConfig(kv.second);
updaters_.push_back(std::move(up));
}
Expand Down
4 changes: 2 additions & 2 deletions src/tree/tree_updater.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@ DMLC_REGISTRY_ENABLE(::xgboost::TreeUpdaterReg);

namespace xgboost {
TreeUpdater* TreeUpdater::Create(const std::string& name, Context const* ctx,
LearnerModelParam const* model) {
ObjInfo const* task) {
auto* e = ::dmlc::Registry< ::xgboost::TreeUpdaterReg>::Get()->Find(name);
if (e == nullptr) {
LOG(FATAL) << "Unknown tree updater " << name;
}
auto p_updater = (e->body)(ctx, model);
auto p_updater = (e->body)(ctx, task);
return p_updater;
}
} // namespace xgboost
Expand Down
24 changes: 12 additions & 12 deletions src/tree/updater_approx.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,12 @@ auto BatchSpec(TrainParam const &p, common::Span<float> hess) {

class GloablApproxBuilder {
protected:
TrainParam const* param_;
TrainParam const *param_;
std::shared_ptr<common::ColumnSampler> col_sampler_;
HistEvaluator<CPUExpandEntry> evaluator_;
HistogramBuilder<CPUExpandEntry> histogram_builder_;
Context const *ctx_;
LearnerModelParam const* const model_;
ObjInfo const *const task_;

std::vector<CommonRowPartitioner> partitioner_;
// Pointer to last updated tree, used for update prediction cache.
Expand All @@ -64,7 +64,7 @@ class GloablApproxBuilder {
partitioner_.clear();
// Generating the GHistIndexMatrix is quite slow, is there a way to speed it up?
for (auto const &page :
p_fmat->GetBatches<GHistIndexMatrix>(BatchSpec(*param_, hess, model_->task))) {
p_fmat->GetBatches<GHistIndexMatrix>(BatchSpec(*param_, hess, *task_))) {
if (n_total_bins == 0) {
n_total_bins = page.cut.TotalBins();
feature_values_ = page.cut;
Expand Down Expand Up @@ -158,7 +158,7 @@ class GloablApproxBuilder {
void LeafPartition(RegTree const &tree, common::Span<float const> hess,
std::vector<bst_node_t> *p_out_position) {
monitor_->Start(__func__);
if (!model_->task.UpdateTreeLeaf()) {
if (!task_->UpdateTreeLeaf()) {
return;
}
for (auto const &part : partitioner_) {
Expand All @@ -170,12 +170,12 @@ class GloablApproxBuilder {
public:
explicit GloablApproxBuilder(TrainParam const *param, MetaInfo const &info, Context const *ctx,
std::shared_ptr<common::ColumnSampler> column_sampler,
LearnerModelParam const *model, common::Monitor *monitor)
ObjInfo const *task, common::Monitor *monitor)
: param_{param},
col_sampler_{std::move(column_sampler)},
evaluator_{ctx, param_, info, col_sampler_},
ctx_{ctx},
model_{model},
task_{task},
monitor_{monitor} {}

void UpdateTree(DMatrix *p_fmat, std::vector<GradientPair> const &gpair, common::Span<float> hess,
Expand Down Expand Up @@ -257,11 +257,11 @@ class GlobalApproxUpdater : public TreeUpdater {
DMatrix *cached_{nullptr};
std::shared_ptr<common::ColumnSampler> column_sampler_ =
std::make_shared<common::ColumnSampler>();
LearnerModelParam const *model_;
ObjInfo const *task_;

public:
explicit GlobalApproxUpdater(Context const *ctx, LearnerModelParam const *model)
: TreeUpdater(ctx), model_{model} {
explicit GlobalApproxUpdater(Context const *ctx, ObjInfo const *task)
: TreeUpdater(ctx), task_{task} {
monitor_.Init(__func__);
}

Expand All @@ -282,7 +282,7 @@ class GlobalApproxUpdater : public TreeUpdater {
void Update(TrainParam const *param, HostDeviceVector<GradientPair> *gpair, DMatrix *m,
common::Span<HostDeviceVector<bst_node_t>> out_position,
const std::vector<RegTree *> &trees) override {
pimpl_ = std::make_unique<GloablApproxBuilder>(param, m->Info(), ctx_, column_sampler_, model_,
pimpl_ = std::make_unique<GloablApproxBuilder>(param, m->Info(), ctx_, column_sampler_, task_,
&monitor_);

linalg::Matrix<GradientPair> h_gpair;
Expand Down Expand Up @@ -319,7 +319,7 @@ XGBOOST_REGISTER_TREE_UPDATER(GlobalHistMaker, "grow_histmaker")
.describe(
"Tree constructor that uses approximate histogram construction "
"for each node.")
.set_body([](Context const *ctx, LearnerModelParam const *model) {
return new GlobalApproxUpdater(ctx, model);
.set_body([](Context const *ctx, ObjInfo const *task) {
return new GlobalApproxUpdater(ctx, task);
});
} // namespace xgboost::tree
16 changes: 8 additions & 8 deletions src/tree/updater_gpu_hist.cu
Original file line number Diff line number Diff line change
Expand Up @@ -691,7 +691,7 @@ struct GPUHistMakerDevice {
}

void UpdateTree(HostDeviceVector<GradientPair>* gpair_all, DMatrix* p_fmat,
LearnerModelParam const* model, RegTree* p_tree,
ObjInfo const* task, RegTree* p_tree,
collective::DeviceCommunicator* communicator,
HostDeviceVector<bst_node_t>* p_out_position) {
auto& tree = *p_tree;
Expand Down Expand Up @@ -742,7 +742,7 @@ struct GPUHistMakerDevice {
}

monitor.Start("FinalisePosition");
this->FinalisePosition(p_tree, p_fmat, model->task, p_out_position);
this->FinalisePosition(p_tree, p_fmat, *task, p_out_position);
monitor.Stop("FinalisePosition");
}
};
Expand All @@ -751,8 +751,8 @@ class GPUHistMaker : public TreeUpdater {
using GradientSumT = GradientPairPrecise;

public:
explicit GPUHistMaker(Context const* ctx, LearnerModelParam const* model)
: TreeUpdater(ctx), model_{model} {};
explicit GPUHistMaker(Context const* ctx, ObjInfo const* task)
: TreeUpdater(ctx), task_{task} {};
void Configure(const Args& args) override {
// Used in test to count how many configurations are performed
LOG(DEBUG) << "[GPU Hist]: Configure";
Expand Down Expand Up @@ -855,7 +855,7 @@ class GPUHistMaker : public TreeUpdater {

gpair->SetDevice(ctx_->gpu_id);
auto* communicator = collective::Communicator::GetDevice(ctx_->gpu_id);
maker->UpdateTree(gpair, p_fmat, model_, p_tree, communicator, p_out_position);
maker->UpdateTree(gpair, p_fmat, task_, p_tree, communicator, p_out_position);
}

bool UpdatePredictionCache(const DMatrix* data,
Expand Down Expand Up @@ -883,16 +883,16 @@ class GPUHistMaker : public TreeUpdater {

DMatrix* p_last_fmat_{nullptr};
RegTree const* p_last_tree_{nullptr};
LearnerModelParam const* model_{nullptr};
ObjInfo const* task_{nullptr};

common::Monitor monitor_;
};

#if !defined(GTEST_TEST)
XGBOOST_REGISTER_TREE_UPDATER(GPUHistMaker, "grow_gpu_hist")
.describe("Grow tree with GPU.")
.set_body([](Context const* ctx, LearnerModelParam const* model) {
return new GPUHistMaker(ctx, model);
.set_body([](Context const* ctx, ObjInfo const* task) {
return new GPUHistMaker(ctx, task);
});
#endif // !defined(GTEST_TEST)
} // namespace xgboost::tree
8 changes: 4 additions & 4 deletions src/tree/updater_prune.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ DMLC_REGISTRY_FILE_TAG(updater_prune);
/*! \brief pruner that prunes a tree after growing finishes */
class TreePruner : public TreeUpdater {
public:
explicit TreePruner(Context const* ctx, LearnerModelParam const* model) : TreeUpdater(ctx) {
syncher_.reset(TreeUpdater::Create("sync", ctx_, model));
explicit TreePruner(Context const* ctx, ObjInfo const* task) : TreeUpdater(ctx) {
syncher_.reset(TreeUpdater::Create("sync", ctx_, task));
pruner_monitor_.Init("TreePruner");
}
[[nodiscard]] char const* Name() const override { return "prune"; }
Expand Down Expand Up @@ -90,7 +90,7 @@ class TreePruner : public TreeUpdater {

XGBOOST_REGISTER_TREE_UPDATER(TreePruner, "prune")
.describe("Pruner that prune the tree according to statistics.")
.set_body([](Context const* ctx, LearnerModelParam const* model) {
return new TreePruner{ctx, model};
.set_body([](Context const* ctx, ObjInfo const* task) {
return new TreePruner{ctx, task};
});
} // namespace xgboost::tree
31 changes: 16 additions & 15 deletions src/tree/updater_quantile_hist.cc
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ class MultiTargetHistBuilder : public UpdateTreeMixIn<MultiTargetHistBuilder, Mu
// Pointer to last updated tree, used for update prediction cache.
RegTree const *p_last_tree_{nullptr};

LearnerModelParam const *model_;
ObjInfo const *task_;

public:
void UpdatePosition(DMatrix *p_fmat, RegTree const *p_tree,
Expand Down Expand Up @@ -291,7 +291,7 @@ class MultiTargetHistBuilder : public UpdateTreeMixIn<MultiTargetHistBuilder, Mu
void LeafPartition(RegTree const &tree, linalg::MatrixView<GradientPair const> gpair,
std::vector<bst_node_t> *p_out_position) {
monitor_->Start(__func__);
if (!model_->task.UpdateTreeLeaf()) {
if (!task_->UpdateTreeLeaf()) {
return;
}
for (auto const &part : partitioner_) {
Expand All @@ -303,27 +303,27 @@ class MultiTargetHistBuilder : public UpdateTreeMixIn<MultiTargetHistBuilder, Mu
public:
explicit MultiTargetHistBuilder(Context const *ctx, MetaInfo const &info, TrainParam const *param,
std::shared_ptr<common::ColumnSampler> column_sampler,
LearnerModelParam const *model, common::Monitor *monitor)
ObjInfo const *task, common::Monitor *monitor)
: UpdateTreeMixIn<MultiTargetHistBuilder, MultiExpandEntry>{param, monitor},
col_sampler_{std::move(column_sampler)},
evaluator_{std::make_unique<HistMultiEvaluator>(ctx, info, param, col_sampler_)},
ctx_{ctx},
model_{model} {}
task_{task} {}
};

struct HistBuilder : public UpdateTreeMixIn<HistBuilder, CPUExpandEntry> {
public:
// constructor
explicit HistBuilder(Context const *ctx, std::shared_ptr<common::ColumnSampler> column_sampler,
TrainParam const *param, DMatrix const *fmat, LearnerModelParam const *model,
TrainParam const *param, DMatrix const *fmat, ObjInfo const *task,
common::Monitor *monitor)
: UpdateTreeMixIn<HistBuilder, CPUExpandEntry>{param, monitor},
col_sampler_{std::move(column_sampler)},
evaluator_{std::make_unique<HistEvaluator<CPUExpandEntry>>(ctx, param, fmat->Info(),
col_sampler_)},
p_last_fmat_(fmat),
histogram_builder_{new HistogramBuilder<CPUExpandEntry>},
model_{model},
task_{task},
ctx_{ctx} {
monitor_->Init("Quantile::Builder");
}
Expand Down Expand Up @@ -489,7 +489,7 @@ struct HistBuilder : public UpdateTreeMixIn<HistBuilder, CPUExpandEntry> {
void LeafPartition(RegTree const &tree, linalg::MatrixView<GradientPair const> gpair,
std::vector<bst_node_t> *p_out_position) {
monitor_->Start(__func__);
if (!model_->task.UpdateTreeLeaf()) {
if (!task_->UpdateTreeLeaf()) {
return;
}
for (auto const &part : partitioner_) {
Expand All @@ -508,7 +508,7 @@ struct HistBuilder : public UpdateTreeMixIn<HistBuilder, CPUExpandEntry> {
DMatrix const *const p_last_fmat_;

std::unique_ptr<HistogramBuilder<CPUExpandEntry>> histogram_builder_;
LearnerModelParam const *model_;
ObjInfo const *task_;
// Context for number of threads
Context const *ctx_;
};
Expand All @@ -520,11 +520,11 @@ class QuantileHistMaker : public TreeUpdater {
std::shared_ptr<common::ColumnSampler> column_sampler_ =
std::make_shared<common::ColumnSampler>();
common::Monitor monitor_;
LearnerModelParam const *model_;
ObjInfo const *task_;

public:
explicit QuantileHistMaker(Context const *ctx, LearnerModelParam const *model)
: TreeUpdater{ctx}, model_{model} {}
explicit QuantileHistMaker(Context const *ctx, ObjInfo const *task)
: TreeUpdater{ctx}, task_{task} {}
void Configure(const Args &) override {}

void LoadConfig(Json const &) override {}
Expand All @@ -539,11 +539,12 @@ class QuantileHistMaker : public TreeUpdater {
CHECK(param->monotone_constraints.empty()) << "monotone constraint" << MTNotImplemented();
if (!p_mtimpl_) {
this->p_mtimpl_ = std::make_unique<MultiTargetHistBuilder>(
ctx_, p_fmat->Info(), param, column_sampler_, model_, &monitor_);
ctx_, p_fmat->Info(), param, column_sampler_, task_, &monitor_);
}
} else {
if (!p_impl_) {
p_impl_.reset(new HistBuilder(ctx_, column_sampler_, param, p_fmat, model_, &monitor_));
p_impl_ =
std::make_unique<HistBuilder>(ctx_, column_sampler_, param, p_fmat, task_, &monitor_);
}
}

Expand Down Expand Up @@ -591,7 +592,7 @@ class QuantileHistMaker : public TreeUpdater {

XGBOOST_REGISTER_TREE_UPDATER(QuantileHistMaker, "grow_quantile_histmaker")
.describe("Grow tree using quantized histogram.")
.set_body([](Context const *ctx, LearnerModelParam const *model) {
return new QuantileHistMaker{ctx, model};
.set_body([](Context const *ctx, ObjInfo const *task) {
return new QuantileHistMaker{ctx, task};
});
} // namespace xgboost::tree
12 changes: 6 additions & 6 deletions tests/cpp/tree/test_gpu_hist.cu
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,8 @@ void TestHistogramIndexImpl() {

// Build 2 matrices and build a histogram maker with that
Context ctx(CreateEmptyGenericParam(0));
LearnerModelParam model;
tree::GPUHistMaker hist_maker{&ctx, &model}, hist_maker_ext{&ctx, &model};
ObjInfo task{ObjInfo::kRegression};
tree::GPUHistMaker hist_maker{&ctx, &task}, hist_maker_ext{&ctx, &task};
std::unique_ptr<DMatrix> hist_maker_dmat(
CreateSparsePageDMatrixWithRC(kNRows, kNCols, 0, true));

Expand Down Expand Up @@ -240,8 +240,8 @@ void UpdateTree(HostDeviceVector<GradientPair>* gpair, DMatrix* dmat,
param.UpdateAllowUnknown(args);

Context ctx(CreateEmptyGenericParam(0));
LearnerModelParam model;
tree::GPUHistMaker hist_maker{&ctx, &model};
ObjInfo task{ObjInfo::kRegression};
tree::GPUHistMaker hist_maker{&ctx, &task};

std::vector<HostDeviceVector<bst_node_t>> position(1);
hist_maker.Update(&param, gpair, dmat, common::Span<HostDeviceVector<bst_node_t>>{position},
Expand Down Expand Up @@ -386,8 +386,8 @@ TEST(GpuHist, ExternalMemoryWithSampling) {

TEST(GpuHist, ConfigIO) {
Context ctx(CreateEmptyGenericParam(0));
LearnerModelParam model;
std::unique_ptr<TreeUpdater> updater{TreeUpdater::Create("grow_gpu_hist", &ctx, &model)};
ObjInfo task{ObjInfo::kRegression};
std::unique_ptr<TreeUpdater> updater{TreeUpdater::Create("grow_gpu_hist", &ctx, &task)};
updater->Configure(Args{});

Json j_updater { Object() };
Expand Down
Loading

0 comments on commit b06792c

Please sign in to comment.