diff --git a/docs/Parameters.rst b/docs/Parameters.rst
index 99d5a2bb1c82..f86dc797b3bf 100644
--- a/docs/Parameters.rst
+++ b/docs/Parameters.rst
@@ -240,6 +240,10 @@ Learning Control Parameters
- ``<= 0`` means disable
+- ``first_metric_only`` :raw-html:`🔗︎`, default = ``false``, type = bool
+
+ - set this to ``true``, if you want to use only the first metric for early stopping
+
- ``max_delta_step`` :raw-html:`🔗︎`, default = ``0.0``, type = double, aliases: ``max_tree_output``, ``max_leaf_output``
- used to limit the max output of tree leaves
diff --git a/include/LightGBM/config.h b/include/LightGBM/config.h
index 2a6896f82dc0..b6c985bd2f51 100644
--- a/include/LightGBM/config.h
+++ b/include/LightGBM/config.h
@@ -260,6 +260,9 @@ struct Config {
// desc = ``<= 0`` means disable
int early_stopping_round = 0;
+ // desc = set this to ``true``, if you want to use only the first metric for early stopping
+ bool first_metric_only = false;
+
// alias = max_tree_output, max_leaf_output
// desc = used to limit the max output of tree leaves
// desc = ``<= 0`` means no constraint
diff --git a/src/boosting/gbdt.cpp b/src/boosting/gbdt.cpp
index b8c2f28de47d..aca1cdfe7d63 100644
--- a/src/boosting/gbdt.cpp
+++ b/src/boosting/gbdt.cpp
@@ -22,6 +22,7 @@ GBDT::GBDT() : iter_(0),
train_data_(nullptr),
objective_function_(nullptr),
early_stopping_round_(0),
+es_first_metric_only_(false),
max_feature_idx_(0),
num_tree_per_iteration_(1),
num_class_(1),
@@ -51,6 +52,7 @@ void GBDT::Init(const Config* config, const Dataset* train_data, const Objective
num_class_ = config->num_class;
config_ = std::unique_ptr(new Config(*config));
early_stopping_round_ = config_->early_stopping_round;
+ es_first_metric_only_ = config_->first_metric_only;
shrinkage_rate_ = config_->learning_rate;
std::string forced_splits_path = config->forcedsplits_filename;
@@ -129,20 +131,18 @@ void GBDT::AddValidDataset(const Dataset* valid_data,
}
valid_score_updater_.push_back(std::move(new_score_updater));
valid_metrics_.emplace_back();
- if (early_stopping_round_ > 0) {
- best_iter_.emplace_back();
- best_score_.emplace_back();
- best_msg_.emplace_back();
- }
for (const auto& metric : valid_metrics) {
valid_metrics_.back().push_back(metric);
- if (early_stopping_round_ > 0) {
- best_iter_.back().push_back(0);
- best_score_.back().push_back(kMinScore);
- best_msg_.back().emplace_back();
- }
}
valid_metrics_.back().shrink_to_fit();
+
+ if (early_stopping_round_ > 0) {
+ auto num_metrics = valid_metrics.size();
+ if (es_first_metric_only_) { num_metrics = 1; }
+ best_iter_.emplace_back(num_metrics, 0);
+ best_score_.emplace_back(num_metrics, kMinScore);
+ best_msg_.emplace_back(num_metrics);
+ }
}
void GBDT::Boosting() {
@@ -514,6 +514,7 @@ std::string GBDT::OutputMetric(int iter) {
msg_buf << tmp_buf.str() << '\n';
}
}
+ if (es_first_metric_only_ && j > 0) { continue; }
if (ret.empty() && early_stopping_round_ > 0) {
auto cur_score = valid_metrics_[i][j]->factor_to_bigger_better() * test_scores.back();
if (cur_score > best_score_[i][j]) {
diff --git a/src/boosting/gbdt.h b/src/boosting/gbdt.h
index 78784488362b..7bfbfcd8748d 100644
--- a/src/boosting/gbdt.h
+++ b/src/boosting/gbdt.h
@@ -434,6 +434,8 @@ class GBDT : public GBDTBase {
std::vector> valid_metrics_;
/*! \brief Number of rounds for early stopping */
int early_stopping_round_;
+ /*! \brief Only use first metric for early stopping */
+ bool es_first_metric_only_;
/*! \brief Best iteration(s) for early stopping */
std::vector> best_iter_;
/*! \brief Best score(s) for early stopping */
diff --git a/src/io/config_auto.cpp b/src/io/config_auto.cpp
index ca0ed85b9dd0..9d5c3f853526 100644
--- a/src/io/config_auto.cpp
+++ b/src/io/config_auto.cpp
@@ -181,6 +181,7 @@ std::unordered_set Config::parameter_set({
"feature_fraction",
"feature_fraction_seed",
"early_stopping_round",
+ "first_metric_only",
"max_delta_step",
"lambda_l1",
"lambda_l2",
@@ -312,6 +313,8 @@ void Config::GetMembersFromString(const std::unordered_map