diff --git a/.ci/test.sh b/.ci/test.sh index 416eec422cd7..80b2bea93f70 100755 --- a/.ci/test.sh +++ b/.ci/test.sh @@ -95,7 +95,7 @@ if [[ $TASK == "swig" ]]; then exit 0 fi -conda install -q -y -n $CONDA_ENV dask distributed joblib matplotlib numpy pandas psutil pytest scikit-learn scipy +conda install -q -y -n $CONDA_ENV cloudpickle dask distributed joblib matplotlib numpy pandas psutil pytest scikit-learn scipy # graphviz must come from conda-forge to avoid this on some linux distros: # https://github.com/conda-forge/graphviz-feedstock/issues/18 diff --git a/python-package/lightgbm/dask.py b/python-package/lightgbm/dask.py index c5f4049b0d5f..d8945fa5fa38 100644 --- a/python-package/lightgbm/dask.py +++ b/python-package/lightgbm/dask.py @@ -9,7 +9,7 @@ import socket from collections import defaultdict from copy import deepcopy -from typing import Any, Dict, Iterable, List, Optional, Type, Union +from typing import Any, Callable, Dict, Iterable, List, Optional, Type, Union from urllib.parse import urlparse import numpy as np @@ -17,7 +17,7 @@ from .basic import _choose_param_value, _ConfigAliases, _LIB, _log_warning, _safe_call, LightGBMError from .compat import (PANDAS_INSTALLED, pd_DataFrame, pd_Series, concat, - SKLEARN_INSTALLED, + SKLEARN_INSTALLED, LGBMNotFittedError, DASK_INSTALLED, dask_DataFrame, dask_Array, dask_Series, delayed, Client, default_client, get_worker, wait) from .sklearn import LGBMClassifier, LGBMModel, LGBMRegressor, LGBMRanker @@ -27,6 +27,25 @@ _PredictionDtype = Union[Type[np.float32], Type[np.float64], Type[np.int32], Type[np.int64]] +def _get_dask_client(client: Optional[Client]) -> Client: + """Choose a Dask client to use. + + Parameters + ---------- + client : dask.distributed.Client or None + Dask client. + + Returns + ------- + client : dask.distributed.Client + A Dask client. + """ + if client is None: + return default_client() + else: + return client + + def _find_open_port(worker_ip: str, local_listen_port: int, ports_to_skip: Iterable[int]) -> int: """Find an open port. @@ -434,6 +453,29 @@ def _predict( class _DaskLGBMModel: + @property + def client_(self) -> Client: + """Dask client. + + This property can be passed in the constructor or updated + with ``model.set_params(client=client)``. + """ + if not getattr(self, "fitted_", False): + raise LGBMNotFittedError('Cannot access property client_ before calling fit().') + + return _get_dask_client(client=self.client) + + def _lgb_getstate(self) -> Dict[Any, Any]: + """Remove un-picklable attributes before serialization.""" + client = self.__dict__.pop("client", None) + self.__dict__.pop("_client", None) + self._other_params.pop("client", None) + out = deepcopy(self.__dict__) + out.update({"_client": None, "client": None}) + self._client = client + self.client = client + return out + def _fit( self, model_factory: Type[LGBMModel], @@ -441,18 +483,16 @@ def _fit( y: _DaskCollection, sample_weight: Optional[_DaskCollection] = None, group: Optional[_DaskCollection] = None, - client: Optional[Client] = None, **kwargs: Any ) -> "_DaskLGBMModel": if not all((DASK_INSTALLED, PANDAS_INSTALLED, SKLEARN_INSTALLED)): raise LightGBMError('dask, pandas and scikit-learn are required for lightgbm.dask') - if client is None: - client = default_client() params = self.get_params(True) + params.pop("client", None) model = _train( - client=client, + client=_get_dask_client(self.client), data=X, label=y, params=params, @@ -468,8 +508,11 @@ def _fit( return self def _to_local(self, model_factory: Type[LGBMModel]) -> LGBMModel: - model = model_factory(**self.get_params()) + params = self.get_params() + params.pop("client", None) + model = model_factory(**params) self._copy_extra_params(self, model) + model._other_params.pop("client", None) return model @staticmethod @@ -478,18 +521,82 @@ def _copy_extra_params(source: Union["_DaskLGBMModel", LGBMModel], dest: Union[" attributes = source.__dict__ extra_param_names = set(attributes.keys()).difference(params.keys()) for name in extra_param_names: - setattr(dest, name, attributes[name]) + if name != "_client": + setattr(dest, name, attributes[name]) class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel): """Distributed version of lightgbm.LGBMClassifier.""" + def __init__( + self, + boosting_type: str = 'gbdt', + num_leaves: int = 31, + max_depth: int = -1, + learning_rate: float = 0.1, + n_estimators: int = 100, + subsample_for_bin: int = 200000, + objective: Optional[Union[Callable, str]] = None, + class_weight: Optional[Union[dict, str]] = None, + min_split_gain: float = 0., + min_child_weight: float = 1e-3, + min_child_samples: int = 20, + subsample: float = 1., + subsample_freq: int = 0, + colsample_bytree: float = 1., + reg_alpha: float = 0., + reg_lambda: float = 0., + random_state: Optional[Union[int, np.random.RandomState]] = None, + n_jobs: int = -1, + silent: bool = True, + importance_type: str = 'split', + client: Optional[Client] = None, + **kwargs: Any + ): + """Docstring is inherited from the lightgbm.LGBMClassifier.__init__.""" + self._client = client + self.client = client + super().__init__( + boosting_type=boosting_type, + num_leaves=num_leaves, + max_depth=max_depth, + learning_rate=learning_rate, + n_estimators=n_estimators, + subsample_for_bin=subsample_for_bin, + objective=objective, + class_weight=class_weight, + min_split_gain=min_split_gain, + min_child_weight=min_child_weight, + min_child_samples=min_child_samples, + subsample=subsample, + subsample_freq=subsample_freq, + colsample_bytree=colsample_bytree, + reg_alpha=reg_alpha, + reg_lambda=reg_lambda, + random_state=random_state, + n_jobs=n_jobs, + silent=silent, + importance_type=importance_type, + **kwargs + ) + + _base_doc = LGBMClassifier.__init__.__doc__ + _before_kwargs, _kwargs, _after_kwargs = _base_doc.partition('**kwargs') + __init__.__doc__ = ( + _before_kwargs + + 'client : dask.distributed.Client or None, optional (default=None)\n' + + ' ' * 12 + 'Dask client. If ``None``, ``distributed.default_client()`` will be used at runtime. The Dask client used by this class will not be saved if the model object is pickled.\n' + + ' ' * 8 + _kwargs + _after_kwargs + ) + + def __getstate__(self) -> Dict[Any, Any]: + return self._lgb_getstate() + def fit( self, X: _DaskMatrixLike, y: _DaskCollection, sample_weight: Optional[_DaskCollection] = None, - client: Optional[Client] = None, **kwargs: Any ) -> "DaskLGBMClassifier": """Docstring is inherited from the lightgbm.LGBMClassifier.fit.""" @@ -498,16 +605,10 @@ def fit( X=X, y=y, sample_weight=sample_weight, - client=client, **kwargs ) - _base_doc = LGBMClassifier.fit.__doc__ - _before_init_score, _init_score, _after_init_score = _base_doc.partition('init_score :') - fit.__doc__ = (_before_init_score - + 'client : dask.distributed.Client or None, optional (default=None)\n' - + ' ' * 12 + 'Dask client.\n' - + ' ' * 8 + _init_score + _after_init_score) + fit.__doc__ = LGBMClassifier.fit.__doc__ def predict(self, X: _DaskMatrixLike, **kwargs: Any) -> dask_Array: """Docstring is inherited from the lightgbm.LGBMClassifier.predict.""" @@ -545,6 +646,70 @@ def to_local(self) -> LGBMClassifier: class DaskLGBMRegressor(LGBMRegressor, _DaskLGBMModel): """Distributed version of lightgbm.LGBMRegressor.""" + def __init__( + self, + boosting_type: str = 'gbdt', + num_leaves: int = 31, + max_depth: int = -1, + learning_rate: float = 0.1, + n_estimators: int = 100, + subsample_for_bin: int = 200000, + objective: Optional[Union[Callable, str]] = None, + class_weight: Optional[Union[dict, str]] = None, + min_split_gain: float = 0., + min_child_weight: float = 1e-3, + min_child_samples: int = 20, + subsample: float = 1., + subsample_freq: int = 0, + colsample_bytree: float = 1., + reg_alpha: float = 0., + reg_lambda: float = 0., + random_state: Optional[Union[int, np.random.RandomState]] = None, + n_jobs: int = -1, + silent: bool = True, + importance_type: str = 'split', + client: Optional[Client] = None, + **kwargs: Any + ): + """Docstring is inherited from the lightgbm.LGBMRegressor.__init__.""" + self._client = client + self.client = client + super().__init__( + boosting_type=boosting_type, + num_leaves=num_leaves, + max_depth=max_depth, + learning_rate=learning_rate, + n_estimators=n_estimators, + subsample_for_bin=subsample_for_bin, + objective=objective, + class_weight=class_weight, + min_split_gain=min_split_gain, + min_child_weight=min_child_weight, + min_child_samples=min_child_samples, + subsample=subsample, + subsample_freq=subsample_freq, + colsample_bytree=colsample_bytree, + reg_alpha=reg_alpha, + reg_lambda=reg_lambda, + random_state=random_state, + n_jobs=n_jobs, + silent=silent, + importance_type=importance_type, + **kwargs + ) + + _base_doc = LGBMRegressor.__init__.__doc__ + _before_kwargs, _kwargs, _after_kwargs = _base_doc.partition('**kwargs') + __init__.__doc__ = ( + _before_kwargs + + 'client : dask.distributed.Client or None, optional (default=None)\n' + + ' ' * 12 + 'Dask client. If ``None``, ``distributed.default_client()`` will be used at runtime. The Dask client used by this class will not be saved if the model object is pickled.\n' + + ' ' * 8 + _kwargs + _after_kwargs + ) + + def __getstate__(self) -> Dict[Any, Any]: + return self._lgb_getstate() + def fit( self, X: _DaskMatrixLike, @@ -559,16 +724,10 @@ def fit( X=X, y=y, sample_weight=sample_weight, - client=client, **kwargs ) - _base_doc = LGBMRegressor.fit.__doc__ - _before_init_score, _init_score, _after_init_score = _base_doc.partition('init_score :') - fit.__doc__ = (_before_init_score - + 'client : dask.distributed.Client or None, optional (default=None)\n' - + ' ' * 12 + 'Dask client.\n' - + ' ' * 8 + _init_score + _after_init_score) + fit.__doc__ = LGBMRegressor.fit.__doc__ def predict(self, X: _DaskMatrixLike, **kwargs) -> dask_Array: """Docstring is inherited from the lightgbm.LGBMRegressor.predict.""" @@ -594,6 +753,70 @@ def to_local(self) -> LGBMRegressor: class DaskLGBMRanker(LGBMRanker, _DaskLGBMModel): """Distributed version of lightgbm.LGBMRanker.""" + def __init__( + self, + boosting_type: str = 'gbdt', + num_leaves: int = 31, + max_depth: int = -1, + learning_rate: float = 0.1, + n_estimators: int = 100, + subsample_for_bin: int = 200000, + objective: Optional[Union[Callable, str]] = None, + class_weight: Optional[Union[dict, str]] = None, + min_split_gain: float = 0., + min_child_weight: float = 1e-3, + min_child_samples: int = 20, + subsample: float = 1., + subsample_freq: int = 0, + colsample_bytree: float = 1., + reg_alpha: float = 0., + reg_lambda: float = 0., + random_state: Optional[Union[int, np.random.RandomState]] = None, + n_jobs: int = -1, + silent: bool = True, + importance_type: str = 'split', + client: Optional[Client] = None, + **kwargs: Any + ): + """Docstring is inherited from the lightgbm.LGBMRanker.__init__.""" + self._client = client + self.client = client + super().__init__( + boosting_type=boosting_type, + num_leaves=num_leaves, + max_depth=max_depth, + learning_rate=learning_rate, + n_estimators=n_estimators, + subsample_for_bin=subsample_for_bin, + objective=objective, + class_weight=class_weight, + min_split_gain=min_split_gain, + min_child_weight=min_child_weight, + min_child_samples=min_child_samples, + subsample=subsample, + subsample_freq=subsample_freq, + colsample_bytree=colsample_bytree, + reg_alpha=reg_alpha, + reg_lambda=reg_lambda, + random_state=random_state, + n_jobs=n_jobs, + silent=silent, + importance_type=importance_type, + **kwargs + ) + + _base_doc = LGBMRanker.__init__.__doc__ + _before_kwargs, _kwargs, _after_kwargs = _base_doc.partition('**kwargs') + __init__.__doc__ = ( + _before_kwargs + + 'client : dask.distributed.Client or None, optional (default=None)\n' + + ' ' * 12 + 'Dask client. If ``None``, ``distributed.default_client()`` will be used at runtime. The Dask client used by this class will not be saved if the model object is pickled.\n' + + ' ' * 8 + _kwargs + _after_kwargs + ) + + def __getstate__(self) -> Dict[Any, Any]: + return self._lgb_getstate() + def fit( self, X: _DaskMatrixLike, @@ -601,7 +824,6 @@ def fit( sample_weight: Optional[_DaskCollection] = None, init_score: Optional[_DaskCollection] = None, group: Optional[_DaskCollection] = None, - client: Optional[Client] = None, **kwargs: Any ) -> "DaskLGBMRanker": """Docstring is inherited from the lightgbm.LGBMRanker.fit.""" @@ -614,16 +836,10 @@ def fit( y=y, sample_weight=sample_weight, group=group, - client=client, **kwargs ) - _base_doc = LGBMRanker.fit.__doc__ - _before_eval_set, _eval_set, _after_eval_set = _base_doc.partition('eval_set :') - fit.__doc__ = (_before_eval_set - + 'client : dask.distributed.Client or None, optional (default=None)\n' - + ' ' * 12 + 'Dask client.\n' - + ' ' * 8 + _eval_set + _after_eval_set) + fit.__doc__ = LGBMRanker.fit.__doc__ def predict(self, X: _DaskMatrixLike, **kwargs: Any) -> dask_Array: """Docstring is inherited from the lightgbm.LGBMRanker.predict.""" diff --git a/tests/python_package_test/test_dask.py b/tests/python_package_test/test_dask.py index 1cb2704820ec..cc9fa3adb184 100644 --- a/tests/python_package_test/test_dask.py +++ b/tests/python_package_test/test_dask.py @@ -1,6 +1,9 @@ # coding: utf-8 """Tests for lightgbm.dask module""" +import inspect +import joblib +import pickle import socket from itertools import groupby from os import getenv @@ -13,13 +16,14 @@ if not lgb.compat.DASK_INSTALLED: pytest.skip('Dask is not installed', allow_module_level=True) +import cloudpickle import dask.array as da import dask.dataframe as dd import numpy as np import pandas as pd from scipy.stats import spearmanr from dask.array.utils import assert_eq -from dask.distributed import wait +from dask.distributed import default_client, Client, LocalCluster, wait from distributed.utils_test import client, cluster_fixture, gen_cluster, loop from scipy.sparse import csr_matrix from sklearn.datasets import make_blobs, make_regression @@ -137,6 +141,32 @@ def _accuracy_score(dy_true, dy_pred): return da.average(dy_true == dy_pred).compute() +def _pickle(obj, filepath, serializer): + if serializer == 'pickle': + with open(filepath, 'wb') as f: + pickle.dump(obj, f) + elif serializer == 'joblib': + joblib.dump(obj, filepath) + elif serializer == 'cloudpickle': + with open(filepath, 'wb') as f: + cloudpickle.dump(obj, f) + else: + raise ValueError(f'Unrecognized serializer type: {serializer}') + + +def _unpickle(filepath, serializer): + if serializer == 'pickle': + with open(filepath, 'rb') as f: + return pickle.load(f) + elif serializer == 'joblib': + return joblib.load(filepath) + elif serializer == 'cloudpickle': + with open(filepath, 'rb') as f: + return cloudpickle.load(f) + else: + raise ValueError(f'Unrecognized serializer type: {serializer}') + + @pytest.mark.parametrize('output', data_output) @pytest.mark.parametrize('centers', data_centers) def test_classifier(output, centers, client, listen_port): @@ -151,11 +181,12 @@ def test_classifier(output, centers, client, listen_port): "num_leaves": 10 } dask_classifier = lgb.DaskLGBMClassifier( + client=client, time_out=5, local_listen_port=listen_port, **params ) - dask_classifier = dask_classifier.fit(dX, dy, sample_weight=dw, client=client) + dask_classifier = dask_classifier.fit(dX, dy, sample_weight=dw) p1 = dask_classifier.predict(dX) p1_proba = dask_classifier.predict_proba(dX).compute() p1_local = dask_classifier.to_local().predict(X) @@ -193,12 +224,13 @@ def test_classifier_pred_contrib(output, centers, client, listen_port): "num_leaves": 10 } dask_classifier = lgb.DaskLGBMClassifier( + client=client, time_out=5, local_listen_port=listen_port, tree_learner='data', **params ) - dask_classifier = dask_classifier.fit(dX, dy, sample_weight=dw, client=client) + dask_classifier = dask_classifier.fit(dX, dy, sample_weight=dw) preds_with_contrib = dask_classifier.predict(dX, pred_contrib=True).compute() local_classifier = lgb.LGBMClassifier(**params) @@ -241,6 +273,7 @@ def test_training_does_not_fail_on_port_conflicts(client): s.bind(('127.0.0.1', 12400)) dask_classifier = lgb.DaskLGBMClassifier( + client=client, time_out=5, local_listen_port=12400, n_estimators=5, @@ -251,7 +284,6 @@ def test_training_does_not_fail_on_port_conflicts(client): X=dX, y=dy, sample_weight=dw, - client=client ) assert dask_classifier.booster_ @@ -270,12 +302,13 @@ def test_regressor(output, client, listen_port): "num_leaves": 10 } dask_regressor = lgb.DaskLGBMRegressor( + client=client, time_out=5, local_listen_port=listen_port, tree='data', **params ) - dask_regressor = dask_regressor.fit(dX, dy, client=client, sample_weight=dw) + dask_regressor = dask_regressor.fit(dX, dy, sample_weight=dw) p1 = dask_regressor.predict(dX) if output != 'dataframe': s1 = _r2_score(dy, p1) @@ -313,12 +346,13 @@ def test_regressor_pred_contrib(output, client, listen_port): "num_leaves": 10 } dask_regressor = lgb.DaskLGBMRegressor( + client=client, time_out=5, local_listen_port=listen_port, tree_learner='data', **params ) - dask_regressor = dask_regressor.fit(dX, dy, sample_weight=dw, client=client) + dask_regressor = dask_regressor.fit(dX, dy, sample_weight=dw) preds_with_contrib = dask_regressor.predict(dX, pred_contrib=True).compute() local_regressor = lgb.LGBMRegressor(**params) @@ -353,11 +387,12 @@ def test_regressor_quantile(output, client, listen_port, alpha): "num_leaves": 10 } dask_regressor = lgb.DaskLGBMRegressor( + client=client, local_listen_port=listen_port, tree_learner_type='data_parallel', **params ) - dask_regressor = dask_regressor.fit(dX, dy, client=client, sample_weight=dw) + dask_regressor = dask_regressor.fit(dX, dy, sample_weight=dw) p1 = dask_regressor.predict(dX).compute() q1 = np.count_nonzero(y < p1) / y.shape[0] @@ -400,12 +435,13 @@ def test_ranker(output, client, listen_port, group): "min_child_samples": 1 } dask_ranker = lgb.DaskLGBMRanker( + client=client, time_out=5, local_listen_port=listen_port, tree_learner_type='data_parallel', **params ) - dask_ranker = dask_ranker.fit(dX, dy, sample_weight=dw, group=dg, client=client) + dask_ranker = dask_ranker.fit(dX, dy, sample_weight=dw, group=dg) rnkvec_dask = dask_ranker.predict(dX) rnkvec_dask = rnkvec_dask.compute() rnkvec_dask_local = dask_ranker.to_local().predict(X) @@ -424,6 +460,288 @@ def test_ranker(output, client, listen_port, group): client.close(timeout=CLIENT_CLOSE_TIMEOUT) +@pytest.mark.parametrize('task', ['classification', 'regression', 'ranking']) +def test_training_works_if_client_not_provided_or_set_after_construction(task, listen_port, client): + if task == 'ranking': + _, _, _, _, dX, dy, _, dg = _create_ranking_data( + output='array', + group=None + ) + model_factory = lgb.DaskLGBMRanker + else: + _, _, _, dX, dy, _ = _create_data( + objective=task, + output='array', + ) + dg = None + if task == 'classification': + model_factory = lgb.DaskLGBMClassifier + elif task == 'regression': + model_factory = lgb.DaskLGBMRegressor + + params = { + "time_out": 5, + "local_listen_port": listen_port, + "n_estimators": 1, + "num_leaves": 2 + } + + # should be able to use the class without specifying a client + dask_model = model_factory(**params) + assert dask_model._client is None + assert dask_model.client is None + with pytest.raises(lgb.compat.LGBMNotFittedError, match='Cannot access property client_ before calling fit'): + dask_model.client_ + + dask_model.fit(dX, dy, group=dg) + assert dask_model.fitted_ + assert dask_model._client is None + assert dask_model.client is None + assert dask_model.client_ == client + + preds = dask_model.predict(dX) + assert isinstance(preds, da.Array) + assert dask_model.fitted_ + assert dask_model._client is None + assert dask_model.client is None + assert dask_model.client_ == client + + local_model = dask_model.to_local() + with pytest.raises(AttributeError): + local_model._client + local_model.client + local_model.client_ + + # should be able to set client after construction + dask_model = model_factory(**params) + dask_model.set_params(client=client) + assert dask_model._client == client + assert dask_model.client == client + + with pytest.raises(lgb.compat.LGBMNotFittedError, match='Cannot access property client_ before calling fit'): + dask_model.client_ + + dask_model.fit(dX, dy, group=dg) + assert dask_model.fitted_ + assert dask_model._client == client + assert dask_model.client == client + assert dask_model.client_ == client + + preds = dask_model.predict(dX) + assert isinstance(preds, da.Array) + assert dask_model.fitted_ + assert dask_model._client == client + assert dask_model.client == client + assert dask_model.client_ == client + + local_model = dask_model.to_local() + assert getattr(local_model, "_client", None) is None + with pytest.raises(AttributeError): + local_model._client + local_model.client + local_model.client_ + + client.close(timeout=CLIENT_CLOSE_TIMEOUT) + + +@pytest.mark.parametrize('serializer', ['pickle', 'joblib', 'cloudpickle']) +@pytest.mark.parametrize('task', ['classification', 'regression', 'ranking']) +@pytest.mark.parametrize('set_client', [True, False]) +def test_model_and_local_version_are_picklable_whether_or_not_client_set_explicitly(serializer, task, set_client, listen_port, tmp_path): + + with LocalCluster(n_workers=2, threads_per_worker=1) as cluster1: + with Client(cluster1) as client1: + + # data on cluster1 + if task == 'ranking': + X_1, _, _, _, dX_1, dy_1, _, dg_1 = _create_ranking_data( + output='array', + group=None + ) + else: + X_1, _, _, dX_1, dy_1, _ = _create_data( + objective=task, + output='array', + ) + dg_1 = None + + with LocalCluster(n_workers=2, threads_per_worker=1) as cluster2: + with Client(cluster2) as client2: + + # create identical data on cluster2 + if task == 'ranking': + X_2, _, _, _, dX_2, dy_2, _, dg_2 = _create_ranking_data( + output='array', + group=None + ) + else: + X_2, _, _, dX_2, dy_2, _ = _create_data( + objective=task, + output='array', + ) + dg_2 = None + + if task == 'ranking': + model_factory = lgb.DaskLGBMRanker + elif task == 'classification': + model_factory = lgb.DaskLGBMClassifier + elif task == 'regression': + model_factory = lgb.DaskLGBMRegressor + + params = { + "time_out": 5, + "local_listen_port": listen_port, + "n_estimators": 1, + "num_leaves": 2 + } + + # at this point, the result of default_client() is client2 since it was the most recently + # created. So setting client to client1 here to test that you can select a non-default client + assert default_client() == client2 + if set_client: + params.update({"client": client1}) + + # unfitted model should survive pickling round trip, and pickling + # shouldn't have side effects on the model object + dask_model = model_factory(**params) + local_model = dask_model.to_local() + if set_client: + assert dask_model._client == client1 + assert dask_model.client == client1 + else: + assert dask_model._client is None + assert dask_model.client is None + + with pytest.raises(lgb.compat.LGBMNotFittedError, match='Cannot access property client_ before calling fit'): + dask_model.client_ + + assert "client" not in local_model.get_params() + assert getattr(local_model, "client", None) is None + + tmp_file = str(tmp_path / "model-1.pkl") + _pickle( + obj=dask_model, + filepath=tmp_file, + serializer=serializer + ) + model_from_disk = _unpickle( + filepath=tmp_file, + serializer=serializer + ) + + local_tmp_file = str(tmp_path / "local-model-1.pkl") + _pickle( + obj=local_model, + filepath=local_tmp_file, + serializer=serializer + ) + local_model_from_disk = _unpickle( + filepath=local_tmp_file, + serializer=serializer + ) + + assert model_from_disk._client is None + assert model_from_disk.client is None + + if set_client: + assert dask_model._client == client1 + assert dask_model.client == client1 + else: + assert dask_model._client is None + assert dask_model.client is None + + with pytest.raises(lgb.compat.LGBMNotFittedError, match='Cannot access property client_ before calling fit'): + dask_model.client_ + + # client will always be None after unpickling + if set_client: + from_disk_params = model_from_disk.get_params() + from_disk_params.pop("client", None) + dask_params = dask_model.get_params() + dask_params.pop("client", None) + assert from_disk_params == dask_params + else: + assert model_from_disk.get_params() == dask_model.get_params() + assert local_model_from_disk.get_params() == local_model.get_params() + + # fitted model should survive pickling round trip, and pickling + # shouldn't have side effects on the model object + if set_client: + dask_model.fit(dX_1, dy_1, group=dg_1) + else: + dask_model.fit(dX_2, dy_2, group=dg_2) + local_model = dask_model.to_local() + + assert "client" not in local_model.get_params() + with pytest.raises(AttributeError): + local_model._client + local_model.client + local_model.client_ + + tmp_file2 = str(tmp_path / "model-2.pkl") + _pickle( + obj=dask_model, + filepath=tmp_file2, + serializer=serializer + ) + fitted_model_from_disk = _unpickle( + filepath=tmp_file2, + serializer=serializer + ) + + local_tmp_file2 = str(tmp_path / "local-model-2.pkl") + _pickle( + obj=local_model, + filepath=local_tmp_file2, + serializer=serializer + ) + local_fitted_model_from_disk = _unpickle( + filepath=local_tmp_file2, + serializer=serializer + ) + + if set_client: + assert dask_model._client == client1 + assert dask_model.client == client1 + assert dask_model.client_ == client1 + else: + assert dask_model._client is None + assert dask_model.client is None + assert dask_model.client_ == default_client() + assert dask_model.client_ == client2 + + assert isinstance(fitted_model_from_disk, model_factory) + assert fitted_model_from_disk._client is None + assert fitted_model_from_disk.client is None + assert fitted_model_from_disk.client_ == default_client() + assert fitted_model_from_disk.client_ == client2 + + # client will always be None after unpickling + if set_client: + from_disk_params = fitted_model_from_disk.get_params() + from_disk_params.pop("client", None) + dask_params = dask_model.get_params() + dask_params.pop("client", None) + assert from_disk_params == dask_params + else: + assert fitted_model_from_disk.get_params() == dask_model.get_params() + assert local_fitted_model_from_disk.get_params() == local_model.get_params() + + if set_client: + preds_orig = dask_model.predict(dX_1).compute() + preds_loaded_model = fitted_model_from_disk.predict(dX_1).compute() + preds_orig_local = local_model.predict(X_1) + preds_loaded_model_local = local_fitted_model_from_disk.predict(X_1) + else: + preds_orig = dask_model.predict(dX_2).compute() + preds_loaded_model = fitted_model_from_disk.predict(dX_2).compute() + preds_orig_local = local_model.predict(X_2) + preds_loaded_model_local = local_fitted_model_from_disk.predict(X_2) + + assert_eq(preds_orig, preds_loaded_model) + assert_eq(preds_orig_local, preds_loaded_model_local) + + def test_find_open_port_works(): worker_ip = '127.0.0.1' with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: @@ -451,6 +769,7 @@ def test_warns_and_continues_on_unrecognized_tree_learner(client): X = da.random.random((1e3, 10)) y = da.random.random((1e3, 1)) dask_regressor = lgb.DaskLGBMRegressor( + client=client, time_out=5, local_listen_port=1234, tree_learner='some-nonsense-value', @@ -458,7 +777,7 @@ def test_warns_and_continues_on_unrecognized_tree_learner(client): num_leaves=2 ) with pytest.warns(UserWarning, match='Parameter tree_learner set to some-nonsense-value'): - dask_regressor = dask_regressor.fit(X, y, client=client) + dask_regressor = dask_regressor.fit(X, y) assert dask_regressor.fitted_ @@ -470,6 +789,7 @@ def test_warns_but_makes_no_changes_for_feature_or_voting_tree_learner(client): y = da.random.random((1e3, 1)) for tree_learner in ['feature_parallel', 'voting']: dask_regressor = lgb.DaskLGBMRegressor( + client=client, time_out=5, local_listen_port=1234, tree_learner=tree_learner, @@ -477,7 +797,7 @@ def test_warns_but_makes_no_changes_for_feature_or_voting_tree_learner(client): num_leaves=2 ) with pytest.warns(UserWarning, match='Support for tree_learner %s in lightgbm' % tree_learner): - dask_regressor = dask_regressor.fit(X, y, client=client) + dask_regressor = dask_regressor.fit(X, y) assert dask_regressor.fitted_ assert dask_regressor.get_params()['tree_learner'] == tree_learner @@ -501,3 +821,26 @@ def f(part): model_factory=lgb.LGBMClassifier ) assert 'foo' in str(info.value) + + +@pytest.mark.parametrize( + "classes", + [ + (lgb.DaskLGBMClassifier, lgb.LGBMClassifier), + (lgb.DaskLGBMRegressor, lgb.LGBMRegressor), + (lgb.DaskLGBMRanker, lgb.LGBMRanker) + ] +) +def test_dask_classes_and_sklearn_equivalents_have_identical_constructors_except_client_arg(classes): + dask_spec = inspect.getfullargspec(classes[0]) + sklearn_spec = inspect.getfullargspec(classes[1]) + assert dask_spec.varargs == sklearn_spec.varargs + assert dask_spec.varkw == sklearn_spec.varkw + assert dask_spec.kwonlyargs == sklearn_spec.kwonlyargs + assert dask_spec.kwonlydefaults == sklearn_spec.kwonlydefaults + + # "client" should be the only different, and the final argument + assert dask_spec.args[:-1] == sklearn_spec.args + assert dask_spec.defaults[:-1] == sklearn_spec.defaults + assert dask_spec.args[-1] == 'client' + assert dask_spec.defaults[-1] is None