diff --git a/docs/Python-Intro.rst b/docs/Python-Intro.rst index 72ef7221c8ce..c9672c5d0169 100644 --- a/docs/Python-Intro.rst +++ b/docs/Python-Intro.rst @@ -200,6 +200,7 @@ Note that ``train()`` will return a model from the best iteration. This works with both metrics to minimize (L2, log loss, etc.) and to maximize (NDCG, AUC, etc.). Note that if you specify more than one evaluation metric, all of them will be used for early stopping. +However, you can change this behavior and make LightGBM check only the first metric for early stopping by creating ``early_stopping`` callback with ``first_metric_only=True``. Prediction ---------- diff --git a/python-package/lightgbm/callback.py b/python-package/lightgbm/callback.py index 6e5e8c5eb70e..0af06540dd30 100644 --- a/python-package/lightgbm/callback.py +++ b/python-package/lightgbm/callback.py @@ -150,7 +150,7 @@ def _callback(env): return _callback -def early_stopping(stopping_rounds, verbose=True): +def early_stopping(stopping_rounds, first_metric_only=False, verbose=True): """Create a callback that activates early stopping. Note @@ -161,11 +161,14 @@ def early_stopping(stopping_rounds, verbose=True): to continue training. Requires at least one validation data and one metric. If there's more than one, will check all of them. But the training data is ignored anyway. + To check only the first metric set ``first_metric_only`` to True. Parameters ---------- stopping_rounds : int The possible number of rounds without the trend occurrence. + first_metric_only : bool, optional (default=False) + Whether to use only the first metric for early stopping. verbose : bool, optional (default=True) Whether to print message with early stopping information. @@ -227,5 +230,7 @@ def _callback(env): print('Did not meet early stopping. Best iteration is:\n[%d]\t%s' % ( best_iter[i] + 1, '\t'.join([_format_eval_result(x) for x in best_score_list[i]]))) raise EarlyStopException(best_iter[i], best_score_list[i]) + if first_metric_only: # the only first metric is used for early stopping + break _callback.order = 30 return _callback diff --git a/python-package/lightgbm/engine.py b/python-package/lightgbm/engine.py index b0d6002ba715..9792d1c9d827 100644 --- a/python-package/lightgbm/engine.py +++ b/python-package/lightgbm/engine.py @@ -66,6 +66,8 @@ def train(params, train_set, num_boost_round=100, to continue training. Requires at least one validation data and one metric. If there's more than one, will check all of them. But the training data is ignored anyway. + To check only the first metric you can pass in ``callbacks`` + ``early_stopping`` callback with ``first_metric_only=True``. The index of iteration that has the best performance will be saved in the ``best_iteration`` field if early stopping logic is enabled by setting ``early_stopping_rounds``. evals_result: dict or None, optional (default=None) @@ -391,6 +393,8 @@ def cv(params, train_set, num_boost_round=100, CV score needs to improve at least every ``early_stopping_rounds`` round(s) to continue. Requires at least one metric. If there's more than one, will check all of them. + To check only the first metric you can pass in ``callbacks`` + ``early_stopping`` callback with ``first_metric_only=True``. Last entry in evaluation history is the one from the best iteration. fpreproc : callable or None, optional (default=None) Preprocessing function that takes (dtrain, dtest, params) diff --git a/python-package/lightgbm/sklearn.py b/python-package/lightgbm/sklearn.py index f9b1697e256d..887eedf3c3cb 100644 --- a/python-package/lightgbm/sklearn.py +++ b/python-package/lightgbm/sklearn.py @@ -369,6 +369,8 @@ def fit(self, X, y, to continue training. Requires at least one validation data and one metric. If there's more than one, will check all of them. But the training data is ignored anyway. + To check only the first metric you can pass in ``callbacks`` + ``early_stopping`` callback with ``first_metric_only=True``. verbose : bool or int, optional (default=True) Requires at least one evaluation data. If True, the eval metric on the eval set is printed at each boosting stage. diff --git a/tests/python_package_test/test_engine.py b/tests/python_package_test/test_engine.py index c5c6e0ef7f0b..86cb403bfd6a 100644 --- a/tests/python_package_test/test_engine.py +++ b/tests/python_package_test/test_engine.py @@ -1,6 +1,7 @@ # coding: utf-8 # pylint: skip-file import copy +import itertools import math import os import psutil @@ -1318,3 +1319,45 @@ def test_get_split_value_histogram(self): np.testing.assert_almost_equal(bin_edges[1:][mask], hist[:, 0]) # test histogram is disabled for categorical features self.assertRaises(lgb.basic.LightGBMError, gbm.get_split_value_histogram, 2) + + def test_early_stopping_for_only_first_metric(self): + X, y = load_boston(True) + X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42) + params = { + 'objective': 'regression', + 'metric': 'None', + 'verbose': -1 + } + lgb_train = lgb.Dataset(X_train, y_train) + lgb_eval = lgb.Dataset(X_test, y_test, reference=lgb_train) + + decreasing_generator = itertools.count(0, -1) + + def decreasing_metric(preds, train_data): + return ('decreasing_metric', next(decreasing_generator), False) + + def constant_metric(preds, train_data): + return ('constant_metric', 0.0, False) + + # test that all metrics are checked (default behaviour) + early_stop_callback = lgb.early_stopping(5, verbose=False) + gbm = lgb.train(params, lgb_train, num_boost_round=20, valid_sets=[lgb_eval], + feval=lambda preds, train_data: [decreasing_metric(preds, train_data), + constant_metric(preds, train_data)], + callbacks=[early_stop_callback]) + self.assertEqual(gbm.best_iteration, 1) + + # test that only the first metric is checked + early_stop_callback = lgb.early_stopping(5, first_metric_only=True, verbose=False) + gbm = lgb.train(params, lgb_train, num_boost_round=20, valid_sets=[lgb_eval], + feval=lambda preds, train_data: [decreasing_metric(preds, train_data), + constant_metric(preds, train_data)], + callbacks=[early_stop_callback]) + self.assertEqual(gbm.best_iteration, 20) + # ... change the order of metrics + early_stop_callback = lgb.early_stopping(5, first_metric_only=True, verbose=False) + gbm = lgb.train(params, lgb_train, num_boost_round=20, valid_sets=[lgb_eval], + feval=lambda preds, train_data: [constant_metric(preds, train_data), + decreasing_metric(preds, train_data)], + callbacks=[early_stop_callback]) + self.assertEqual(gbm.best_iteration, 1)