From b030c63726e237b7fdb041fe51cf3a649ead8261 Mon Sep 17 00:00:00 2001 From: dragon_bra Date: Thu, 18 Jul 2024 17:07:09 +0800 Subject: [PATCH 01/11] basic gpu_linear_tree_learner implementation --- src/treelearner/gpu_linear_tree_learner.cpp | 381 ++++++++++++++++++++ src/treelearner/gpu_linear_tree_learner.h | 127 +++++++ 2 files changed, 508 insertions(+) create mode 100644 src/treelearner/gpu_linear_tree_learner.cpp create mode 100644 src/treelearner/gpu_linear_tree_learner.h diff --git a/src/treelearner/gpu_linear_tree_learner.cpp b/src/treelearner/gpu_linear_tree_learner.cpp new file mode 100644 index 000000000000..c6811ee091bd --- /dev/null +++ b/src/treelearner/gpu_linear_tree_learner.cpp @@ -0,0 +1,381 @@ +/*! + * Copyright (c) 2024 Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. See LICENSE file in the project root for license information. + */ +#include "gpu_linear_tree_learner.h" + +#include + +#include + +namespace LightGBM { + +void GPULinearTreeLearner::Init(const Dataset* train_data, bool is_constant_hessian) { + GPUTreeLearner::Init(train_data, is_constant_hessian); + GPULinearTreeLearner::InitLinear(train_data, config_->num_leaves); +} + +void GPULinearTreeLearner::InitLinear(const Dataset* train_data, const int max_leaves) { + leaf_map_ = std::vector(train_data->num_data(), -1); + contains_nan_ = std::vector(train_data->num_features(), 0); + // identify features containing nans +#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static) + for (int feat = 0; feat < train_data->num_features(); ++feat) { + auto bin_mapper = train_data_->FeatureBinMapper(feat); + if (bin_mapper->bin_type() == BinType::NumericalBin) { + const float* feat_ptr = train_data_->raw_index(feat); + for (int i = 0; i < train_data->num_data(); ++i) { + if (std::isnan(feat_ptr[i])) { + contains_nan_[feat] = 1; + break; + } + } + } + } + any_nan_ = false; + for (int feat = 0; feat < train_data->num_features(); ++feat) { + if (contains_nan_[feat]) { + any_nan_ = true; + break; + } + } + // preallocate the matrix used to calculate linear model coefficients + int max_num_feat = std::min(max_leaves, train_data_->num_numeric_features()); + XTHX_.clear(); + XTg_.clear(); + for (int i = 0; i < max_leaves; ++i) { + // store only upper triangular half of matrix as an array, in row-major order + // this requires (max_num_feat + 1) * (max_num_feat + 2) / 2 entries (including the constant terms of the regression) + // we add another 8 to ensure cache lines are not shared among processors + XTHX_.push_back(std::vector((max_num_feat + 1) * (max_num_feat + 2) / 2 + 8, 0)); + XTg_.push_back(std::vector(max_num_feat + 9, 0.0)); + } + XTHX_by_thread_.clear(); + XTg_by_thread_.clear(); + int max_threads = OMP_NUM_THREADS(); + for (int i = 0; i < max_threads; ++i) { + XTHX_by_thread_.push_back(XTHX_); + XTg_by_thread_.push_back(XTg_); + } +} + +Tree* GPULinearTreeLearner::Train(const score_t* gradients, const score_t *hessians, bool is_first_tree) { + Common::FunctionTimer fun_timer("SerialTreeLearner::Train", global_timer); + gradients_ = gradients; + hessians_ = hessians; + int num_threads = OMP_NUM_THREADS(); + if (share_state_->num_threads != num_threads && share_state_->num_threads > 0) { + Log::Warning( + "Detected that num_threads changed during training (from %d to %d), " + "it may cause unexpected errors.", + share_state_->num_threads, num_threads); + } + share_state_->num_threads = num_threads; + + // some initial works before training + BeforeTrain(); + + auto tree = std::unique_ptr(new Tree(config_->num_leaves, true, true)); + auto tree_ptr = tree.get(); + constraints_->ShareTreePointer(tree_ptr); + + // root leaf + int left_leaf = 0; + int cur_depth = 1; + // only root leaf can be splitted on first time + int right_leaf = -1; + + int init_splits = ForceSplits(tree_ptr, &left_leaf, &right_leaf, &cur_depth); + + for (int split = init_splits; split < config_->num_leaves - 1; ++split) { + // some initial works before finding best split + if (BeforeFindBestSplit(tree_ptr, left_leaf, right_leaf)) { + // find best threshold for every feature + FindBestSplits(tree_ptr); + } + // Get a leaf with max split gain + int best_leaf = static_cast(ArrayArgs::ArgMax(best_split_per_leaf_)); + // Get split information for best leaf + const SplitInfo& best_leaf_SplitInfo = best_split_per_leaf_[best_leaf]; + // cannot split, quit + if (best_leaf_SplitInfo.gain <= 0.0) { + Log::Warning("No further splits with positive gain, best gain: %f", best_leaf_SplitInfo.gain); + break; + } + // split tree with best leaf + Split(tree_ptr, best_leaf, &left_leaf, &right_leaf); + cur_depth = std::max(cur_depth, tree->leaf_depth(left_leaf)); + } + + bool has_nan = false; + if (any_nan_) { + for (int i = 0; i < tree->num_leaves() - 1 ; ++i) { + if (contains_nan_[tree_ptr->split_feature_inner(i)]) { + has_nan = true; + break; + } + } + } + + GetLeafMap(tree_ptr); + + if (has_nan) { + CalculateLinear(tree_ptr, false, gradients_, hessians_, is_first_tree); + } else { + CalculateLinear(tree_ptr, false, gradients_, hessians_, is_first_tree); + } + + Log::Debug("Trained a tree with leaves = %d and depth = %d", tree->num_leaves(), cur_depth); + return tree.release(); +} + +Tree* GPULinearTreeLearner::FitByExistingTree(const Tree* old_tree, const score_t* gradients, const score_t *hessians) const { + auto tree = SerialTreeLearner::FitByExistingTree(old_tree, gradients, hessians); + bool has_nan = false; + if (any_nan_) { + for (int i = 0; i < tree->num_leaves() - 1 ; ++i) { + if (contains_nan_[train_data_->InnerFeatureIndex(tree->split_feature(i))]) { + has_nan = true; + break; + } + } + } + GetLeafMap(tree); + if (has_nan) { + CalculateLinear(tree, true, gradients, hessians, false); + } else { + CalculateLinear(tree, true, gradients, hessians, false); + } + return tree; +} + +Tree* GPULinearTreeLearner::FitByExistingTree(const Tree* old_tree, const std::vector& leaf_pred, + const score_t* gradients, const score_t *hessians) const { + data_partition_->ResetByLeafPred(leaf_pred, old_tree->num_leaves()); + return GPULinearTreeLearner::FitByExistingTree(old_tree, gradients, hessians); +} + +void GPULinearTreeLearner::GetLeafMap(Tree* tree) const { + std::fill(leaf_map_.begin(), leaf_map_.end(), -1); + // map data to leaf number + const data_size_t* ind = data_partition_->indices(); +#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(dynamic) + for (int i = 0; i < tree->num_leaves(); ++i) { + data_size_t idx = data_partition_->leaf_begin(i); + for (int j = 0; j < data_partition_->leaf_count(i); ++j) { + leaf_map_[ind[idx + j]] = i; + } + } +} + + +template +void GPULinearTreeLearner::CalculateLinear(Tree* tree, bool is_refit, const score_t* gradients, const score_t* hessians, bool is_first_tree) const { + tree->SetIsLinear(true); + int num_leaves = tree->num_leaves(); + int num_threads = OMP_NUM_THREADS(); + if (is_first_tree) { + for (int leaf_num = 0; leaf_num < num_leaves; ++leaf_num) { + tree->SetLeafConst(leaf_num, tree->LeafOutput(leaf_num)); + } + return; + } + + // calculate coefficients using the method described in Eq 3 of https://arxiv.org/pdf/1802.05640.pdf + // the coefficients vector is given by + // - (X_T * H * X + lambda) ^ (-1) * (X_T * g) + // where: + // X is the matrix where the first column is the feature values and the second is all ones, + // H is the diagonal matrix of the hessian, + // lambda is the diagonal matrix with diagonal entries equal to the regularisation term linear_lambda + // g is the vector of gradients + // the subscript _T denotes the transpose + + // create array of pointers to raw data, and coefficient matrices, for each leaf + std::vector> leaf_features; + std::vector leaf_num_features; + std::vector> raw_data_ptr; + size_t max_num_features = 0; + for (int i = 0; i < num_leaves; ++i) { + std::vector raw_features; + if (is_refit) { + raw_features = tree->LeafFeatures(i); + } else { + raw_features = tree->branch_features(i); + } + std::sort(raw_features.begin(), raw_features.end()); + auto new_end = std::unique(raw_features.begin(), raw_features.end()); + raw_features.erase(new_end, raw_features.end()); + std::vector numerical_features; + std::vector data_ptr; + for (size_t j = 0; j < raw_features.size(); ++j) { + int feat = train_data_->InnerFeatureIndex(raw_features[j]); + auto bin_mapper = train_data_->FeatureBinMapper(feat); + if (bin_mapper->bin_type() == BinType::NumericalBin) { + numerical_features.push_back(feat); + data_ptr.push_back(train_data_->raw_index(feat)); + } + } + leaf_features.push_back(numerical_features); + raw_data_ptr.push_back(data_ptr); + leaf_num_features.push_back(static_cast(numerical_features.size())); + if (numerical_features.size() > max_num_features) { + max_num_features = numerical_features.size(); + } + } + // clear the coefficient matrices +#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static) + for (int i = 0; i < num_threads; ++i) { + for (int leaf_num = 0; leaf_num < num_leaves; ++leaf_num) { + size_t num_feat = leaf_features[leaf_num].size(); + std::fill(XTHX_by_thread_[i][leaf_num].begin(), XTHX_by_thread_[i][leaf_num].begin() + (num_feat + 1) * (num_feat + 2) / 2, 0.0f); + std::fill(XTg_by_thread_[i][leaf_num].begin(), XTg_by_thread_[i][leaf_num].begin() + num_feat + 1, 0.0f); + } + } +#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static) + for (int leaf_num = 0; leaf_num < num_leaves; ++leaf_num) { + size_t num_feat = leaf_features[leaf_num].size(); + std::fill(XTHX_[leaf_num].begin(), XTHX_[leaf_num].begin() + (num_feat + 1) * (num_feat + 2) / 2, 0.0f); + std::fill(XTg_[leaf_num].begin(), XTg_[leaf_num].begin() + num_feat + 1, 0.0f); + } + std::vector> num_nonzero; + for (int i = 0; i < num_threads; ++i) { + if (HAS_NAN) { + num_nonzero.push_back(std::vector(num_leaves, 0)); + } + } + OMP_INIT_EX(); +#pragma omp parallel num_threads(OMP_NUM_THREADS()) if (num_data_ > 1024) + { + std::vector curr_row(max_num_features + 1); + int tid = omp_get_thread_num(); +#pragma omp for schedule(static) + for (int i = 0; i < num_data_; ++i) { + OMP_LOOP_EX_BEGIN(); + int leaf_num = leaf_map_[i]; + if (leaf_num < 0) { + continue; + } + bool nan_found = false; + int num_feat = leaf_num_features[leaf_num]; + for (int feat = 0; feat < num_feat; ++feat) { + if (HAS_NAN) { + float val = raw_data_ptr[leaf_num][feat][i]; + if (std::isnan(val)) { + nan_found = true; + break; + } + num_nonzero[tid][leaf_num] += 1; + curr_row[feat] = val; + } else { + curr_row[feat] = raw_data_ptr[leaf_num][feat][i]; + } + } + if (HAS_NAN) { + if (nan_found) { + continue; + } + } + curr_row[num_feat] = 1.0; + float h = static_cast(hessians[i]); + float g = static_cast(gradients[i]); + int j = 0; + for (int feat1 = 0; feat1 < num_feat + 1; ++feat1) { + double f1_val = static_cast(curr_row[feat1]); + XTg_by_thread_[tid][leaf_num][feat1] += f1_val * g; + f1_val *= h; + for (int feat2 = feat1; feat2 < num_feat + 1; ++feat2) { + XTHX_by_thread_[tid][leaf_num][j] += f1_val * curr_row[feat2]; + ++j; + } + } + OMP_LOOP_EX_END(); + } + } + OMP_THROW_EX(); + auto total_nonzero = std::vector(tree->num_leaves()); + // aggregate results from different threads + for (int tid = 0; tid < num_threads; ++tid) { +#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static) + for (int leaf_num = 0; leaf_num < num_leaves; ++leaf_num) { + size_t num_feat = leaf_features[leaf_num].size(); + for (size_t j = 0; j < (num_feat + 1) * (num_feat + 2) / 2; ++j) { + XTHX_[leaf_num][j] += XTHX_by_thread_[tid][leaf_num][j]; + } + for (size_t feat1 = 0; feat1 < num_feat + 1; ++feat1) { + XTg_[leaf_num][feat1] += XTg_by_thread_[tid][leaf_num][feat1]; + } + if (HAS_NAN) { + total_nonzero[leaf_num] += num_nonzero[tid][leaf_num]; + } + } + } + if (!HAS_NAN) { + for (int leaf_num = 0; leaf_num < num_leaves; ++leaf_num) { + total_nonzero[leaf_num] = data_partition_->leaf_count(leaf_num); + } + } + double shrinkage = tree->shrinkage(); + double decay_rate = config_->refit_decay_rate; + // copy into eigen matrices and solve +#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static) + for (int leaf_num = 0; leaf_num < num_leaves; ++leaf_num) { + if (total_nonzero[leaf_num] < static_cast(leaf_features[leaf_num].size()) + 1) { + if (is_refit) { + double old_const = tree->LeafConst(leaf_num); + tree->SetLeafConst(leaf_num, decay_rate * old_const + (1.0 - decay_rate) * tree->LeafOutput(leaf_num) * shrinkage); + tree->SetLeafCoeffs(leaf_num, std::vector(leaf_features[leaf_num].size(), 0)); + tree->SetLeafFeaturesInner(leaf_num, leaf_features[leaf_num]); + } else { + tree->SetLeafConst(leaf_num, tree->LeafOutput(leaf_num)); + } + continue; + } + size_t num_feat = leaf_features[leaf_num].size(); + Eigen::MatrixXd XTHX_mat(num_feat + 1, num_feat + 1); + Eigen::MatrixXd XTg_mat(num_feat + 1, 1); + size_t j = 0; + for (size_t feat1 = 0; feat1 < num_feat + 1; ++feat1) { + for (size_t feat2 = feat1; feat2 < num_feat + 1; ++feat2) { + XTHX_mat(feat1, feat2) = XTHX_[leaf_num][j]; + XTHX_mat(feat2, feat1) = XTHX_mat(feat1, feat2); + if ((feat1 == feat2) && (feat1 < num_feat)) { + XTHX_mat(feat1, feat2) += config_->linear_lambda; + } + ++j; + } + XTg_mat(feat1) = XTg_[leaf_num][feat1]; + } + Eigen::MatrixXd coeffs = - XTHX_mat.fullPivLu().inverse() * XTg_mat; + std::vector coeffs_vec; + std::vector features_new; + std::vector old_coeffs = tree->LeafCoeffs(leaf_num); + for (size_t i = 0; i < leaf_features[leaf_num].size(); ++i) { + if (is_refit) { + features_new.push_back(leaf_features[leaf_num][i]); + coeffs_vec.push_back(decay_rate * old_coeffs[i] + (1.0 - decay_rate) * coeffs(i) * shrinkage); + } else { + if (coeffs(i) < -kZeroThreshold || coeffs(i) > kZeroThreshold) { + coeffs_vec.push_back(coeffs(i)); + int feat = leaf_features[leaf_num][i]; + features_new.push_back(feat); + } + } + } + // update the tree properties + tree->SetLeafFeaturesInner(leaf_num, features_new); + std::vector features_raw(features_new.size()); + for (size_t i = 0; i < features_new.size(); ++i) { + features_raw[i] = train_data_->RealFeatureIndex(features_new[i]); + } + tree->SetLeafFeatures(leaf_num, features_raw); + tree->SetLeafCoeffs(leaf_num, coeffs_vec); + if (is_refit) { + double old_const = tree->LeafConst(leaf_num); + tree->SetLeafConst(leaf_num, decay_rate * old_const + (1.0 - decay_rate) * coeffs(num_feat) * shrinkage); + } else { + tree->SetLeafConst(leaf_num, coeffs(num_feat)); + } + } +} +} // namespace LightGBM diff --git a/src/treelearner/gpu_linear_tree_learner.h b/src/treelearner/gpu_linear_tree_learner.h new file mode 100644 index 000000000000..e6efe4a7bdb9 --- /dev/null +++ b/src/treelearner/gpu_linear_tree_learner.h @@ -0,0 +1,127 @@ +/*! + * Copyright (c) 2024 Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. See LICENSE file in the project root for license information. + */ +#ifndef LIGHTGBM_TREELEARNER_GPU_LINEAR_TREE_LEARNER_H_ +#define LIGHTGBM_TREELEARNER_GPU_LINEAR_TREE_LEARNER_H_ + +#include +#include +#include +#include +#include + +#include "gpu_tree_learner.h" + +namespace LightGBM { + +class GPULinearTreeLearner: public GPUTreeLearner { + public: + explicit GPULinearTreeLearner(const Config* config) : GPUTreeLearner(config) {} + + void Init(const Dataset* train_data, bool is_constant_hessian) override; + + void InitLinear(const Dataset* train_data, const int max_leaves) override; + + Tree* Train(const score_t* gradients, const score_t *hessians, bool is_first_tree) override; + + /*! \brief Create array mapping dataset to leaf index, used for linear trees */ + void GetLeafMap(Tree* tree) const; + + template + void CalculateLinear(Tree* tree, bool is_refit, const score_t* gradients, const score_t* hessians, bool is_first_tree) const; + + Tree* FitByExistingTree(const Tree* old_tree, const score_t* gradients, const score_t* hessians) const override; + + Tree* FitByExistingTree(const Tree* old_tree, const std::vector& leaf_pred, + const score_t* gradients, const score_t* hessians) const override; + + void AddPredictionToScore(const Tree* tree, + double* out_score) const override { + CHECK_LE(tree->num_leaves(), data_partition_->num_leaves()); + bool has_nan = false; + if (any_nan_) { + for (int i = 0; i < tree->num_leaves() - 1 ; ++i) { + // use split_feature because split_feature_inner doesn't work when refitting existing tree + if (contains_nan_[train_data_->InnerFeatureIndex(tree->split_feature(i))]) { + has_nan = true; + break; + } + } + } + if (has_nan) { + AddPredictionToScoreInner(tree, out_score); + } else { + AddPredictionToScoreInner(tree, out_score); + } + } + + template + void AddPredictionToScoreInner(const Tree* tree, double* out_score) const { + int num_leaves = tree->num_leaves(); + std::vector leaf_const(num_leaves); + std::vector> leaf_coeff(num_leaves); + std::vector> feat_ptr(num_leaves); + std::vector leaf_output(num_leaves); + std::vector leaf_num_features(num_leaves); + for (int leaf_num = 0; leaf_num < num_leaves; ++leaf_num) { + leaf_const[leaf_num] = tree->LeafConst(leaf_num); + leaf_coeff[leaf_num] = tree->LeafCoeffs(leaf_num); + leaf_output[leaf_num] = tree->LeafOutput(leaf_num); + for (int feat : tree->LeafFeaturesInner(leaf_num)) { + feat_ptr[leaf_num].push_back(train_data_->raw_index(feat)); + } + leaf_num_features[leaf_num] = static_cast(feat_ptr[leaf_num].size()); + } + OMP_INIT_EX(); +#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static) if (num_data_ > 1024) + for (int i = 0; i < num_data_; ++i) { + OMP_LOOP_EX_BEGIN(); + int leaf_num = leaf_map_[i]; + if (leaf_num < 0) { + continue; + } + double output = leaf_const[leaf_num]; + int num_feat = leaf_num_features[leaf_num]; + if (HAS_NAN) { + bool nan_found = false; + for (int feat_ind = 0; feat_ind < num_feat; ++feat_ind) { + float val = feat_ptr[leaf_num][feat_ind][i]; + if (std::isnan(val)) { + nan_found = true; + break; + } + output += val * leaf_coeff[leaf_num][feat_ind]; + } + if (nan_found) { + out_score[i] += leaf_output[leaf_num]; + } else { + out_score[i] += output; + } + } else { + for (int feat_ind = 0; feat_ind < num_feat; ++feat_ind) { + output += feat_ptr[leaf_num][feat_ind][i] * leaf_coeff[leaf_num][feat_ind]; + } + out_score[i] += output; + } + OMP_LOOP_EX_END(); + } + OMP_THROW_EX(); + } + + protected: + /*! \brief whether numerical features contain any nan values */ + std::vector contains_nan_; + /*! whether any numerical feature contains a nan value */ + bool any_nan_; + /*! \brief map dataset to leaves */ + mutable std::vector leaf_map_; + /*! \brief temporary storage for calculating linear model coefficients */ + mutable std::vector> XTHX_; + mutable std::vector> XTg_; + mutable std::vector>> XTHX_by_thread_; + mutable std::vector>> XTg_by_thread_; +}; + +} // namespace LightGBM +#endif // LightGBM_TREELEARNER_GPU_LINEAR_TREE_LEARNER_H_ From 80f61e651eb3e4f0f56a9bf9cdd36acedfe0a055 Mon Sep 17 00:00:00 2001 From: dragon_bra Date: Thu, 18 Jul 2024 17:07:30 +0800 Subject: [PATCH 02/11] corresponding config of gpu linear tree --- R-package/src/Makevars.in | 1 + R-package/src/Makevars.win.in | 1 + src/io/config.cpp | 4 ++-- src/treelearner/tree_learner.cpp | 7 ++++++- 4 files changed, 10 insertions(+), 3 deletions(-) diff --git a/R-package/src/Makevars.in b/R-package/src/Makevars.in index ec2c067b64b8..0040d86bf58b 100644 --- a/R-package/src/Makevars.in +++ b/R-package/src/Makevars.in @@ -49,6 +49,7 @@ OBJECTS = \ treelearner/feature_histogram.o \ treelearner/feature_parallel_tree_learner.o \ treelearner/gpu_tree_learner.o \ + treelearner/gpu_linear_tree_learner.o \ treelearner/gradient_discretizer.o \ treelearner/linear_tree_learner.o \ treelearner/serial_tree_learner.o \ diff --git a/R-package/src/Makevars.win.in b/R-package/src/Makevars.win.in index ebcb40d1372a..5fcc720fdd4a 100644 --- a/R-package/src/Makevars.win.in +++ b/R-package/src/Makevars.win.in @@ -50,6 +50,7 @@ OBJECTS = \ treelearner/feature_histogram.o \ treelearner/feature_parallel_tree_learner.o \ treelearner/gpu_tree_learner.o \ + treelearner/gpu_linear_tree_learner.o \ treelearner/gradient_discretizer.o \ treelearner/linear_tree_learner.o \ treelearner/serial_tree_learner.o \ diff --git a/src/io/config.cpp b/src/io/config.cpp index c63de70fc16b..1a6b099dded8 100644 --- a/src/io/config.cpp +++ b/src/io/config.cpp @@ -417,9 +417,9 @@ void Config::CheckParamConflict(const std::unordered_map #include "gpu_tree_learner.h" +#include "gpu_linear_tree_learner.h" #include "linear_tree_learner.h" #include "parallel_tree_learner.h" #include "serial_tree_learner.h" @@ -30,7 +31,11 @@ TreeLearner* TreeLearner::CreateTreeLearner(const std::string& learner_type, con } } else if (device_type == std::string("gpu")) { if (learner_type == std::string("serial")) { - return new GPUTreeLearner(config); + if (config->linear_tree) { + return new GPULinearTreeLearner(config); + } else { + return new GPUTreeLearner(config); + } } else if (learner_type == std::string("feature")) { return new FeatureParallelTreeLearner(config); } else if (learner_type == std::string("data")) { From 94ad3e5e864c207270b7a172b2eec36abad03050 Mon Sep 17 00:00:00 2001 From: dragonbra Date: Thu, 25 Jul 2024 17:47:42 +0800 Subject: [PATCH 03/11] Update src/io/config.cpp Co-authored-by: Nikita Titov --- src/io/config.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/io/config.cpp b/src/io/config.cpp index 1a6b099dded8..bc01ea8d57af 100644 --- a/src/io/config.cpp +++ b/src/io/config.cpp @@ -419,7 +419,7 @@ void Config::CheckParamConflict(const std::unordered_map Date: Tue, 1 Oct 2024 14:09:40 +0000 Subject: [PATCH 04/11] work around for gpu linear tree learner without gpu enabled --- src/treelearner/gpu_linear_tree_learner.h | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/src/treelearner/gpu_linear_tree_learner.h b/src/treelearner/gpu_linear_tree_learner.h index e6efe4a7bdb9..d43f42751d8e 100644 --- a/src/treelearner/gpu_linear_tree_learner.h +++ b/src/treelearner/gpu_linear_tree_learner.h @@ -15,6 +15,8 @@ namespace LightGBM { +#ifdef USE_GPU + class GPULinearTreeLearner: public GPUTreeLearner { public: explicit GPULinearTreeLearner(const Config* config) : GPUTreeLearner(config) {} @@ -123,5 +125,19 @@ class GPULinearTreeLearner: public GPUTreeLearner { mutable std::vector>> XTg_by_thread_; }; +#else // USE_GPU + +class GPULinearTreeLearner: public GPUTreeLearner { + public: + #ifdef _MSC_VER + #pragma warning(disable : 4702) + #endif + explicit GPULinearTreeLearner(const Config* tree_config) : GPUTreeLearner(tree_config) { + Log::Fatal("GPU Tree Linear Learner was not enabled in this build.\n" + "Please recompile with CMake option -DUSE_GPU=1"); + } +}; + } // namespace LightGBM + #endif // LightGBM_TREELEARNER_GPU_LINEAR_TREE_LEARNER_H_ From b16a1270e256bbd5332d29deb35c045512f9bc40 Mon Sep 17 00:00:00 2001 From: Yu Shi Date: Wed, 2 Oct 2024 05:07:10 +0000 Subject: [PATCH 05/11] add #endif --- src/treelearner/gpu_linear_tree_learner.h | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/treelearner/gpu_linear_tree_learner.h b/src/treelearner/gpu_linear_tree_learner.h index d43f42751d8e..dcf7e60966f1 100644 --- a/src/treelearner/gpu_linear_tree_learner.h +++ b/src/treelearner/gpu_linear_tree_learner.h @@ -138,6 +138,8 @@ class GPULinearTreeLearner: public GPUTreeLearner { } }; +#endif // USE_GPU + } // namespace LightGBM #endif // LightGBM_TREELEARNER_GPU_LINEAR_TREE_LEARNER_H_ From 4fdbe07a7d449fb09dc5558057e1b12f9acaf5d7 Mon Sep 17 00:00:00 2001 From: Yu Shi Date: Wed, 2 Oct 2024 05:22:46 +0000 Subject: [PATCH 06/11] add #ifdef USE_GPU --- src/treelearner/gpu_linear_tree_learner.cpp | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/treelearner/gpu_linear_tree_learner.cpp b/src/treelearner/gpu_linear_tree_learner.cpp index c6811ee091bd..8ff2da3fd44e 100644 --- a/src/treelearner/gpu_linear_tree_learner.cpp +++ b/src/treelearner/gpu_linear_tree_learner.cpp @@ -2,6 +2,9 @@ * Copyright (c) 2024 Microsoft Corporation. All rights reserved. * Licensed under the MIT License. See LICENSE file in the project root for license information. */ + +#ifdef USE_GPU + #include "gpu_linear_tree_learner.h" #include @@ -379,3 +382,6 @@ void GPULinearTreeLearner::CalculateLinear(Tree* tree, bool is_refit, const scor } } } // namespace LightGBM + + +#endif // USE_GPU \ No newline at end of file From edfc8cbbd36f39f51392fc88e80e99f294070c72 Mon Sep 17 00:00:00 2001 From: Yu Shi Date: Wed, 2 Oct 2024 06:30:00 +0000 Subject: [PATCH 07/11] fix lint problems --- src/treelearner/gpu_linear_tree_learner.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/treelearner/gpu_linear_tree_learner.cpp b/src/treelearner/gpu_linear_tree_learner.cpp index 8ff2da3fd44e..ac9cb015353f 100644 --- a/src/treelearner/gpu_linear_tree_learner.cpp +++ b/src/treelearner/gpu_linear_tree_learner.cpp @@ -384,4 +384,4 @@ void GPULinearTreeLearner::CalculateLinear(Tree* tree, bool is_refit, const scor } // namespace LightGBM -#endif // USE_GPU \ No newline at end of file +#endif // USE_GPU From 3a05c78567cbabadd63e2d4ae1a0acd539f65e47 Mon Sep 17 00:00:00 2001 From: Yu Shi Date: Wed, 9 Oct 2024 09:43:25 +0000 Subject: [PATCH 08/11] fix compilation when USE_GPU is OFF --- src/treelearner/gpu_linear_tree_learner.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/treelearner/gpu_linear_tree_learner.cpp b/src/treelearner/gpu_linear_tree_learner.cpp index ac9cb015353f..ada4243492f9 100644 --- a/src/treelearner/gpu_linear_tree_learner.cpp +++ b/src/treelearner/gpu_linear_tree_learner.cpp @@ -3,10 +3,10 @@ * Licensed under the MIT License. See LICENSE file in the project root for license information. */ -#ifdef USE_GPU - #include "gpu_linear_tree_learner.h" +#ifdef USE_GPU + #include #include From 7988c814cb10b8cc7cfbd7bd0dc45a4b272f77de Mon Sep 17 00:00:00 2001 From: Yu Shi Date: Thu, 10 Oct 2024 02:59:56 +0000 Subject: [PATCH 09/11] add destructor --- src/treelearner/gpu_linear_tree_learner.h | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/treelearner/gpu_linear_tree_learner.h b/src/treelearner/gpu_linear_tree_learner.h index dcf7e60966f1..53c4e40cf5a4 100644 --- a/src/treelearner/gpu_linear_tree_learner.h +++ b/src/treelearner/gpu_linear_tree_learner.h @@ -21,6 +21,8 @@ class GPULinearTreeLearner: public GPUTreeLearner { public: explicit GPULinearTreeLearner(const Config* config) : GPUTreeLearner(config) {} + ~GPULinearTreeLearner() {} + void Init(const Dataset* train_data, bool is_constant_hessian) override; void InitLinear(const Dataset* train_data, const int max_leaves) override; From d07d9cc18720480884ee716d592e8e48cf80ecec Mon Sep 17 00:00:00 2001 From: Yu Shi Date: Fri, 11 Oct 2024 11:49:06 +0000 Subject: [PATCH 10/11] add gpu_linear_tree_learner.cpp in make file list --- CMakeLists.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/CMakeLists.txt b/CMakeLists.txt index b2859a96d351..52aff98235e3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -439,6 +439,7 @@ set( src/treelearner/data_parallel_tree_learner.cpp src/treelearner/feature_histogram.cpp src/treelearner/feature_parallel_tree_learner.cpp + src/treelearner/gpu_linear_tree_learner.cpp src/treelearner/gpu_tree_learner.cpp src/treelearner/gradient_discretizer.cpp src/treelearner/linear_tree_learner.cpp From 8818b7beea14acf098424702ae4d3c5f7d44cd9b Mon Sep 17 00:00:00 2001 From: Yu Shi Date: Fri, 18 Oct 2024 03:49:49 +0000 Subject: [PATCH 11/11] use template for linear tree learner --- CMakeLists.txt | 1 - R-package/src/Makevars.in | 1 - R-package/src/Makevars.win.in | 1 - src/treelearner/gpu_linear_tree_learner.cpp | 387 -------------------- src/treelearner/gpu_linear_tree_learner.h | 147 -------- src/treelearner/linear_tree_learner.cpp | 112 +++--- src/treelearner/linear_tree_learner.h | 16 +- src/treelearner/tree_learner.cpp | 5 +- 8 files changed, 78 insertions(+), 592 deletions(-) delete mode 100644 src/treelearner/gpu_linear_tree_learner.cpp delete mode 100644 src/treelearner/gpu_linear_tree_learner.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 52aff98235e3..b2859a96d351 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -439,7 +439,6 @@ set( src/treelearner/data_parallel_tree_learner.cpp src/treelearner/feature_histogram.cpp src/treelearner/feature_parallel_tree_learner.cpp - src/treelearner/gpu_linear_tree_learner.cpp src/treelearner/gpu_tree_learner.cpp src/treelearner/gradient_discretizer.cpp src/treelearner/linear_tree_learner.cpp diff --git a/R-package/src/Makevars.in b/R-package/src/Makevars.in index 0040d86bf58b..ec2c067b64b8 100644 --- a/R-package/src/Makevars.in +++ b/R-package/src/Makevars.in @@ -49,7 +49,6 @@ OBJECTS = \ treelearner/feature_histogram.o \ treelearner/feature_parallel_tree_learner.o \ treelearner/gpu_tree_learner.o \ - treelearner/gpu_linear_tree_learner.o \ treelearner/gradient_discretizer.o \ treelearner/linear_tree_learner.o \ treelearner/serial_tree_learner.o \ diff --git a/R-package/src/Makevars.win.in b/R-package/src/Makevars.win.in index 5fcc720fdd4a..ebcb40d1372a 100644 --- a/R-package/src/Makevars.win.in +++ b/R-package/src/Makevars.win.in @@ -50,7 +50,6 @@ OBJECTS = \ treelearner/feature_histogram.o \ treelearner/feature_parallel_tree_learner.o \ treelearner/gpu_tree_learner.o \ - treelearner/gpu_linear_tree_learner.o \ treelearner/gradient_discretizer.o \ treelearner/linear_tree_learner.o \ treelearner/serial_tree_learner.o \ diff --git a/src/treelearner/gpu_linear_tree_learner.cpp b/src/treelearner/gpu_linear_tree_learner.cpp deleted file mode 100644 index ada4243492f9..000000000000 --- a/src/treelearner/gpu_linear_tree_learner.cpp +++ /dev/null @@ -1,387 +0,0 @@ -/*! - * Copyright (c) 2024 Microsoft Corporation. All rights reserved. - * Licensed under the MIT License. See LICENSE file in the project root for license information. - */ - -#include "gpu_linear_tree_learner.h" - -#ifdef USE_GPU - -#include - -#include - -namespace LightGBM { - -void GPULinearTreeLearner::Init(const Dataset* train_data, bool is_constant_hessian) { - GPUTreeLearner::Init(train_data, is_constant_hessian); - GPULinearTreeLearner::InitLinear(train_data, config_->num_leaves); -} - -void GPULinearTreeLearner::InitLinear(const Dataset* train_data, const int max_leaves) { - leaf_map_ = std::vector(train_data->num_data(), -1); - contains_nan_ = std::vector(train_data->num_features(), 0); - // identify features containing nans -#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static) - for (int feat = 0; feat < train_data->num_features(); ++feat) { - auto bin_mapper = train_data_->FeatureBinMapper(feat); - if (bin_mapper->bin_type() == BinType::NumericalBin) { - const float* feat_ptr = train_data_->raw_index(feat); - for (int i = 0; i < train_data->num_data(); ++i) { - if (std::isnan(feat_ptr[i])) { - contains_nan_[feat] = 1; - break; - } - } - } - } - any_nan_ = false; - for (int feat = 0; feat < train_data->num_features(); ++feat) { - if (contains_nan_[feat]) { - any_nan_ = true; - break; - } - } - // preallocate the matrix used to calculate linear model coefficients - int max_num_feat = std::min(max_leaves, train_data_->num_numeric_features()); - XTHX_.clear(); - XTg_.clear(); - for (int i = 0; i < max_leaves; ++i) { - // store only upper triangular half of matrix as an array, in row-major order - // this requires (max_num_feat + 1) * (max_num_feat + 2) / 2 entries (including the constant terms of the regression) - // we add another 8 to ensure cache lines are not shared among processors - XTHX_.push_back(std::vector((max_num_feat + 1) * (max_num_feat + 2) / 2 + 8, 0)); - XTg_.push_back(std::vector(max_num_feat + 9, 0.0)); - } - XTHX_by_thread_.clear(); - XTg_by_thread_.clear(); - int max_threads = OMP_NUM_THREADS(); - for (int i = 0; i < max_threads; ++i) { - XTHX_by_thread_.push_back(XTHX_); - XTg_by_thread_.push_back(XTg_); - } -} - -Tree* GPULinearTreeLearner::Train(const score_t* gradients, const score_t *hessians, bool is_first_tree) { - Common::FunctionTimer fun_timer("SerialTreeLearner::Train", global_timer); - gradients_ = gradients; - hessians_ = hessians; - int num_threads = OMP_NUM_THREADS(); - if (share_state_->num_threads != num_threads && share_state_->num_threads > 0) { - Log::Warning( - "Detected that num_threads changed during training (from %d to %d), " - "it may cause unexpected errors.", - share_state_->num_threads, num_threads); - } - share_state_->num_threads = num_threads; - - // some initial works before training - BeforeTrain(); - - auto tree = std::unique_ptr(new Tree(config_->num_leaves, true, true)); - auto tree_ptr = tree.get(); - constraints_->ShareTreePointer(tree_ptr); - - // root leaf - int left_leaf = 0; - int cur_depth = 1; - // only root leaf can be splitted on first time - int right_leaf = -1; - - int init_splits = ForceSplits(tree_ptr, &left_leaf, &right_leaf, &cur_depth); - - for (int split = init_splits; split < config_->num_leaves - 1; ++split) { - // some initial works before finding best split - if (BeforeFindBestSplit(tree_ptr, left_leaf, right_leaf)) { - // find best threshold for every feature - FindBestSplits(tree_ptr); - } - // Get a leaf with max split gain - int best_leaf = static_cast(ArrayArgs::ArgMax(best_split_per_leaf_)); - // Get split information for best leaf - const SplitInfo& best_leaf_SplitInfo = best_split_per_leaf_[best_leaf]; - // cannot split, quit - if (best_leaf_SplitInfo.gain <= 0.0) { - Log::Warning("No further splits with positive gain, best gain: %f", best_leaf_SplitInfo.gain); - break; - } - // split tree with best leaf - Split(tree_ptr, best_leaf, &left_leaf, &right_leaf); - cur_depth = std::max(cur_depth, tree->leaf_depth(left_leaf)); - } - - bool has_nan = false; - if (any_nan_) { - for (int i = 0; i < tree->num_leaves() - 1 ; ++i) { - if (contains_nan_[tree_ptr->split_feature_inner(i)]) { - has_nan = true; - break; - } - } - } - - GetLeafMap(tree_ptr); - - if (has_nan) { - CalculateLinear(tree_ptr, false, gradients_, hessians_, is_first_tree); - } else { - CalculateLinear(tree_ptr, false, gradients_, hessians_, is_first_tree); - } - - Log::Debug("Trained a tree with leaves = %d and depth = %d", tree->num_leaves(), cur_depth); - return tree.release(); -} - -Tree* GPULinearTreeLearner::FitByExistingTree(const Tree* old_tree, const score_t* gradients, const score_t *hessians) const { - auto tree = SerialTreeLearner::FitByExistingTree(old_tree, gradients, hessians); - bool has_nan = false; - if (any_nan_) { - for (int i = 0; i < tree->num_leaves() - 1 ; ++i) { - if (contains_nan_[train_data_->InnerFeatureIndex(tree->split_feature(i))]) { - has_nan = true; - break; - } - } - } - GetLeafMap(tree); - if (has_nan) { - CalculateLinear(tree, true, gradients, hessians, false); - } else { - CalculateLinear(tree, true, gradients, hessians, false); - } - return tree; -} - -Tree* GPULinearTreeLearner::FitByExistingTree(const Tree* old_tree, const std::vector& leaf_pred, - const score_t* gradients, const score_t *hessians) const { - data_partition_->ResetByLeafPred(leaf_pred, old_tree->num_leaves()); - return GPULinearTreeLearner::FitByExistingTree(old_tree, gradients, hessians); -} - -void GPULinearTreeLearner::GetLeafMap(Tree* tree) const { - std::fill(leaf_map_.begin(), leaf_map_.end(), -1); - // map data to leaf number - const data_size_t* ind = data_partition_->indices(); -#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(dynamic) - for (int i = 0; i < tree->num_leaves(); ++i) { - data_size_t idx = data_partition_->leaf_begin(i); - for (int j = 0; j < data_partition_->leaf_count(i); ++j) { - leaf_map_[ind[idx + j]] = i; - } - } -} - - -template -void GPULinearTreeLearner::CalculateLinear(Tree* tree, bool is_refit, const score_t* gradients, const score_t* hessians, bool is_first_tree) const { - tree->SetIsLinear(true); - int num_leaves = tree->num_leaves(); - int num_threads = OMP_NUM_THREADS(); - if (is_first_tree) { - for (int leaf_num = 0; leaf_num < num_leaves; ++leaf_num) { - tree->SetLeafConst(leaf_num, tree->LeafOutput(leaf_num)); - } - return; - } - - // calculate coefficients using the method described in Eq 3 of https://arxiv.org/pdf/1802.05640.pdf - // the coefficients vector is given by - // - (X_T * H * X + lambda) ^ (-1) * (X_T * g) - // where: - // X is the matrix where the first column is the feature values and the second is all ones, - // H is the diagonal matrix of the hessian, - // lambda is the diagonal matrix with diagonal entries equal to the regularisation term linear_lambda - // g is the vector of gradients - // the subscript _T denotes the transpose - - // create array of pointers to raw data, and coefficient matrices, for each leaf - std::vector> leaf_features; - std::vector leaf_num_features; - std::vector> raw_data_ptr; - size_t max_num_features = 0; - for (int i = 0; i < num_leaves; ++i) { - std::vector raw_features; - if (is_refit) { - raw_features = tree->LeafFeatures(i); - } else { - raw_features = tree->branch_features(i); - } - std::sort(raw_features.begin(), raw_features.end()); - auto new_end = std::unique(raw_features.begin(), raw_features.end()); - raw_features.erase(new_end, raw_features.end()); - std::vector numerical_features; - std::vector data_ptr; - for (size_t j = 0; j < raw_features.size(); ++j) { - int feat = train_data_->InnerFeatureIndex(raw_features[j]); - auto bin_mapper = train_data_->FeatureBinMapper(feat); - if (bin_mapper->bin_type() == BinType::NumericalBin) { - numerical_features.push_back(feat); - data_ptr.push_back(train_data_->raw_index(feat)); - } - } - leaf_features.push_back(numerical_features); - raw_data_ptr.push_back(data_ptr); - leaf_num_features.push_back(static_cast(numerical_features.size())); - if (numerical_features.size() > max_num_features) { - max_num_features = numerical_features.size(); - } - } - // clear the coefficient matrices -#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static) - for (int i = 0; i < num_threads; ++i) { - for (int leaf_num = 0; leaf_num < num_leaves; ++leaf_num) { - size_t num_feat = leaf_features[leaf_num].size(); - std::fill(XTHX_by_thread_[i][leaf_num].begin(), XTHX_by_thread_[i][leaf_num].begin() + (num_feat + 1) * (num_feat + 2) / 2, 0.0f); - std::fill(XTg_by_thread_[i][leaf_num].begin(), XTg_by_thread_[i][leaf_num].begin() + num_feat + 1, 0.0f); - } - } -#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static) - for (int leaf_num = 0; leaf_num < num_leaves; ++leaf_num) { - size_t num_feat = leaf_features[leaf_num].size(); - std::fill(XTHX_[leaf_num].begin(), XTHX_[leaf_num].begin() + (num_feat + 1) * (num_feat + 2) / 2, 0.0f); - std::fill(XTg_[leaf_num].begin(), XTg_[leaf_num].begin() + num_feat + 1, 0.0f); - } - std::vector> num_nonzero; - for (int i = 0; i < num_threads; ++i) { - if (HAS_NAN) { - num_nonzero.push_back(std::vector(num_leaves, 0)); - } - } - OMP_INIT_EX(); -#pragma omp parallel num_threads(OMP_NUM_THREADS()) if (num_data_ > 1024) - { - std::vector curr_row(max_num_features + 1); - int tid = omp_get_thread_num(); -#pragma omp for schedule(static) - for (int i = 0; i < num_data_; ++i) { - OMP_LOOP_EX_BEGIN(); - int leaf_num = leaf_map_[i]; - if (leaf_num < 0) { - continue; - } - bool nan_found = false; - int num_feat = leaf_num_features[leaf_num]; - for (int feat = 0; feat < num_feat; ++feat) { - if (HAS_NAN) { - float val = raw_data_ptr[leaf_num][feat][i]; - if (std::isnan(val)) { - nan_found = true; - break; - } - num_nonzero[tid][leaf_num] += 1; - curr_row[feat] = val; - } else { - curr_row[feat] = raw_data_ptr[leaf_num][feat][i]; - } - } - if (HAS_NAN) { - if (nan_found) { - continue; - } - } - curr_row[num_feat] = 1.0; - float h = static_cast(hessians[i]); - float g = static_cast(gradients[i]); - int j = 0; - for (int feat1 = 0; feat1 < num_feat + 1; ++feat1) { - double f1_val = static_cast(curr_row[feat1]); - XTg_by_thread_[tid][leaf_num][feat1] += f1_val * g; - f1_val *= h; - for (int feat2 = feat1; feat2 < num_feat + 1; ++feat2) { - XTHX_by_thread_[tid][leaf_num][j] += f1_val * curr_row[feat2]; - ++j; - } - } - OMP_LOOP_EX_END(); - } - } - OMP_THROW_EX(); - auto total_nonzero = std::vector(tree->num_leaves()); - // aggregate results from different threads - for (int tid = 0; tid < num_threads; ++tid) { -#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static) - for (int leaf_num = 0; leaf_num < num_leaves; ++leaf_num) { - size_t num_feat = leaf_features[leaf_num].size(); - for (size_t j = 0; j < (num_feat + 1) * (num_feat + 2) / 2; ++j) { - XTHX_[leaf_num][j] += XTHX_by_thread_[tid][leaf_num][j]; - } - for (size_t feat1 = 0; feat1 < num_feat + 1; ++feat1) { - XTg_[leaf_num][feat1] += XTg_by_thread_[tid][leaf_num][feat1]; - } - if (HAS_NAN) { - total_nonzero[leaf_num] += num_nonzero[tid][leaf_num]; - } - } - } - if (!HAS_NAN) { - for (int leaf_num = 0; leaf_num < num_leaves; ++leaf_num) { - total_nonzero[leaf_num] = data_partition_->leaf_count(leaf_num); - } - } - double shrinkage = tree->shrinkage(); - double decay_rate = config_->refit_decay_rate; - // copy into eigen matrices and solve -#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static) - for (int leaf_num = 0; leaf_num < num_leaves; ++leaf_num) { - if (total_nonzero[leaf_num] < static_cast(leaf_features[leaf_num].size()) + 1) { - if (is_refit) { - double old_const = tree->LeafConst(leaf_num); - tree->SetLeafConst(leaf_num, decay_rate * old_const + (1.0 - decay_rate) * tree->LeafOutput(leaf_num) * shrinkage); - tree->SetLeafCoeffs(leaf_num, std::vector(leaf_features[leaf_num].size(), 0)); - tree->SetLeafFeaturesInner(leaf_num, leaf_features[leaf_num]); - } else { - tree->SetLeafConst(leaf_num, tree->LeafOutput(leaf_num)); - } - continue; - } - size_t num_feat = leaf_features[leaf_num].size(); - Eigen::MatrixXd XTHX_mat(num_feat + 1, num_feat + 1); - Eigen::MatrixXd XTg_mat(num_feat + 1, 1); - size_t j = 0; - for (size_t feat1 = 0; feat1 < num_feat + 1; ++feat1) { - for (size_t feat2 = feat1; feat2 < num_feat + 1; ++feat2) { - XTHX_mat(feat1, feat2) = XTHX_[leaf_num][j]; - XTHX_mat(feat2, feat1) = XTHX_mat(feat1, feat2); - if ((feat1 == feat2) && (feat1 < num_feat)) { - XTHX_mat(feat1, feat2) += config_->linear_lambda; - } - ++j; - } - XTg_mat(feat1) = XTg_[leaf_num][feat1]; - } - Eigen::MatrixXd coeffs = - XTHX_mat.fullPivLu().inverse() * XTg_mat; - std::vector coeffs_vec; - std::vector features_new; - std::vector old_coeffs = tree->LeafCoeffs(leaf_num); - for (size_t i = 0; i < leaf_features[leaf_num].size(); ++i) { - if (is_refit) { - features_new.push_back(leaf_features[leaf_num][i]); - coeffs_vec.push_back(decay_rate * old_coeffs[i] + (1.0 - decay_rate) * coeffs(i) * shrinkage); - } else { - if (coeffs(i) < -kZeroThreshold || coeffs(i) > kZeroThreshold) { - coeffs_vec.push_back(coeffs(i)); - int feat = leaf_features[leaf_num][i]; - features_new.push_back(feat); - } - } - } - // update the tree properties - tree->SetLeafFeaturesInner(leaf_num, features_new); - std::vector features_raw(features_new.size()); - for (size_t i = 0; i < features_new.size(); ++i) { - features_raw[i] = train_data_->RealFeatureIndex(features_new[i]); - } - tree->SetLeafFeatures(leaf_num, features_raw); - tree->SetLeafCoeffs(leaf_num, coeffs_vec); - if (is_refit) { - double old_const = tree->LeafConst(leaf_num); - tree->SetLeafConst(leaf_num, decay_rate * old_const + (1.0 - decay_rate) * coeffs(num_feat) * shrinkage); - } else { - tree->SetLeafConst(leaf_num, coeffs(num_feat)); - } - } -} -} // namespace LightGBM - - -#endif // USE_GPU diff --git a/src/treelearner/gpu_linear_tree_learner.h b/src/treelearner/gpu_linear_tree_learner.h deleted file mode 100644 index 53c4e40cf5a4..000000000000 --- a/src/treelearner/gpu_linear_tree_learner.h +++ /dev/null @@ -1,147 +0,0 @@ -/*! - * Copyright (c) 2024 Microsoft Corporation. All rights reserved. - * Licensed under the MIT License. See LICENSE file in the project root for license information. - */ -#ifndef LIGHTGBM_TREELEARNER_GPU_LINEAR_TREE_LEARNER_H_ -#define LIGHTGBM_TREELEARNER_GPU_LINEAR_TREE_LEARNER_H_ - -#include -#include -#include -#include -#include - -#include "gpu_tree_learner.h" - -namespace LightGBM { - -#ifdef USE_GPU - -class GPULinearTreeLearner: public GPUTreeLearner { - public: - explicit GPULinearTreeLearner(const Config* config) : GPUTreeLearner(config) {} - - ~GPULinearTreeLearner() {} - - void Init(const Dataset* train_data, bool is_constant_hessian) override; - - void InitLinear(const Dataset* train_data, const int max_leaves) override; - - Tree* Train(const score_t* gradients, const score_t *hessians, bool is_first_tree) override; - - /*! \brief Create array mapping dataset to leaf index, used for linear trees */ - void GetLeafMap(Tree* tree) const; - - template - void CalculateLinear(Tree* tree, bool is_refit, const score_t* gradients, const score_t* hessians, bool is_first_tree) const; - - Tree* FitByExistingTree(const Tree* old_tree, const score_t* gradients, const score_t* hessians) const override; - - Tree* FitByExistingTree(const Tree* old_tree, const std::vector& leaf_pred, - const score_t* gradients, const score_t* hessians) const override; - - void AddPredictionToScore(const Tree* tree, - double* out_score) const override { - CHECK_LE(tree->num_leaves(), data_partition_->num_leaves()); - bool has_nan = false; - if (any_nan_) { - for (int i = 0; i < tree->num_leaves() - 1 ; ++i) { - // use split_feature because split_feature_inner doesn't work when refitting existing tree - if (contains_nan_[train_data_->InnerFeatureIndex(tree->split_feature(i))]) { - has_nan = true; - break; - } - } - } - if (has_nan) { - AddPredictionToScoreInner(tree, out_score); - } else { - AddPredictionToScoreInner(tree, out_score); - } - } - - template - void AddPredictionToScoreInner(const Tree* tree, double* out_score) const { - int num_leaves = tree->num_leaves(); - std::vector leaf_const(num_leaves); - std::vector> leaf_coeff(num_leaves); - std::vector> feat_ptr(num_leaves); - std::vector leaf_output(num_leaves); - std::vector leaf_num_features(num_leaves); - for (int leaf_num = 0; leaf_num < num_leaves; ++leaf_num) { - leaf_const[leaf_num] = tree->LeafConst(leaf_num); - leaf_coeff[leaf_num] = tree->LeafCoeffs(leaf_num); - leaf_output[leaf_num] = tree->LeafOutput(leaf_num); - for (int feat : tree->LeafFeaturesInner(leaf_num)) { - feat_ptr[leaf_num].push_back(train_data_->raw_index(feat)); - } - leaf_num_features[leaf_num] = static_cast(feat_ptr[leaf_num].size()); - } - OMP_INIT_EX(); -#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static) if (num_data_ > 1024) - for (int i = 0; i < num_data_; ++i) { - OMP_LOOP_EX_BEGIN(); - int leaf_num = leaf_map_[i]; - if (leaf_num < 0) { - continue; - } - double output = leaf_const[leaf_num]; - int num_feat = leaf_num_features[leaf_num]; - if (HAS_NAN) { - bool nan_found = false; - for (int feat_ind = 0; feat_ind < num_feat; ++feat_ind) { - float val = feat_ptr[leaf_num][feat_ind][i]; - if (std::isnan(val)) { - nan_found = true; - break; - } - output += val * leaf_coeff[leaf_num][feat_ind]; - } - if (nan_found) { - out_score[i] += leaf_output[leaf_num]; - } else { - out_score[i] += output; - } - } else { - for (int feat_ind = 0; feat_ind < num_feat; ++feat_ind) { - output += feat_ptr[leaf_num][feat_ind][i] * leaf_coeff[leaf_num][feat_ind]; - } - out_score[i] += output; - } - OMP_LOOP_EX_END(); - } - OMP_THROW_EX(); - } - - protected: - /*! \brief whether numerical features contain any nan values */ - std::vector contains_nan_; - /*! whether any numerical feature contains a nan value */ - bool any_nan_; - /*! \brief map dataset to leaves */ - mutable std::vector leaf_map_; - /*! \brief temporary storage for calculating linear model coefficients */ - mutable std::vector> XTHX_; - mutable std::vector> XTg_; - mutable std::vector>> XTHX_by_thread_; - mutable std::vector>> XTg_by_thread_; -}; - -#else // USE_GPU - -class GPULinearTreeLearner: public GPUTreeLearner { - public: - #ifdef _MSC_VER - #pragma warning(disable : 4702) - #endif - explicit GPULinearTreeLearner(const Config* tree_config) : GPUTreeLearner(tree_config) { - Log::Fatal("GPU Tree Linear Learner was not enabled in this build.\n" - "Please recompile with CMake option -DUSE_GPU=1"); - } -}; - -#endif // USE_GPU - -} // namespace LightGBM - -#endif // LightGBM_TREELEARNER_GPU_LINEAR_TREE_LEARNER_H_ diff --git a/src/treelearner/linear_tree_learner.cpp b/src/treelearner/linear_tree_learner.cpp index c96bce64d644..22c6d1caf7f3 100644 --- a/src/treelearner/linear_tree_learner.cpp +++ b/src/treelearner/linear_tree_learner.cpp @@ -10,20 +10,22 @@ namespace LightGBM { -void LinearTreeLearner::Init(const Dataset* train_data, bool is_constant_hessian) { - SerialTreeLearner::Init(train_data, is_constant_hessian); - LinearTreeLearner::InitLinear(train_data, config_->num_leaves); +template +void LinearTreeLearner::Init(const Dataset* train_data, bool is_constant_hessian) { + TREE_LEARNER_TYPE::Init(train_data, is_constant_hessian); + LinearTreeLearner::InitLinear(train_data, this->config_->num_leaves); } -void LinearTreeLearner::InitLinear(const Dataset* train_data, const int max_leaves) { +template +void LinearTreeLearner::InitLinear(const Dataset* train_data, const int max_leaves) { leaf_map_ = std::vector(train_data->num_data(), -1); contains_nan_ = std::vector(train_data->num_features(), 0); // identify features containing nans #pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static) for (int feat = 0; feat < train_data->num_features(); ++feat) { - auto bin_mapper = train_data_->FeatureBinMapper(feat); + auto bin_mapper = this->train_data_->FeatureBinMapper(feat); if (bin_mapper->bin_type() == BinType::NumericalBin) { - const float* feat_ptr = train_data_->raw_index(feat); + const float* feat_ptr = this->train_data_->raw_index(feat); for (int i = 0; i < train_data->num_data(); ++i) { if (std::isnan(feat_ptr[i])) { contains_nan_[feat] = 1; @@ -40,7 +42,7 @@ void LinearTreeLearner::InitLinear(const Dataset* train_data, const int max_leav } } // preallocate the matrix used to calculate linear model coefficients - int max_num_feat = std::min(max_leaves, train_data_->num_numeric_features()); + int max_num_feat = std::min(max_leaves, this->train_data_->num_numeric_features()); XTHX_.clear(); XTg_.clear(); for (int i = 0; i < max_leaves; ++i) { @@ -59,25 +61,26 @@ void LinearTreeLearner::InitLinear(const Dataset* train_data, const int max_leav } } -Tree* LinearTreeLearner::Train(const score_t* gradients, const score_t *hessians, bool is_first_tree) { +template +Tree* LinearTreeLearner::Train(const score_t* gradients, const score_t *hessians, bool is_first_tree) { Common::FunctionTimer fun_timer("SerialTreeLearner::Train", global_timer); - gradients_ = gradients; - hessians_ = hessians; + this->gradients_ = gradients; + this->hessians_ = hessians; int num_threads = OMP_NUM_THREADS(); - if (share_state_->num_threads != num_threads && share_state_->num_threads > 0) { + if (this->share_state_->num_threads != num_threads && this->share_state_->num_threads > 0) { Log::Warning( "Detected that num_threads changed during training (from %d to %d), " "it may cause unexpected errors.", - share_state_->num_threads, num_threads); + this->share_state_->num_threads, num_threads); } - share_state_->num_threads = num_threads; + this->share_state_->num_threads = num_threads; // some initial works before training - BeforeTrain(); + this->BeforeTrain(); - auto tree = std::unique_ptr(new Tree(config_->num_leaves, true, true)); + auto tree = std::unique_ptr(new Tree(this->config_->num_leaves, true, true)); auto tree_ptr = tree.get(); - constraints_->ShareTreePointer(tree_ptr); + this->constraints_->ShareTreePointer(tree_ptr); // root leaf int left_leaf = 0; @@ -85,25 +88,25 @@ Tree* LinearTreeLearner::Train(const score_t* gradients, const score_t *hessians // only root leaf can be splitted on first time int right_leaf = -1; - int init_splits = ForceSplits(tree_ptr, &left_leaf, &right_leaf, &cur_depth); + int init_splits = this->ForceSplits(tree_ptr, &left_leaf, &right_leaf, &cur_depth); - for (int split = init_splits; split < config_->num_leaves - 1; ++split) { + for (int split = init_splits; split < this->config_->num_leaves - 1; ++split) { // some initial works before finding best split - if (BeforeFindBestSplit(tree_ptr, left_leaf, right_leaf)) { + if (this->BeforeFindBestSplit(tree_ptr, left_leaf, right_leaf)) { // find best threshold for every feature - FindBestSplits(tree_ptr); + this->FindBestSplits(tree_ptr); } // Get a leaf with max split gain - int best_leaf = static_cast(ArrayArgs::ArgMax(best_split_per_leaf_)); + int best_leaf = static_cast(ArrayArgs::ArgMax(this->best_split_per_leaf_)); // Get split information for best leaf - const SplitInfo& best_leaf_SplitInfo = best_split_per_leaf_[best_leaf]; + const SplitInfo& best_leaf_SplitInfo = this->best_split_per_leaf_[best_leaf]; // cannot split, quit if (best_leaf_SplitInfo.gain <= 0.0) { Log::Warning("No further splits with positive gain, best gain: %f", best_leaf_SplitInfo.gain); break; } // split tree with best leaf - Split(tree_ptr, best_leaf, &left_leaf, &right_leaf); + this->Split(tree_ptr, best_leaf, &left_leaf, &right_leaf); cur_depth = std::max(cur_depth, tree->leaf_depth(left_leaf)); } @@ -120,21 +123,22 @@ Tree* LinearTreeLearner::Train(const score_t* gradients, const score_t *hessians GetLeafMap(tree_ptr); if (has_nan) { - CalculateLinear(tree_ptr, false, gradients_, hessians_, is_first_tree); + CalculateLinear(tree_ptr, false, this->gradients_, this->hessians_, is_first_tree); } else { - CalculateLinear(tree_ptr, false, gradients_, hessians_, is_first_tree); + CalculateLinear(tree_ptr, false, this->gradients_, this->hessians_, is_first_tree); } Log::Debug("Trained a tree with leaves = %d and depth = %d", tree->num_leaves(), cur_depth); return tree.release(); } -Tree* LinearTreeLearner::FitByExistingTree(const Tree* old_tree, const score_t* gradients, const score_t *hessians) const { - auto tree = SerialTreeLearner::FitByExistingTree(old_tree, gradients, hessians); +template +Tree* LinearTreeLearner::FitByExistingTree(const Tree* old_tree, const score_t* gradients, const score_t *hessians) const { + auto tree = TREE_LEARNER_TYPE::FitByExistingTree(old_tree, gradients, hessians); bool has_nan = false; if (any_nan_) { for (int i = 0; i < tree->num_leaves() - 1 ; ++i) { - if (contains_nan_[train_data_->InnerFeatureIndex(tree->split_feature(i))]) { + if (contains_nan_[this->train_data_->InnerFeatureIndex(tree->split_feature(i))]) { has_nan = true; break; } @@ -149,28 +153,31 @@ Tree* LinearTreeLearner::FitByExistingTree(const Tree* old_tree, const score_t* return tree; } -Tree* LinearTreeLearner::FitByExistingTree(const Tree* old_tree, const std::vector& leaf_pred, +template +Tree* LinearTreeLearner::FitByExistingTree(const Tree* old_tree, const std::vector& leaf_pred, const score_t* gradients, const score_t *hessians) const { - data_partition_->ResetByLeafPred(leaf_pred, old_tree->num_leaves()); + this->data_partition_->ResetByLeafPred(leaf_pred, old_tree->num_leaves()); return LinearTreeLearner::FitByExistingTree(old_tree, gradients, hessians); } -void LinearTreeLearner::GetLeafMap(Tree* tree) const { +template +void LinearTreeLearner::GetLeafMap(Tree* tree) const { std::fill(leaf_map_.begin(), leaf_map_.end(), -1); // map data to leaf number - const data_size_t* ind = data_partition_->indices(); + const data_size_t* ind = this->data_partition_->indices(); #pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(dynamic) for (int i = 0; i < tree->num_leaves(); ++i) { - data_size_t idx = data_partition_->leaf_begin(i); - for (int j = 0; j < data_partition_->leaf_count(i); ++j) { + data_size_t idx = this->data_partition_->leaf_begin(i); + for (int j = 0; j < this->data_partition_->leaf_count(i); ++j) { leaf_map_[ind[idx + j]] = i; } } } -template -void LinearTreeLearner::CalculateLinear(Tree* tree, bool is_refit, const score_t* gradients, const score_t* hessians, bool is_first_tree) const { +template +template +void LinearTreeLearner::CalculateLinear(Tree* tree, bool is_refit, const score_t* gradients, const score_t* hessians, bool is_first_tree) const { tree->SetIsLinear(true); int num_leaves = tree->num_leaves(); int num_threads = OMP_NUM_THREADS(); @@ -209,11 +216,11 @@ void LinearTreeLearner::CalculateLinear(Tree* tree, bool is_refit, const score_t std::vector numerical_features; std::vector data_ptr; for (size_t j = 0; j < raw_features.size(); ++j) { - int feat = train_data_->InnerFeatureIndex(raw_features[j]); - auto bin_mapper = train_data_->FeatureBinMapper(feat); + int feat = this->train_data_->InnerFeatureIndex(raw_features[j]); + auto bin_mapper = this->train_data_->FeatureBinMapper(feat); if (bin_mapper->bin_type() == BinType::NumericalBin) { numerical_features.push_back(feat); - data_ptr.push_back(train_data_->raw_index(feat)); + data_ptr.push_back(this->train_data_->raw_index(feat)); } } leaf_features.push_back(numerical_features); @@ -245,12 +252,12 @@ void LinearTreeLearner::CalculateLinear(Tree* tree, bool is_refit, const score_t } } OMP_INIT_EX(); -#pragma omp parallel num_threads(OMP_NUM_THREADS()) if (num_data_ > 1024) +#pragma omp parallel num_threads(OMP_NUM_THREADS()) if (this->num_data_ > 1024) { std::vector curr_row(max_num_features + 1); int tid = omp_get_thread_num(); #pragma omp for schedule(static) - for (int i = 0; i < num_data_; ++i) { + for (int i = 0; i < this->num_data_; ++i) { OMP_LOOP_EX_BEGIN(); int leaf_num = leaf_map_[i]; if (leaf_num < 0) { @@ -312,11 +319,11 @@ void LinearTreeLearner::CalculateLinear(Tree* tree, bool is_refit, const score_t } if (!HAS_NAN) { for (int leaf_num = 0; leaf_num < num_leaves; ++leaf_num) { - total_nonzero[leaf_num] = data_partition_->leaf_count(leaf_num); + total_nonzero[leaf_num] = this->data_partition_->leaf_count(leaf_num); } } double shrinkage = tree->shrinkage(); - double decay_rate = config_->refit_decay_rate; + double decay_rate = this->config_->refit_decay_rate; // copy into eigen matrices and solve #pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static) for (int leaf_num = 0; leaf_num < num_leaves; ++leaf_num) { @@ -340,7 +347,7 @@ void LinearTreeLearner::CalculateLinear(Tree* tree, bool is_refit, const score_t XTHX_mat(feat1, feat2) = XTHX_[leaf_num][j]; XTHX_mat(feat2, feat1) = XTHX_mat(feat1, feat2); if ((feat1 == feat2) && (feat1 < num_feat)) { - XTHX_mat(feat1, feat2) += config_->linear_lambda; + XTHX_mat(feat1, feat2) += this->config_->linear_lambda; } ++j; } @@ -366,7 +373,7 @@ void LinearTreeLearner::CalculateLinear(Tree* tree, bool is_refit, const score_t tree->SetLeafFeaturesInner(leaf_num, features_new); std::vector features_raw(features_new.size()); for (size_t i = 0; i < features_new.size(); ++i) { - features_raw[i] = train_data_->RealFeatureIndex(features_new[i]); + features_raw[i] = this->train_data_->RealFeatureIndex(features_new[i]); } tree->SetLeafFeatures(leaf_num, features_raw); tree->SetLeafCoeffs(leaf_num, coeffs_vec); @@ -378,4 +385,19 @@ void LinearTreeLearner::CalculateLinear(Tree* tree, bool is_refit, const score_t } } } + +template void LinearTreeLearner::Init(const Dataset* train_data, bool is_constant_hessian); +template void LinearTreeLearner::InitLinear(const Dataset* train_data, const int max_leaves); +template Tree* LinearTreeLearner::Train(const score_t* gradients, const score_t *hessians, bool is_first_tree); +template Tree* LinearTreeLearner::FitByExistingTree(const Tree* old_tree, const score_t* gradients, const score_t *hessians) const; +template Tree* LinearTreeLearner::FitByExistingTree(const Tree* old_tree, const std::vector& leaf_pred, + const score_t* gradients, const score_t *hessians) const; + +template void LinearTreeLearner::Init(const Dataset* train_data, bool is_constant_hessian); +template void LinearTreeLearner::InitLinear(const Dataset* train_data, const int max_leaves); +template Tree* LinearTreeLearner::Train(const score_t* gradients, const score_t *hessians, bool is_first_tree); +template Tree* LinearTreeLearner::FitByExistingTree(const Tree* old_tree, const score_t* gradients, const score_t *hessians) const; +template Tree* LinearTreeLearner::FitByExistingTree(const Tree* old_tree, const std::vector& leaf_pred, + const score_t* gradients, const score_t *hessians) const; + } // namespace LightGBM diff --git a/src/treelearner/linear_tree_learner.h b/src/treelearner/linear_tree_learner.h index e20a80ad42d3..376040cc6583 100644 --- a/src/treelearner/linear_tree_learner.h +++ b/src/treelearner/linear_tree_learner.h @@ -11,13 +11,15 @@ #include #include +#include "gpu_tree_learner.h" #include "serial_tree_learner.h" namespace LightGBM { -class LinearTreeLearner: public SerialTreeLearner { +template +class LinearTreeLearner: public TREE_LEARNER_TYPE { public: - explicit LinearTreeLearner(const Config* config) : SerialTreeLearner(config) {} + explicit LinearTreeLearner(const Config* config) : TREE_LEARNER_TYPE(config) {} void Init(const Dataset* train_data, bool is_constant_hessian) override; @@ -38,12 +40,12 @@ class LinearTreeLearner: public SerialTreeLearner { void AddPredictionToScore(const Tree* tree, double* out_score) const override { - CHECK_LE(tree->num_leaves(), data_partition_->num_leaves()); + CHECK_LE(tree->num_leaves(), this->data_partition_->num_leaves()); bool has_nan = false; if (any_nan_) { for (int i = 0; i < tree->num_leaves() - 1 ; ++i) { // use split_feature because split_feature_inner doesn't work when refitting existing tree - if (contains_nan_[train_data_->InnerFeatureIndex(tree->split_feature(i))]) { + if (contains_nan_[this->train_data_->InnerFeatureIndex(tree->split_feature(i))]) { has_nan = true; break; } @@ -69,13 +71,13 @@ class LinearTreeLearner: public SerialTreeLearner { leaf_coeff[leaf_num] = tree->LeafCoeffs(leaf_num); leaf_output[leaf_num] = tree->LeafOutput(leaf_num); for (int feat : tree->LeafFeaturesInner(leaf_num)) { - feat_ptr[leaf_num].push_back(train_data_->raw_index(feat)); + feat_ptr[leaf_num].push_back(this->train_data_->raw_index(feat)); } leaf_num_features[leaf_num] = static_cast(feat_ptr[leaf_num].size()); } OMP_INIT_EX(); -#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static) if (num_data_ > 1024) - for (int i = 0; i < num_data_; ++i) { +#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static) if (this->num_data_ > 1024) + for (int i = 0; i < this->num_data_; ++i) { OMP_LOOP_EX_BEGIN(); int leaf_num = leaf_map_[i]; if (leaf_num < 0) { diff --git a/src/treelearner/tree_learner.cpp b/src/treelearner/tree_learner.cpp index 9c5eef9580f1..13d607a2ee5f 100644 --- a/src/treelearner/tree_learner.cpp +++ b/src/treelearner/tree_learner.cpp @@ -5,7 +5,6 @@ #include #include "gpu_tree_learner.h" -#include "gpu_linear_tree_learner.h" #include "linear_tree_learner.h" #include "parallel_tree_learner.h" #include "serial_tree_learner.h" @@ -18,7 +17,7 @@ TreeLearner* TreeLearner::CreateTreeLearner(const std::string& learner_type, con if (device_type == std::string("cpu")) { if (learner_type == std::string("serial")) { if (config->linear_tree) { - return new LinearTreeLearner(config); + return new LinearTreeLearner(config); } else { return new SerialTreeLearner(config); } @@ -32,7 +31,7 @@ TreeLearner* TreeLearner::CreateTreeLearner(const std::string& learner_type, con } else if (device_type == std::string("gpu")) { if (learner_type == std::string("serial")) { if (config->linear_tree) { - return new GPULinearTreeLearner(config); + return new LinearTreeLearner(config); } else { return new GPUTreeLearner(config); }