Skip to content

Commit

Permalink
[python-package] support saving and loading CVBooster (fixes #3556) (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
nyanp authored Aug 16, 2022
1 parent 6b695c2 commit 4a9b08e
Show file tree
Hide file tree
Showing 4 changed files with 204 additions and 23 deletions.
123 changes: 119 additions & 4 deletions python-package/lightgbm/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
"""Library with training routines of LightGBM."""
import collections
import copy
import json
from operator import attrgetter
from pathlib import Path
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
Expand Down Expand Up @@ -271,9 +272,14 @@ def train(
class CVBooster:
"""CVBooster in LightGBM.
Auxiliary data structure to hold and redirect all boosters of ``cv`` function.
Auxiliary data structure to hold and redirect all boosters of ``cv()`` function.
This class has the same methods as Booster class.
All method calls are actually performed for underlying Boosters and then all returned results are returned in a list.
All method calls, except for the following methods, are actually performed for underlying Boosters and
then all returned results are returned in a list.
- ``model_from_string()``
- ``model_to_string()``
- ``save_model()``
Attributes
----------
Expand All @@ -283,18 +289,43 @@ class CVBooster:
The best iteration of fitted model.
"""

def __init__(self):
def __init__(
self,
model_file: Optional[Union[str, Path]] = None
):
"""Initialize the CVBooster.
Generally, no need to instantiate manually.
Parameters
----------
model_file : str, pathlib.Path or None, optional (default=None)
Path to the CVBooster model file.
"""
self.boosters = []
self.best_iteration = -1

if model_file is not None:
with open(model_file, "r") as file:
self._from_dict(json.load(file))

def _append(self, booster: Booster) -> None:
"""Add a booster to CVBooster."""
self.boosters.append(booster)

def _from_dict(self, models: Dict[str, Any]) -> None:
"""Load CVBooster from dict."""
self.best_iteration = models["best_iteration"]
self.boosters = []
for model_str in models["boosters"]:
self._append(Booster(model_str=model_str))

def _to_dict(self, num_iteration: Optional[int], start_iteration: int, importance_type: str) -> Dict[str, Any]:
"""Serialize CVBooster to dict."""
models_str = []
for booster in self.boosters:
models_str.append(booster.model_to_string(num_iteration=num_iteration, start_iteration=start_iteration,
importance_type=importance_type))
return {"boosters": models_str, "best_iteration": self.best_iteration}

def __getattr__(self, name: str) -> Callable[[Any, Any], List[Any]]:
"""Redirect methods call of CVBooster."""
def handler_function(*args: Any, **kwargs: Any) -> List[Any]:
Expand All @@ -305,6 +336,90 @@ def handler_function(*args: Any, **kwargs: Any) -> List[Any]:
return ret
return handler_function

def __getstate__(self) -> Dict[str, Any]:
return vars(self)

def __setstate__(self, state: Dict[str, Any]) -> None:
vars(self).update(state)

def model_from_string(self, model_str: str) -> "CVBooster":
"""Load CVBooster from a string.
Parameters
----------
model_str : str
Model will be loaded from this string.
Returns
-------
self : CVBooster
Loaded CVBooster object.
"""
self._from_dict(json.loads(model_str))
return self

def model_to_string(
self,
num_iteration: Optional[int] = None,
start_iteration: int = 0,
importance_type: str = 'split'
) -> str:
"""Save CVBooster to JSON string.
Parameters
----------
num_iteration : int or None, optional (default=None)
Index of the iteration that should be saved.
If None, if the best iteration exists, it is saved; otherwise, all iterations are saved.
If <= 0, all iterations are saved.
start_iteration : int, optional (default=0)
Start index of the iteration that should be saved.
importance_type : str, optional (default="split")
What type of feature importance should be saved.
If "split", result contains numbers of times the feature is used in a model.
If "gain", result contains total gains of splits which use the feature.
Returns
-------
str_repr : str
JSON string representation of CVBooster.
"""
return json.dumps(self._to_dict(num_iteration, start_iteration, importance_type))

def save_model(
self,
filename: Union[str, Path],
num_iteration: Optional[int] = None,
start_iteration: int = 0,
importance_type: str = 'split'
) -> "CVBooster":
"""Save CVBooster to a file as JSON text.
Parameters
----------
filename : str or pathlib.Path
Filename to save CVBooster.
num_iteration : int or None, optional (default=None)
Index of the iteration that should be saved.
If None, if the best iteration exists, it is saved; otherwise, all iterations are saved.
If <= 0, all iterations are saved.
start_iteration : int, optional (default=0)
Start index of the iteration that should be saved.
importance_type : str, optional (default="split")
What type of feature importance should be saved.
If "split", result contains numbers of times the feature is used in a model.
If "gain", result contains total gains of splits which use the feature.
Returns
-------
self : CVBooster
Returns self.
"""
with open(filename, "w") as file:
json.dump(self._to_dict(num_iteration, start_iteration, importance_type), file)

return self


def _make_n_folds(
full_data: Dataset,
Expand Down
18 changes: 1 addition & 17 deletions tests/python_package_test/test_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,7 @@

import lightgbm as lgb

from .utils import pickle_obj, unpickle_obj

SERIALIZERS = ["pickle", "joblib", "cloudpickle"]


def pickle_and_unpickle_object(obj, serializer):
with lgb.basic._TempFile() as tmp_file:
pickle_obj(
obj=obj,
filepath=tmp_file.name,
serializer=serializer
)
obj_from_disk = unpickle_obj(
filepath=tmp_file.name,
serializer=serializer
)
return obj_from_disk
from .utils import SERIALIZERS, pickle_and_unpickle_object, pickle_obj, unpickle_obj


def reset_feature_fraction(boosting_round):
Expand Down
68 changes: 66 additions & 2 deletions tests/python_package_test/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@
import lightgbm as lgb
from lightgbm.compat import PANDAS_INSTALLED, pd_DataFrame

from .utils import (dummy_obj, load_boston, load_breast_cancer, load_digits, load_iris, logistic_sigmoid,
make_synthetic_regression, mse_obj, sklearn_multiclass_custom_objective, softmax)
from .utils import (SERIALIZERS, dummy_obj, load_boston, load_breast_cancer, load_digits, load_iris, logistic_sigmoid,
make_synthetic_regression, mse_obj, pickle_and_unpickle_object, sklearn_multiclass_custom_objective,
softmax)

decreasing_generator = itertools.count(0, -1)

Expand Down Expand Up @@ -1073,6 +1074,69 @@ def test_cvbooster():
assert ret < 0.15


def test_cvbooster_save_load(tmp_path):
X, y = load_breast_cancer(return_X_y=True)
X_train, X_test, y_train, _ = train_test_split(X, y, test_size=0.1, random_state=42)
params = {
'objective': 'binary',
'metric': 'binary_logloss',
'verbose': -1,
}
nfold = 3
lgb_train = lgb.Dataset(X_train, y_train)

cv_res = lgb.cv(params, lgb_train,
num_boost_round=10,
nfold=nfold,
callbacks=[lgb.early_stopping(stopping_rounds=5)],
return_cvbooster=True)
cvbooster = cv_res['cvbooster']
preds = cvbooster.predict(X_test)
best_iteration = cvbooster.best_iteration

model_path_txt = str(tmp_path / 'lgb.model')

cvbooster.save_model(model_path_txt)
model_string = cvbooster.model_to_string()
del cvbooster

cvbooster_from_txt_file = lgb.CVBooster(model_file=model_path_txt)
cvbooster_from_string = lgb.CVBooster().model_from_string(model_string)
for cvbooster_loaded in [cvbooster_from_txt_file, cvbooster_from_string]:
assert best_iteration == cvbooster_loaded.best_iteration
np.testing.assert_array_equal(preds, cvbooster_loaded.predict(X_test))


@pytest.mark.parametrize('serializer', SERIALIZERS)
def test_cvbooster_picklable(serializer):
X, y = load_breast_cancer(return_X_y=True)
X_train, X_test, y_train, _ = train_test_split(X, y, test_size=0.1, random_state=42)
params = {
'objective': 'binary',
'metric': 'binary_logloss',
'verbose': -1,
}
nfold = 3
lgb_train = lgb.Dataset(X_train, y_train)

cv_res = lgb.cv(params, lgb_train,
num_boost_round=10,
nfold=nfold,
callbacks=[lgb.early_stopping(stopping_rounds=5)],
return_cvbooster=True)
cvbooster = cv_res['cvbooster']
preds = cvbooster.predict(X_test)
best_iteration = cvbooster.best_iteration

cvbooster_from_disk = pickle_and_unpickle_object(obj=cvbooster, serializer=serializer)
del cvbooster

assert best_iteration == cvbooster_from_disk.best_iteration

preds_from_disk = cvbooster_from_disk.predict(X_test)
np.testing.assert_array_equal(preds, preds_from_disk)


def test_feature_name():
X_train, y_train = make_synthetic_regression()
params = {'verbose': -1}
Expand Down
18 changes: 18 additions & 0 deletions tests/python_package_test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@
import sklearn.datasets
from sklearn.utils import check_random_state

import lightgbm as lgb

SERIALIZERS = ["pickle", "joblib", "cloudpickle"]


@lru_cache(maxsize=None)
def load_boston(**kwargs):
Expand Down Expand Up @@ -179,3 +183,17 @@ def unpickle_obj(filepath, serializer):
return cloudpickle.load(f)
else:
raise ValueError(f'Unrecognized serializer type: {serializer}')


def pickle_and_unpickle_object(obj, serializer):
with lgb.basic._TempFile() as tmp_file:
pickle_obj(
obj=obj,
filepath=tmp_file.name,
serializer=serializer
)
obj_from_disk = unpickle_obj(
filepath=tmp_file.name,
serializer=serializer
)
return obj_from_disk

0 comments on commit 4a9b08e

Please sign in to comment.