Skip to content

Commit

Permalink
configuration.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Mar 6, 2023
1 parent 4d5ff74 commit 0961133
Show file tree
Hide file tree
Showing 10 changed files with 91 additions and 77 deletions.
6 changes: 2 additions & 4 deletions demo/guide-python/multioutput_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,8 @@ def rmse_model(plot_result: bool):
subsample=0.6,
)
reg.fit(X, y, eval_set=[(X, y)])
# reg.save_model("model.json")

y_predt = reg.predict(X)
# print("y_predt:", y_predt, y)
if plot_result:
plot_predt(y, y_predt, "multi")

Expand Down Expand Up @@ -90,13 +88,13 @@ def rmse(predt: np.ndarray, dtrain: xgb.DMatrix) -> Tuple[str, float]:
X, y = gen_circle()
Xy = xgb.DMatrix(X, y)
results: Dict[str, Dict[str, List[float]]] = {}
# Make sure the `num_class` is passed to XGBoost when custom objective is used.
# Make sure the `num_target` is passed to XGBoost when custom objective is used.
# When builtin objective is used, XGBoost can figure out the number of targets
# automatically.
booster = xgb.train(
{
"tree_method": "hist",
"num_class": y.shape[1],
"num_target": y.shape[1],
"multi_strategy": "mono",
"objective": "reg:squarederror", # fixme
},
Expand Down
29 changes: 22 additions & 7 deletions include/xgboost/learner.h
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,13 @@ class Learner : public Model, public Configurable, public dmlc::Serializable {
};

struct LearnerModelParamLegacy;
/**
* \brief Strategy for building multi-target models.
*/
enum class Strategy : std::int32_t {
kComposite = 0,
kMono = 1,
};

/*
* \brief Basic Model Parameters, used to describe the booster.
Expand Down Expand Up @@ -305,6 +312,10 @@ struct LearnerModelParam {
* \brief Current task, determined by objective.
*/
ObjInfo task{ObjInfo::kRegression};
/**
* \brief Strategy for building multi-target models.
*/
Strategy multi_strategy{Strategy::kComposite};

LearnerModelParam() = default;
// As the old `LearnerModelParamLegacy` is still used by binary IO, we keep
Expand All @@ -313,23 +324,27 @@ struct LearnerModelParam {
linalg::Tensor<float, 1> base_margin, ObjInfo t);
LearnerModelParam(LearnerModelParamLegacy const& user_param, ObjInfo t);
LearnerModelParam(bst_feature_t n_features, linalg::Tensor<float, 1> base_margin,
uint32_t n_groups, bst_target_t n_targets)
uint32_t n_groups, bst_target_t n_targets, Strategy multi_strategy)
: base_score_{std::move(base_margin)},
num_feature{n_features},
num_output_group{n_groups},
num_target{n_targets} {}
num_target{n_targets},
multi_strategy{multi_strategy} {}

linalg::TensorView<float const, 1> BaseScore(Context const* ctx) const;
linalg::TensorView<float const, 1> BaseScore(int32_t device) const;
[[nodiscard]] linalg::TensorView<float const, 1> BaseScore(int32_t device) const;

void Copy(LearnerModelParam const& that);
bool IsVectorLeaf() const { return num_output_group == 1 && num_target > 1; }
bst_target_t OutputLength() const {
return this->IsVectorLeaf() ? this->num_target : this->num_output_group;
[[nodiscard]] bool IsVectorLeaf() const noexcept { return multi_strategy == Strategy::kMono; }
[[nodiscard]] bst_target_t OutputLength() const noexcept {
return this->num_target * this->num_output_group;
}
[[nodiscard]] bst_target_t LeafLength() const noexcept {
return this->IsVectorLeaf() ? this->OutputLength() : 1;
}

/* \brief Whether this parameter is initialized with LearnerModelParamLegacy. */
bool Initialized() const { return num_feature != 0 && num_output_group != 0; }
[[nodiscard]] bool Initialized() const { return num_feature != 0 && num_output_group != 0; }
};

} // namespace xgboost
Expand Down
10 changes: 5 additions & 5 deletions include/xgboost/tree_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -493,20 +493,20 @@ class RegTree : public Model {
linalg::VectorView<float const> left_weight,
linalg::VectorView<float const> right_weight);

bool HasCategoricalSplit() const {
[[nodiscard]] bool HasCategoricalSplit() const {
return !split_categories_.empty();
}
/**
* \brief Whether this is a multi-target tree.
*/
bool IsMultiTarget() const { return static_cast<bool>(p_mt_tree_); }
bst_target_t NumTargets() const {
[[nodiscard]] bool IsMultiTarget() const { return static_cast<bool>(p_mt_tree_); }
[[nodiscard]] bst_target_t NumTargets() const {
if (IsMultiTarget()) {
return this->p_mt_tree_->NumTargets();
}
return 1;
}
auto GetMultiTargetTree() const {
[[nodiscard]] auto GetMultiTargetTree() const {
CHECK(IsMultiTarget());
return p_mt_tree_.get();
}
Expand All @@ -515,7 +515,7 @@ class RegTree : public Model {
* \brief get current depth
* \param nid node id
*/
int GetDepth(int nid) const {
[[nodiscard]] std::int32_t GetDepth(bst_node_t nid) const {
if (IsMultiTarget()) {
return this->p_mt_tree_->Depth(nid);
}
Expand Down
29 changes: 14 additions & 15 deletions src/gbm/gbtree.cc
Original file line number Diff line number Diff line change
Expand Up @@ -259,10 +259,10 @@ void GBTree::DoBoost(DMatrix* p_fmat, HostDeviceVector<GradientPair>* in_gpair,
// `gpu_id` be the single source of determining what algorithms to run, but that will
// break a lots of existing code.
auto device = tparam_.tree_method != TreeMethod::kGPUHist ? Context::kCpuId : ctx_->gpu_id;
auto out = linalg::TensorView<float, 2>{
auto out = linalg::MakeTensorView(
ctx_,
device == Context::kCpuId ? predt->predictions.HostSpan() : predt->predictions.DeviceSpan(),
{static_cast<size_t>(p_fmat->Info().num_row_), static_cast<size_t>(ngroup)},
device};
p_fmat->Info().num_row_, static_cast<size_t>(model_.learner_model_param->OutputLength()));
CHECK_NE(ngroup, 0);

if (!p_fmat->SingleColBlock() && obj->Task().UpdateTreeLeaf()) {
Expand All @@ -279,7 +279,7 @@ void GBTree::DoBoost(DMatrix* p_fmat, HostDeviceVector<GradientPair>* in_gpair,
UpdateTreeLeaf(p_fmat, predt->predictions, obj, 1, node_position, &ret);
// No update prediction cache yet.
new_trees.push_back(std::move(ret));
} else if (ngroup == 1) {
} else if (model_.learner_model_param->OutputLength() == 1) {
std::vector<std::unique_ptr<RegTree>> ret;
BoostNewTrees(in_gpair, p_fmat, 0, &node_position, &ret);
UpdateTreeLeaf(p_fmat, predt->predictions, obj, 0, node_position, &ret);
Expand Down Expand Up @@ -372,8 +372,8 @@ void GBTree::BoostNewTrees(HostDeviceVector<GradientPair>* gpair, DMatrix* p_fma
<< "can not be used to create new trees. "
<< "Set `process_type` to `update` if you want to update existing "
"trees.";
// create new tree
std::unique_ptr<RegTree> ptr(new RegTree{this->model_.learner_model_param->num_target,
// create new tree.
std::unique_ptr<RegTree> ptr(new RegTree{this->model_.learner_model_param->LeafLength(),
this->model_.learner_model_param->num_feature});
ptr->param.UpdateAllowUnknown(this->cfg_);
new_trees.push_back(ptr.get());
Expand All @@ -397,13 +397,8 @@ void GBTree::BoostNewTrees(HostDeviceVector<GradientPair>* gpair, DMatrix* p_fma
}

// update the trees
if (model_.learner_model_param->IsVectorLeaf()) {
CHECK_EQ(gpair->Size(), model_.learner_model_param->num_target * p_fmat->Info().num_row_);
} else {
CHECK_EQ(gpair->Size(), p_fmat->Info().num_row_)
<< "Mismatching size between number of rows from input data and size of "
"gradient vector.";
}
CHECK_EQ(gpair->Size(), model_.learner_model_param->OutputLength() * p_fmat->Info().num_row_)
<< "Mismatching size between number of rows from input data and size of gradient vector.";

out_position->resize(new_trees.size());
for (auto& up : updaters_) {
Expand All @@ -413,8 +408,12 @@ void GBTree::BoostNewTrees(HostDeviceVector<GradientPair>* gpair, DMatrix* p_fma

void GBTree::CommitModel(std::vector<std::vector<std::unique_ptr<RegTree>>>&& new_trees) {
monitor_.Start("CommitModel");
for (uint32_t gid = 0; gid < model_.learner_model_param->num_output_group; ++gid) {
model_.CommitModel(std::move(new_trees[gid]), gid);
if (this->model_.learner_model_param->IsVectorLeaf()) {
model_.CommitModel(std::move(new_trees[0]), 0);
} else {
for (std::uint32_t gid = 0; gid < model_.learner_model_param->OutputLength(); ++gid) {
model_.CommitModel(std::move(new_trees[gid]), gid);
}
}
monitor_.Stop("CommitModel");
}
Expand Down
18 changes: 9 additions & 9 deletions src/gbm/gbtree.h
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ inline std::pair<uint32_t, uint32_t> LayerToTree(gbm::GBTreeModel const& model,
tree_begin = layer_begin * model.param.num_parallel_tree;
tree_end = layer_end * model.param.num_parallel_tree;
} else {
bst_group_t groups = model.learner_model_param->num_output_group;
bst_group_t groups = model.learner_model_param->OutputLength();
tree_begin = layer_begin * groups * model.param.num_parallel_tree;
tree_end = layer_end * groups * model.param.num_parallel_tree;
}
Expand Down Expand Up @@ -248,25 +248,25 @@ class GBTree : public GradientBooster {
void LoadModel(Json const& in) override;

// Number of trees per layer.
auto LayerTrees() const {
auto n_trees = model_.learner_model_param->num_output_group * model_.param.num_parallel_tree;
return n_trees;
[[nodiscard]] std::uint32_t LayerTrees() const {
if (model_.learner_model_param->IsVectorLeaf()) {
return model_.param.num_parallel_tree;
}
return model_.param.num_parallel_tree * model_.learner_model_param->OutputLength();
}

// slice the trees, out must be already allocated
void Slice(int32_t layer_begin, int32_t layer_end, int32_t step,
GradientBooster *out, bool* out_of_bound) const override;

int32_t BoostedRounds() const override {
[[nodiscard]] std::int32_t BoostedRounds() const override {
CHECK_NE(model_.param.num_parallel_tree, 0);
CHECK_NE(model_.learner_model_param->num_output_group, 0);

return this->model_.learner_model_param->IsVectorLeaf()
? (model_.trees.size() / model_.param.num_parallel_tree)
: (model_.trees.size() / this->LayerTrees());
return model_.trees.size() / this->LayerTrees();
}

bool ModelFitted() const override {
[[nodiscard]] bool ModelFitted() const override {
return !model_.trees.empty() || !model_.trees_to_update.empty();
}

Expand Down
4 changes: 2 additions & 2 deletions src/gbm/gbtree_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,8 @@ void GBTreeModel::LoadModel(Json const& in) {
common::ParallelFor(trees_json.size(), ctx_->Threads(), [&](auto t) {
auto tree_id = get<Integer>(trees_json[t]["id"]);
CHECK(this->learner_model_param->Initialized());
trees.at(tree_id).reset(
new RegTree{this->learner_model_param->num_target, this->learner_model_param->num_feature});
trees.at(tree_id).reset(new RegTree{this->learner_model_param->LeafLength(),
this->learner_model_param->num_feature});
trees.at(tree_id)->LoadModel(trees_json[t]);
});

Expand Down
46 changes: 29 additions & 17 deletions src/learner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,6 @@ const char* kMaxDeltaStepDefaultValue = "0.7";
} // anonymous namespace

namespace xgboost {
enum class Strategy : std::int32_t {
kComposite = 0,
kMono = 1,
};

std::string StrategyStr(Strategy s) { return s == Strategy::kComposite ? "compo" : "mono"; }
} // namespace xgboost
DECLARE_FIELD_ENUM_CLASS(xgboost::Strategy);
Expand Down Expand Up @@ -98,6 +93,10 @@ struct LearnerModelParamLegacy : public dmlc::Parameter<LearnerModelParamLegacy>
/*! \brief the version of XGBoost. */
std::uint32_t major_version;
std::uint32_t minor_version;
/**
* \brief Number of target variables.
*/
std::int32_t num_target;
/*! \brief Number of output targets, 1 if the strategy is composite. */
Strategy multi_strategy;
/**
Expand All @@ -111,14 +110,15 @@ struct LearnerModelParamLegacy : public dmlc::Parameter<LearnerModelParamLegacy>
*/
std::int32_t boost_from_average{true};
/*! \brief reserved field */
int reserved[25];
int reserved[24];
/*! \brief constructor */
LearnerModelParamLegacy() {
std::memset(this, 0, sizeof(LearnerModelParamLegacy));
base_score = ObjFunction::DefaultBaseScore();
multi_strategy = Strategy::kComposite;
major_version = std::get<0>(Version::Self());
minor_version = std::get<1>(Version::Self());
num_target = 1;
boost_from_average = true;
static_assert(sizeof(LearnerModelParamLegacy) == 136,
"Do not change the size of this struct, as it will break binary IO.");
Expand All @@ -143,6 +143,10 @@ struct LearnerModelParamLegacy : public dmlc::Parameter<LearnerModelParamLegacy>
CHECK(ret.ec == std::errc());
obj["num_class"] = std::string{integers, static_cast<size_t>(std::distance(integers, ret.ptr))};

ret = to_chars(integers, integers + NumericLimits<int64_t>::kToCharsSize,
static_cast<int64_t>(num_target));
obj["num_target"] = std::string{integers, static_cast<size_t>(std::distance(integers, ret.ptr))};

obj["multi_strategy"] = StrategyStr(multi_strategy);

ret = to_chars(integers, integers + NumericLimits<std::int64_t>::kToCharsSize,
Expand All @@ -157,6 +161,7 @@ struct LearnerModelParamLegacy : public dmlc::Parameter<LearnerModelParamLegacy>
std::map<std::string, std::string> m;
m["num_feature"] = get<String const>(j_param.at("num_feature"));
m["num_class"] = get<String const>(j_param.at("num_class"));
m["num_target"] = get<String const>(j_param.at("num_target"));

auto strategy_it = j_param.find("multi_strategy");
if (strategy_it != j_param.cend()) {
Expand Down Expand Up @@ -239,6 +244,10 @@ struct LearnerModelParamLegacy : public dmlc::Parameter<LearnerModelParamLegacy>
DMLC_DECLARE_FIELD(num_class).set_default(0).set_lower_bound(0).describe(
"Number of class option for multi-class classifier. "
" By default equals 0 and corresponds to binary classifier.");
DMLC_DECLARE_FIELD(num_target)
.set_default(1)
.set_lower_bound(1)
.describe("Number of output targets.");
DMLC_DECLARE_FIELD(multi_strategy)
.add_enum("compo", Strategy::kComposite)
.add_enum("mono", Strategy::kMono)
Expand All @@ -252,12 +261,15 @@ struct LearnerModelParamLegacy : public dmlc::Parameter<LearnerModelParamLegacy>

LearnerModelParam::LearnerModelParam(LearnerModelParamLegacy const& user_param, ObjInfo t)
: num_feature{user_param.num_feature},
num_output_group{user_param.multi_strategy == Strategy::kComposite
? std::max(user_param.num_class, 1)
: 1u},
num_target{user_param.multi_strategy == Strategy::kMono ? std::max(user_param.num_class, 1)
: 1u},
task{t} {}
num_output_group{static_cast<std::uint32_t>(std::max(user_param.num_class, 1))},
num_target{static_cast<bst_target_t>(std::max(user_param.num_target, 1))},
task{t},
multi_strategy{user_param.multi_strategy} {
if (num_output_group > 1 && num_target > 1) {
LOG(FATAL) << "multi-target-multi-class is not yet supported. Output groups:"
<< num_output_group << ", output targets:" << num_target;
}
}

LearnerModelParam::LearnerModelParam(Context const* ctx, LearnerModelParamLegacy const& user_param,
linalg::Tensor<float, 1> base_margin, ObjInfo t)
Expand Down Expand Up @@ -819,7 +831,7 @@ class LearnerConfiguration : public Learner {
*/
void ConfigureTargets() {
CHECK(this->obj_);
if (mparam_.num_class > 1) {
if (mparam_.num_target > 1) {
return;
}
auto const& cache = prediction_container_.Container();
Expand All @@ -832,12 +844,12 @@ class LearnerConfiguration : public Learner {
CHECK(n_targets == t || 1 == t) << "Inconsistent labels.";
}
}
if (mparam_.num_class > 1) {
CHECK(n_targets == 1 || n_targets == static_cast<bst_target_t>(mparam_.num_class))
if (mparam_.num_target > 1) {
CHECK(n_targets == 1 || n_targets == static_cast<bst_target_t>(mparam_.num_target))
<< "Inconsistent configuration of num_target. Configuration result from input data:"
<< n_targets << ", configuration from parameter:" << mparam_.num_class;
<< n_targets << ", configuration from parameter:" << mparam_.num_target;
}
mparam_.num_class = n_targets;
mparam_.num_target = n_targets;
}
};

Expand Down
6 changes: 3 additions & 3 deletions src/predictor/cpu_predictor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ void PredictBatchByBlockOfRowsKernel(DataView batch, gbm::GBTreeModel const &mod
// process block of rows through all trees to keep cache locality
if (model.learner_model_param->IsVectorLeaf()) {
multi::PredictByAllTrees(model, tree_begin, tree_end, batch_offset + batch.base_rowid,
thread_temp, fvec_offset, block_size, out_predt);
thread_temp, fvec_offset, block_size, out_predt);
} else {
scalar::PredictByAllTrees(model, tree_begin, tree_end, batch_offset + batch.base_rowid,
thread_temp, fvec_offset, block_size, out_predt);
Expand Down Expand Up @@ -803,7 +803,7 @@ class CPUPredictor : public Predictor {
std::vector<bst_float> const *tree_weights, bool approximate,
int condition, unsigned condition_feature) const override {
CHECK(!model.learner_model_param->IsVectorLeaf())
<< "predict contribution" << MTNotImplemented();
<< "Predict contribution" << MTNotImplemented();
auto const n_threads = this->ctx_->Threads();
const int num_feature = model.learner_model_param->num_feature;
std::vector<RegTree::FVec> feat_vecs;
Expand Down Expand Up @@ -884,7 +884,7 @@ class CPUPredictor : public Predictor {
std::vector<bst_float> const *tree_weights,
bool approximate) const override {
CHECK(!model.learner_model_param->IsVectorLeaf())
<< "predict interaction contribution" << MTNotImplemented();
<< "Predict interaction contribution" << MTNotImplemented();
const MetaInfo& info = p_fmat->Info();
const int ngroup = model.learner_model_param->num_output_group;
size_t const ncolumns = model.learner_model_param->num_feature;
Expand Down
Loading

0 comments on commit 0961133

Please sign in to comment.