From f77e0adf59bfff6b64e2d04e84a0da3079cbe145 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Thu, 17 Mar 2022 05:03:53 +0100 Subject: [PATCH] [python] make `early_stopping` callback pickleable (#5012) * Turn `early_stopping` into a Callable class * Fix * Lint * Remove print * Fix order * Revert "Lint" This reverts commit 7ca8b557572446888cf793c0082d9a7efd1e29a7. * Apply suggestion from code review * Nit * Lint * Move callable class outside the func for pickling * Move _pickle and _unpickle to tests utils * Add early stopping callback picklability test * Nit * Fix * Lint * Improve type hint * Lint * Lint * Add cloudpickle to test_windows * Update tests/python_package_test/test_engine.py * Fix * Apply suggestions from code review --- .ci/test_windows.ps1 | 2 +- python-package/lightgbm/callback.py | 245 +++++++++++---------- tests/python_package_test/test_callback.py | 22 ++ tests/python_package_test/test_dask.py | 47 +--- tests/python_package_test/utils.py | 29 +++ 5 files changed, 184 insertions(+), 161 deletions(-) create mode 100644 tests/python_package_test/test_callback.py diff --git a/.ci/test_windows.ps1 b/.ci/test_windows.ps1 index d4c5012a1b87..2e69513a2abb 100644 --- a/.ci/test_windows.ps1 +++ b/.ci/test_windows.ps1 @@ -50,7 +50,7 @@ if ($env:TASK -eq "swig") { Exit 0 } -conda install -q -y -n $env:CONDA_ENV joblib matplotlib numpy pandas psutil pytest scikit-learn scipy ; Check-Output $? +conda install -q -y -n $env:CONDA_ENV cloudpickle joblib matplotlib numpy pandas psutil pytest scikit-learn scipy ; Check-Output $? # python-graphviz has to be installed separately to prevent conda from downgrading to pypy conda install -q -y -n $env:CONDA_ENV libxml2 python-graphviz ; Check-Output $? diff --git a/python-package/lightgbm/callback.py b/python-package/lightgbm/callback.py index 3eee0ba499d2..2fc301b0e509 100644 --- a/python-package/lightgbm/callback.py +++ b/python-package/lightgbm/callback.py @@ -12,14 +12,6 @@ ] -def _gt_delta(curr_score: float, best_score: float, delta: float) -> bool: - return curr_score > best_score + delta - - -def _lt_delta(curr_score: float, best_score: float, delta: float) -> bool: - return curr_score < best_score - delta - - class EarlyStopException(Exception): """Exception of early stopping.""" @@ -199,156 +191,165 @@ def _callback(env: CallbackEnv) -> None: return _callback -def early_stopping(stopping_rounds: int, first_metric_only: bool = False, verbose: bool = True, min_delta: Union[float, List[float]] = 0.0) -> Callable: - """Create a callback that activates early stopping. - - Activates early stopping. - The model will train until the validation score doesn't improve by at least ``min_delta``. - Validation score needs to improve at least every ``stopping_rounds`` round(s) - to continue training. - Requires at least one validation data and one metric. - If there's more than one, will check all of them. But the training data is ignored anyway. - To check only the first metric set ``first_metric_only`` to True. - The index of iteration that has the best performance will be saved in the ``best_iteration`` attribute of a model. - - Parameters - ---------- - stopping_rounds : int - The possible number of rounds without the trend occurrence. - first_metric_only : bool, optional (default=False) - Whether to use only the first metric for early stopping. - verbose : bool, optional (default=True) - Whether to log message with early stopping information. - By default, standard output resource is used. - Use ``register_logger()`` function to register a custom logger. - min_delta : float or list of float, optional (default=0.0) - Minimum improvement in score to keep training. - If float, this single value is used for all metrics. - If list, its length should match the total number of metrics. - - Returns - ------- - callback : callable - The callback that activates early stopping. - """ - best_score = [] - best_iter = [] - best_score_list: list = [] - cmp_op = [] - enabled = True - first_metric = '' - - def _init(env: CallbackEnv) -> None: - nonlocal best_score - nonlocal best_iter - nonlocal best_score_list - nonlocal cmp_op - nonlocal enabled - nonlocal first_metric - enabled = not any(env.params.get(boost_alias, "") == 'dart' for boost_alias - in _ConfigAliases.get("boosting")) - if not enabled: +class _EarlyStoppingCallback: + """Internal early stopping callable class.""" + + def __init__( + self, + stopping_rounds: int, + first_metric_only: bool = False, + verbose: bool = True, + min_delta: Union[float, List[float]] = 0.0 + ) -> None: + self.order = 30 + self.before_iteration = False + + self.stopping_rounds = stopping_rounds + self.first_metric_only = first_metric_only + self.verbose = verbose + self.min_delta = min_delta + + self.enabled = True + self._reset_storages() + + def _reset_storages(self) -> None: + self.best_score = [] + self.best_iter = [] + self.best_score_list = [] + self.cmp_op = [] + self.first_metric = '' + + def _gt_delta(self, curr_score: float, best_score: float, delta: float) -> bool: + return curr_score > best_score + delta + + def _lt_delta(self, curr_score: float, best_score: float, delta: float) -> bool: + return curr_score < best_score - delta + + def _init(self, env: CallbackEnv) -> None: + self.enabled = not any(env.params.get(boost_alias, "") == 'dart' for boost_alias + in _ConfigAliases.get("boosting")) + if not self.enabled: _log_warning('Early stopping is not available in dart mode') return if not env.evaluation_result_list: raise ValueError('For early stopping, ' 'at least one dataset and eval metric is required for evaluation') - if stopping_rounds <= 0: + if self.stopping_rounds <= 0: raise ValueError("stopping_rounds should be greater than zero.") - if verbose: - _log_info(f"Training until validation scores don't improve for {stopping_rounds} rounds") + if self.verbose: + _log_info(f"Training until validation scores don't improve for {self.stopping_rounds} rounds") - # reset storages - best_score = [] - best_iter = [] - best_score_list = [] - cmp_op = [] - first_metric = '' + self._reset_storages() n_metrics = len(set(m[1] for m in env.evaluation_result_list)) n_datasets = len(env.evaluation_result_list) // n_metrics - if isinstance(min_delta, list): - if not all(t >= 0 for t in min_delta): + if isinstance(self.min_delta, list): + if not all(t >= 0 for t in self.min_delta): raise ValueError('Values for early stopping min_delta must be non-negative.') - if len(min_delta) == 0: - if verbose: + if len(self.min_delta) == 0: + if self.verbose: _log_info('Disabling min_delta for early stopping.') deltas = [0.0] * n_datasets * n_metrics - elif len(min_delta) == 1: - if verbose: - _log_info(f'Using {min_delta[0]} as min_delta for all metrics.') - deltas = min_delta * n_datasets * n_metrics + elif len(self.min_delta) == 1: + if self.verbose: + _log_info(f'Using {self.min_delta[0]} as min_delta for all metrics.') + deltas = self.min_delta * n_datasets * n_metrics else: - if len(min_delta) != n_metrics: + if len(self.min_delta) != n_metrics: raise ValueError('Must provide a single value for min_delta or as many as metrics.') - if first_metric_only and verbose: - _log_info(f'Using only {min_delta[0]} as early stopping min_delta.') - deltas = min_delta * n_datasets + if self.first_metric_only and self.verbose: + _log_info(f'Using only {self.min_delta[0]} as early stopping min_delta.') + deltas = self.min_delta * n_datasets else: - if min_delta < 0: + if self.min_delta < 0: raise ValueError('Early stopping min_delta must be non-negative.') - if min_delta > 0 and n_metrics > 1 and not first_metric_only and verbose: - _log_info(f'Using {min_delta} as min_delta for all metrics.') - deltas = [min_delta] * n_datasets * n_metrics + if self.min_delta > 0 and n_metrics > 1 and not self.first_metric_only and self.verbose: + _log_info(f'Using {self.min_delta} as min_delta for all metrics.') + deltas = [self.min_delta] * n_datasets * n_metrics # split is needed for " " case (e.g. "train l1") - first_metric = env.evaluation_result_list[0][1].split(" ")[-1] + self.first_metric = env.evaluation_result_list[0][1].split(" ")[-1] for eval_ret, delta in zip(env.evaluation_result_list, deltas): - best_iter.append(0) - best_score_list.append(None) + self.best_iter.append(0) + self.best_score_list.append(None) if eval_ret[3]: # greater is better - best_score.append(float('-inf')) - cmp_op.append(partial(_gt_delta, delta=delta)) + self.best_score.append(float('-inf')) + self.cmp_op.append(partial(self._gt_delta, delta=delta)) else: - best_score.append(float('inf')) - cmp_op.append(partial(_lt_delta, delta=delta)) + self.best_score.append(float('inf')) + self.cmp_op.append(partial(self._lt_delta, delta=delta)) - def _final_iteration_check(env: CallbackEnv, eval_name_splitted: List[str], i: int) -> None: - nonlocal best_iter - nonlocal best_score_list + def _final_iteration_check(self, env: CallbackEnv, eval_name_splitted: List[str], i: int) -> None: if env.iteration == env.end_iteration - 1: - if verbose: - best_score_str = '\t'.join([_format_eval_result(x) for x in best_score_list[i]]) + if self.verbose: + best_score_str = '\t'.join([_format_eval_result(x) for x in self.best_score_list[i]]) _log_info('Did not meet early stopping. ' - f'Best iteration is:\n[{best_iter[i] + 1}]\t{best_score_str}') - if first_metric_only: + f'Best iteration is:\n[{self.best_iter[i] + 1}]\t{best_score_str}') + if self.first_metric_only: _log_info(f"Evaluated only: {eval_name_splitted[-1]}") - raise EarlyStopException(best_iter[i], best_score_list[i]) + raise EarlyStopException(self.best_iter[i], self.best_score_list[i]) - def _callback(env: CallbackEnv) -> None: - nonlocal best_score - nonlocal best_iter - nonlocal best_score_list - nonlocal cmp_op - nonlocal enabled - nonlocal first_metric + def __call__(self, env: CallbackEnv) -> None: if env.iteration == env.begin_iteration: - _init(env) - if not enabled: + self._init(env) + if not self.enabled: return for i in range(len(env.evaluation_result_list)): score = env.evaluation_result_list[i][2] - if best_score_list[i] is None or cmp_op[i](score, best_score[i]): - best_score[i] = score - best_iter[i] = env.iteration - best_score_list[i] = env.evaluation_result_list + if self.best_score_list[i] is None or self.cmp_op[i](score, self.best_score[i]): + self.best_score[i] = score + self.best_iter[i] = env.iteration + self.best_score_list[i] = env.evaluation_result_list # split is needed for " " case (e.g. "train l1") eval_name_splitted = env.evaluation_result_list[i][1].split(" ") - if first_metric_only and first_metric != eval_name_splitted[-1]: + if self.first_metric_only and self.first_metric != eval_name_splitted[-1]: continue # use only the first metric for early stopping if ((env.evaluation_result_list[i][0] == "cv_agg" and eval_name_splitted[0] == "train" - or env.evaluation_result_list[i][0] == env.model._train_data_name)): - _final_iteration_check(env, eval_name_splitted, i) + or env.evaluation_result_list[i][0] == env.model._train_data_name)): + self._final_iteration_check(env, eval_name_splitted, i) continue # train data for lgb.cv or sklearn wrapper (underlying lgb.train) - elif env.iteration - best_iter[i] >= stopping_rounds: - if verbose: - eval_result_str = '\t'.join([_format_eval_result(x) for x in best_score_list[i]]) - _log_info(f"Early stopping, best iteration is:\n[{best_iter[i] + 1}]\t{eval_result_str}") - if first_metric_only: + elif env.iteration - self.best_iter[i] >= self.stopping_rounds: + if self.verbose: + eval_result_str = '\t'.join([_format_eval_result(x) for x in self.best_score_list[i]]) + _log_info(f"Early stopping, best iteration is:\n[{self.best_iter[i] + 1}]\t{eval_result_str}") + if self.first_metric_only: _log_info(f"Evaluated only: {eval_name_splitted[-1]}") - raise EarlyStopException(best_iter[i], best_score_list[i]) - _final_iteration_check(env, eval_name_splitted, i) - _callback.order = 30 # type: ignore - return _callback + raise EarlyStopException(self.best_iter[i], self.best_score_list[i]) + self._final_iteration_check(env, eval_name_splitted, i) + + +def early_stopping(stopping_rounds: int, first_metric_only: bool = False, verbose: bool = True, min_delta: Union[float, List[float]] = 0.0) -> _EarlyStoppingCallback: + """Create a callback that activates early stopping. + + Activates early stopping. + The model will train until the validation score doesn't improve by at least ``min_delta``. + Validation score needs to improve at least every ``stopping_rounds`` round(s) + to continue training. + Requires at least one validation data and one metric. + If there's more than one, will check all of them. But the training data is ignored anyway. + To check only the first metric set ``first_metric_only`` to True. + The index of iteration that has the best performance will be saved in the ``best_iteration`` attribute of a model. + + Parameters + ---------- + stopping_rounds : int + The possible number of rounds without the trend occurrence. + first_metric_only : bool, optional (default=False) + Whether to use only the first metric for early stopping. + verbose : bool, optional (default=True) + Whether to log message with early stopping information. + By default, standard output resource is used. + Use ``register_logger()`` function to register a custom logger. + min_delta : float or list of float, optional (default=0.0) + Minimum improvement in score to keep training. + If float, this single value is used for all metrics. + If list, its length should match the total number of metrics. + + Returns + ------- + callback : _EarlyStoppingCallback + The callback that activates early stopping. + """ + return _EarlyStoppingCallback(stopping_rounds=stopping_rounds, first_metric_only=first_metric_only, verbose=verbose, min_delta=min_delta) diff --git a/tests/python_package_test/test_callback.py b/tests/python_package_test/test_callback.py new file mode 100644 index 000000000000..0f339aa3a53e --- /dev/null +++ b/tests/python_package_test/test_callback.py @@ -0,0 +1,22 @@ +# coding: utf-8 +import pytest + +import lightgbm as lgb + +from .utils import pickle_obj, unpickle_obj + + +@pytest.mark.parametrize('serializer', ["pickle", "joblib", "cloudpickle"]) +def test_early_stopping_callback_is_picklable(serializer, tmp_path): + callback = lgb.early_stopping(stopping_rounds=5) + tmp_file = tmp_path / "early_stopping.pkl" + pickle_obj( + obj=callback, + filepath=tmp_file, + serializer=serializer + ) + callback_from_disk = unpickle_obj( + filepath=tmp_file, + serializer=serializer + ) + assert callback.stopping_rounds == callback_from_disk.stopping_rounds diff --git a/tests/python_package_test/test_dask.py b/tests/python_package_test/test_dask.py index d005f950f07a..b56c206ccf60 100644 --- a/tests/python_package_test/test_dask.py +++ b/tests/python_package_test/test_dask.py @@ -2,7 +2,6 @@ """Tests for lightgbm.dask module""" import inspect -import pickle import random import socket from itertools import groupby @@ -24,10 +23,8 @@ if not lgb.compat.DASK_INSTALLED: pytest.skip('Dask is not installed', allow_module_level=True) -import cloudpickle import dask.array as da import dask.dataframe as dd -import joblib import numpy as np import pandas as pd import sklearn.utils.estimator_checks as sklearn_checks @@ -37,7 +34,7 @@ from scipy.stats import spearmanr from sklearn.datasets import make_blobs, make_regression -from .utils import make_ranking +from .utils import make_ranking, pickle_obj, unpickle_obj tasks = ['binary-classification', 'multiclass-classification', 'regression', 'ranking'] distributed_training_algorithms = ['data', 'voting'] @@ -234,32 +231,6 @@ def _constant_metric(y_true, y_pred): return metric_name, value, is_higher_better -def _pickle(obj, filepath, serializer): - if serializer == 'pickle': - with open(filepath, 'wb') as f: - pickle.dump(obj, f) - elif serializer == 'joblib': - joblib.dump(obj, filepath) - elif serializer == 'cloudpickle': - with open(filepath, 'wb') as f: - cloudpickle.dump(obj, f) - else: - raise ValueError(f'Unrecognized serializer type: {serializer}') - - -def _unpickle(filepath, serializer): - if serializer == 'pickle': - with open(filepath, 'rb') as f: - return pickle.load(f) - elif serializer == 'joblib': - return joblib.load(filepath) - elif serializer == 'cloudpickle': - with open(filepath, 'rb') as f: - return cloudpickle.load(f) - else: - raise ValueError(f'Unrecognized serializer type: {serializer}') - - def _objective_least_squares(y_true, y_pred): grad = y_pred - y_true hess = np.ones(len(y_true)) @@ -1341,23 +1312,23 @@ def test_model_and_local_version_are_picklable_whether_or_not_client_set_explici assert getattr(local_model, "client", None) is None tmp_file = tmp_path / "model-1.pkl" - _pickle( + pickle_obj( obj=dask_model, filepath=tmp_file, serializer=serializer ) - model_from_disk = _unpickle( + model_from_disk = unpickle_obj( filepath=tmp_file, serializer=serializer ) local_tmp_file = tmp_path / "local-model-1.pkl" - _pickle( + pickle_obj( obj=local_model, filepath=local_tmp_file, serializer=serializer ) - local_model_from_disk = _unpickle( + local_model_from_disk = unpickle_obj( filepath=local_tmp_file, serializer=serializer ) @@ -1397,23 +1368,23 @@ def test_model_and_local_version_are_picklable_whether_or_not_client_set_explici local_model.client_ tmp_file2 = tmp_path / "model-2.pkl" - _pickle( + pickle_obj( obj=dask_model, filepath=tmp_file2, serializer=serializer ) - fitted_model_from_disk = _unpickle( + fitted_model_from_disk = unpickle_obj( filepath=tmp_file2, serializer=serializer ) local_tmp_file2 = tmp_path / "local-model-2.pkl" - _pickle( + pickle_obj( obj=local_model, filepath=local_tmp_file2, serializer=serializer ) - local_fitted_model_from_disk = _unpickle( + local_fitted_model_from_disk = unpickle_obj( filepath=local_tmp_file2, serializer=serializer ) diff --git a/tests/python_package_test/utils.py b/tests/python_package_test/utils.py index f0a9ada31ffd..63950d471608 100644 --- a/tests/python_package_test/utils.py +++ b/tests/python_package_test/utils.py @@ -1,6 +1,9 @@ # coding: utf-8 +import pickle from functools import lru_cache +import cloudpickle +import joblib import numpy as np import sklearn.datasets from sklearn.utils import check_random_state @@ -131,3 +134,29 @@ def sklearn_multiclass_custom_objective(y_true, y_pred): factor = num_class / (num_class - 1) hess = factor * prob * (1 - prob) return grad, hess + + +def pickle_obj(obj, filepath, serializer): + if serializer == 'pickle': + with open(filepath, 'wb') as f: + pickle.dump(obj, f) + elif serializer == 'joblib': + joblib.dump(obj, filepath) + elif serializer == 'cloudpickle': + with open(filepath, 'wb') as f: + cloudpickle.dump(obj, f) + else: + raise ValueError(f'Unrecognized serializer type: {serializer}') + + +def unpickle_obj(filepath, serializer): + if serializer == 'pickle': + with open(filepath, 'rb') as f: + return pickle.load(f) + elif serializer == 'joblib': + return joblib.load(filepath) + elif serializer == 'cloudpickle': + with open(filepath, 'rb') as f: + return cloudpickle.load(f) + else: + raise ValueError(f'Unrecognized serializer type: {serializer}')