diff --git a/docs/Parameters.rst b/docs/Parameters.rst index 99d5a2bb1c82..f86dc797b3bf 100644 --- a/docs/Parameters.rst +++ b/docs/Parameters.rst @@ -240,6 +240,10 @@ Learning Control Parameters - ``<= 0`` means disable +- ``first_metric_only`` :raw-html:`🔗︎`, default = ``false``, type = bool + + - set this to ``true``, if you want to use only the first metric for early stopping + - ``max_delta_step`` :raw-html:`🔗︎`, default = ``0.0``, type = double, aliases: ``max_tree_output``, ``max_leaf_output`` - used to limit the max output of tree leaves diff --git a/include/LightGBM/config.h b/include/LightGBM/config.h index 2a6896f82dc0..b6c985bd2f51 100644 --- a/include/LightGBM/config.h +++ b/include/LightGBM/config.h @@ -260,6 +260,9 @@ struct Config { // desc = ``<= 0`` means disable int early_stopping_round = 0; + // desc = set this to ``true``, if you want to use only the first metric for early stopping + bool first_metric_only = false; + // alias = max_tree_output, max_leaf_output // desc = used to limit the max output of tree leaves // desc = ``<= 0`` means no constraint diff --git a/src/boosting/gbdt.cpp b/src/boosting/gbdt.cpp index b8c2f28de47d..aca1cdfe7d63 100644 --- a/src/boosting/gbdt.cpp +++ b/src/boosting/gbdt.cpp @@ -22,6 +22,7 @@ GBDT::GBDT() : iter_(0), train_data_(nullptr), objective_function_(nullptr), early_stopping_round_(0), +es_first_metric_only_(false), max_feature_idx_(0), num_tree_per_iteration_(1), num_class_(1), @@ -51,6 +52,7 @@ void GBDT::Init(const Config* config, const Dataset* train_data, const Objective num_class_ = config->num_class; config_ = std::unique_ptr(new Config(*config)); early_stopping_round_ = config_->early_stopping_round; + es_first_metric_only_ = config_->first_metric_only; shrinkage_rate_ = config_->learning_rate; std::string forced_splits_path = config->forcedsplits_filename; @@ -129,20 +131,18 @@ void GBDT::AddValidDataset(const Dataset* valid_data, } valid_score_updater_.push_back(std::move(new_score_updater)); valid_metrics_.emplace_back(); - if (early_stopping_round_ > 0) { - best_iter_.emplace_back(); - best_score_.emplace_back(); - best_msg_.emplace_back(); - } for (const auto& metric : valid_metrics) { valid_metrics_.back().push_back(metric); - if (early_stopping_round_ > 0) { - best_iter_.back().push_back(0); - best_score_.back().push_back(kMinScore); - best_msg_.back().emplace_back(); - } } valid_metrics_.back().shrink_to_fit(); + + if (early_stopping_round_ > 0) { + auto num_metrics = valid_metrics.size(); + if (es_first_metric_only_) { num_metrics = 1; } + best_iter_.emplace_back(num_metrics, 0); + best_score_.emplace_back(num_metrics, kMinScore); + best_msg_.emplace_back(num_metrics); + } } void GBDT::Boosting() { @@ -514,6 +514,7 @@ std::string GBDT::OutputMetric(int iter) { msg_buf << tmp_buf.str() << '\n'; } } + if (es_first_metric_only_ && j > 0) { continue; } if (ret.empty() && early_stopping_round_ > 0) { auto cur_score = valid_metrics_[i][j]->factor_to_bigger_better() * test_scores.back(); if (cur_score > best_score_[i][j]) { diff --git a/src/boosting/gbdt.h b/src/boosting/gbdt.h index 78784488362b..7bfbfcd8748d 100644 --- a/src/boosting/gbdt.h +++ b/src/boosting/gbdt.h @@ -434,6 +434,8 @@ class GBDT : public GBDTBase { std::vector> valid_metrics_; /*! \brief Number of rounds for early stopping */ int early_stopping_round_; + /*! \brief Only use first metric for early stopping */ + bool es_first_metric_only_; /*! \brief Best iteration(s) for early stopping */ std::vector> best_iter_; /*! \brief Best score(s) for early stopping */ diff --git a/src/io/config_auto.cpp b/src/io/config_auto.cpp index ca0ed85b9dd0..9d5c3f853526 100644 --- a/src/io/config_auto.cpp +++ b/src/io/config_auto.cpp @@ -181,6 +181,7 @@ std::unordered_set Config::parameter_set({ "feature_fraction", "feature_fraction_seed", "early_stopping_round", + "first_metric_only", "max_delta_step", "lambda_l1", "lambda_l2", @@ -312,6 +313,8 @@ void Config::GetMembersFromString(const std::unordered_map