diff --git a/python-package/lightgbm/dask.py b/python-package/lightgbm/dask.py index 9833c724356b..0ba89df2cee0 100644 --- a/python-package/lightgbm/dask.py +++ b/python-package/lightgbm/dask.py @@ -17,7 +17,7 @@ from .basic import _choose_param_value, _ConfigAliases, _LIB, _log_warning, _safe_call, LightGBMError from .compat import (PANDAS_INSTALLED, pd_DataFrame, pd_Series, concat, - SKLEARN_INSTALLED, + SKLEARN_INSTALLED, LGBMNotFittedError, DASK_INSTALLED, dask_DataFrame, dask_Array, dask_Series, delayed, Client, default_client, get_worker, wait) from .sklearn import LGBMClassifier, LGBMModel, LGBMRegressor, LGBMRanker @@ -27,6 +27,25 @@ _PredictionDtype = Union[Type[np.float32], Type[np.float64], Type[np.int32], Type[np.int64]] +def _get_dask_client(client: Optional[Client]) -> Client: + """Choose a Dask client to use + + Parameters + ---------- + client : dask.distributed.Client or None + Dask client. + + Returns + ------- + client : dask.distributed.Client + A Dask client. + """ + if client is None: + return default_client() + else: + return client + + def _find_open_port(worker_ip: str, local_listen_port: int, ports_to_skip: Iterable[int]) -> int: """Find an open port. @@ -444,10 +463,10 @@ def client_(self) -> Client: This property can be passed in the constructor or updated with ``model.set_params(client=client)``. """ - if self._client is None: - return default_client() - else: - return self._client + if not getattr(self, "fitted_", False): + raise LGBMNotFittedError('Cannot access property client_ before calling fit().') + + return _get_dask_client(client=self.client) def _lgb_getstate(self) -> Dict[Any, Any]: """Remove un-picklable attributes before serialization.""" @@ -476,7 +495,7 @@ def _fit( params.pop("client", None) model = _train( - client=self.client_, + client=_get_dask_client(self.client), data=X, label=y, params=params, @@ -569,7 +588,7 @@ def __init__( __init__.__doc__ = ( _before_kwargs + 'client : dask.distributed.Client or None, optional (default=None)\n' - + ' ' * 12 + 'Dask client. If ``None``, ``distributed.default_client()`` will be used at runtime. This client will not be saved if the model object is pickled.\n' + + ' ' * 12 + 'Dask client. If ``None``, ``distributed.default_client()`` will be used at runtime. The Dask client used by this class will not be saved if the model object is pickled.\n' + ' ' * 8 + _kwargs + _after_kwargs ) @@ -687,7 +706,7 @@ def __init__( __init__.__doc__ = ( _before_kwargs + 'client : dask.distributed.Client or None, optional (default=None)\n' - + ' ' * 12 + 'Dask client. If ``None``, ``distributed.default_client()`` will be used at runtime. This client will not be saved if the model object is pickled.\n' + + ' ' * 12 + 'Dask client. If ``None``, ``distributed.default_client()`` will be used at runtime. The Dask client used by this class will not be saved if the model object is pickled.\n' + ' ' * 8 + _kwargs + _after_kwargs ) @@ -794,7 +813,7 @@ def __init__( __init__.__doc__ = ( _before_kwargs + 'client : dask.distributed.Client or None, optional (default=None)\n' - + ' ' * 12 + 'Dask client. If ``None``, ``distributed.default_client()`` will be used at runtime. This client will not be saved if the model object is pickled.\n' + + ' ' * 12 + 'Dask client. If ``None``, ``distributed.default_client()`` will be used at runtime. The Dask client used by this class will not be saved if the model object is pickled.\n' + ' ' * 8 + _kwargs + _after_kwargs ) diff --git a/tests/python_package_test/test_dask.py b/tests/python_package_test/test_dask.py index 955c6e8ebf94..e004c9cf933d 100644 --- a/tests/python_package_test/test_dask.py +++ b/tests/python_package_test/test_dask.py @@ -23,7 +23,7 @@ import pandas as pd from scipy.stats import spearmanr from dask.array.utils import assert_eq -from dask.distributed import wait +from dask.distributed import default_client, Client, LocalCluster, wait from distributed.utils_test import client, cluster_fixture, gen_cluster, loop from scipy.sparse import csr_matrix from sklearn.datasets import make_blobs, make_regression @@ -486,11 +486,12 @@ def test_training_works_if_client_not_provided_or_set_after_construction(task, l "num_leaves": 2 } - # fit should work if client isn't provided + # should be able to use the class without specifying a client dask_model = model_factory(**params) assert dask_model._client is None assert dask_model.client is None - assert dask_model.client_ == client + with pytest.raises(lgb.compat.LGBMNotFittedError, match='Cannot access property client_ before calling fit'): + dask_model.client_ dask_model.fit(dX, dy, group=dg) assert dask_model.fitted_ @@ -516,7 +517,9 @@ def test_training_works_if_client_not_provided_or_set_after_construction(task, l dask_model.set_params(client=client) assert dask_model._client == client assert dask_model.client == client - assert dask_model.client_ == client + + with pytest.raises(lgb.compat.LGBMNotFittedError, match='Cannot access property client_ before calling fit'): + dask_model.client_ dask_model.fit(dX, dy, group=dg) assert dask_model.fitted_ @@ -544,153 +547,199 @@ def test_training_works_if_client_not_provided_or_set_after_construction(task, l @pytest.mark.parametrize('serializer', ['pickle', 'joblib', 'cloudpickle']) @pytest.mark.parametrize('task', ['classification', 'regression', 'ranking']) @pytest.mark.parametrize('set_client', [True, False]) -def test_model_and_local_version_are_picklable_whether_or_not_client_set_explicitly(serializer, task, set_client, listen_port, client, tmp_path): - if task == 'ranking': - X, _, _, _, dX, dy, _, dg = _create_ranking_data( - output='array', - group=None - ) - model_factory = lgb.DaskLGBMRanker - else: - X, _, _, dX, dy, _ = _create_data( - objective=task, - output='array', - ) - dg = None - if task == 'classification': - model_factory = lgb.DaskLGBMClassifier - elif task == 'regression': - model_factory = lgb.DaskLGBMRegressor - - params = { - "time_out": 5, - "local_listen_port": listen_port, - "n_estimators": 1, - "num_leaves": 2 - } - - if set_client: - params.update({"client": client}) - - # unfitted model should survive pickling round trip, and pickling - # shouldn't have side effects on the model object - dask_model = model_factory(**params) - local_model = dask_model.to_local() - if set_client: - assert dask_model._client == client - assert dask_model.client == client - else: - assert dask_model._client is None - assert dask_model.client is None - - assert dask_model.client_ == client - assert "client" not in local_model.get_params() - assert getattr(local_model, "client", None) is None - - tmp_file = str(tmp_path / "model-1.pkl") - _pickle( - obj=dask_model, - filepath=tmp_file, - serializer=serializer - ) - model_from_disk = _unpickle( - filepath=tmp_file, - serializer=serializer - ) - - local_tmp_file = str(tmp_path / "local-model-1.pkl") - _pickle( - obj=local_model, - filepath=local_tmp_file, - serializer=serializer - ) - local_model_from_disk = _unpickle( - filepath=local_tmp_file, - serializer=serializer - ) - - if set_client: - assert dask_model._client == client - assert dask_model.client == client - else: - assert dask_model._client is None - assert dask_model.client is None - assert model_from_disk._client is None - assert model_from_disk.client is None - assert model_from_disk.client_ == client - # client will always be None after unpickling - if set_client: - from_disk_params = model_from_disk.get_params() - from_disk_params.pop("client", None) - dask_params = dask_model.get_params() - dask_params.pop("client", None) - assert from_disk_params == dask_params - else: - assert model_from_disk.get_params() == dask_model.get_params() - assert local_model_from_disk.get_params() == local_model.get_params() - - # fitted model should survive pickling round trip, and pickling - # shouldn't have side effects on the model object - dask_model.fit(dX, dy, group=dg) - local_model = dask_model.to_local() - - assert "client" not in local_model.get_params() - with pytest.raises(AttributeError): - local_model._client - local_model.client - local_model.client_ - - tmp_file2 = str(tmp_path / "model-2.pkl") - _pickle( - obj=dask_model, - filepath=tmp_file2, - serializer=serializer - ) - fitted_model_from_disk = _unpickle( - filepath=tmp_file2, - serializer=serializer - ) - - local_tmp_file2 = str(tmp_path / "local-model-2.pkl") - _pickle( - obj=local_model, - filepath=local_tmp_file2, - serializer=serializer - ) - local_fitted_model_from_disk = _unpickle( - filepath=local_tmp_file2, - serializer=serializer - ) - - if set_client: - assert dask_model._client == client - assert dask_model.client == client - else: - assert dask_model._client is None - assert dask_model.client is None - assert isinstance(fitted_model_from_disk, model_factory) - assert fitted_model_from_disk._client is None - assert fitted_model_from_disk.client is None - assert fitted_model_from_disk.client_ == client - - # client will always be None after unpickling - if set_client: - from_disk_params = fitted_model_from_disk.get_params() - from_disk_params.pop("client", None) - dask_params = dask_model.get_params() - dask_params.pop("client", None) - assert from_disk_params == dask_params - else: - assert fitted_model_from_disk.get_params() == dask_model.get_params() - assert local_fitted_model_from_disk.get_params() == local_model.get_params() - - preds_orig = dask_model.predict(dX).compute() - preds_loaded_model = fitted_model_from_disk.predict(dX).compute() - assert_eq(preds_orig, preds_loaded_model) - - preds_orig_local = local_model.predict(X) - preds_loaded_model_local = local_fitted_model_from_disk.predict(X) - assert_eq(preds_orig_local, preds_loaded_model_local) +def test_model_and_local_version_are_picklable_whether_or_not_client_set_explicitly(serializer, task, set_client, listen_port, tmp_path): + + with LocalCluster(n_workers=2, threads_per_worker=1) as cluster1: + with Client(cluster1) as client1: + + # data on cluster1 + if task == 'ranking': + X_1, _, _, _, dX_1, dy_1, _, dg_1 = _create_ranking_data( + output='array', + group=None + ) + else: + X_1, _, _, dX_1, dy_1, _ = _create_data( + objective=task, + output='array', + ) + dg_1 = None + + with LocalCluster(n_workers=2, threads_per_worker=1) as cluster2: + with Client(cluster2) as client2: + + # create identical data on cluster2 + if task == 'ranking': + X_2, _, _, _, dX_2, dy_2, _, dg_2 = _create_ranking_data( + output='array', + group=None + ) + else: + X_2, _, _, dX_2, dy_2, _ = _create_data( + objective=task, + output='array', + ) + dg_2 = None + + if task == 'ranking': + model_factory = lgb.DaskLGBMRanker + elif task == 'classification': + model_factory = lgb.DaskLGBMClassifier + elif task == 'regression': + model_factory = lgb.DaskLGBMRegressor + + params = { + "time_out": 5, + "local_listen_port": listen_port, + "n_estimators": 1, + "num_leaves": 2 + } + + # at this point, the result of default_client() is client2 since it was the most recently + # created. So setting client to client1 here to test that you can select a non-default client + assert default_client() == client2 + if set_client: + params.update({"client": client1}) + + # unfitted model should survive pickling round trip, and pickling + # shouldn't have side effects on the model object + dask_model = model_factory(**params) + local_model = dask_model.to_local() + if set_client: + assert dask_model._client == client1 + assert dask_model.client == client1 + else: + assert dask_model._client is None + assert dask_model.client is None + + with pytest.raises(lgb.compat.LGBMNotFittedError, match='Cannot access property client_ before calling fit'): + dask_model.client_ + + assert "client" not in local_model.get_params() + assert getattr(local_model, "client", None) is None + + tmp_file = str(tmp_path / "model-1.pkl") + _pickle( + obj=dask_model, + filepath=tmp_file, + serializer=serializer + ) + model_from_disk = _unpickle( + filepath=tmp_file, + serializer=serializer + ) + + local_tmp_file = str(tmp_path / "local-model-1.pkl") + _pickle( + obj=local_model, + filepath=local_tmp_file, + serializer=serializer + ) + local_model_from_disk = _unpickle( + filepath=local_tmp_file, + serializer=serializer + ) + + assert model_from_disk._client is None + assert model_from_disk.client is None + + if set_client: + assert dask_model._client == client1 + assert dask_model.client == client1 + else: + assert dask_model._client is None + assert dask_model.client is None + + with pytest.raises(lgb.compat.LGBMNotFittedError, match='Cannot access property client_ before calling fit'): + dask_model.client_ + + # client will always be None after unpickling + if set_client: + from_disk_params = model_from_disk.get_params() + from_disk_params.pop("client", None) + dask_params = dask_model.get_params() + dask_params.pop("client", None) + assert from_disk_params == dask_params + else: + assert model_from_disk.get_params() == dask_model.get_params() + assert local_model_from_disk.get_params() == local_model.get_params() + + # fitted model should survive pickling round trip, and pickling + # shouldn't have side effects on the model object + if set_client: + dask_model.fit(dX_1, dy_1, group=dg_1) + else: + dask_model.fit(dX_2, dy_2, group=dg_2) + local_model = dask_model.to_local() + + assert "client" not in local_model.get_params() + with pytest.raises(AttributeError): + local_model._client + local_model.client + local_model.client_ + + tmp_file2 = str(tmp_path / "model-2.pkl") + _pickle( + obj=dask_model, + filepath=tmp_file2, + serializer=serializer + ) + fitted_model_from_disk = _unpickle( + filepath=tmp_file2, + serializer=serializer + ) + + local_tmp_file2 = str(tmp_path / "local-model-2.pkl") + _pickle( + obj=local_model, + filepath=local_tmp_file2, + serializer=serializer + ) + local_fitted_model_from_disk = _unpickle( + filepath=local_tmp_file2, + serializer=serializer + ) + + if set_client: + assert dask_model._client == client1 + assert dask_model.client == client1 + assert dask_model.client_ == client1 + else: + assert dask_model._client is None + assert dask_model.client is None + assert dask_model.client_ == default_client() + assert dask_model.client_ == client2 + + assert isinstance(fitted_model_from_disk, model_factory) + assert fitted_model_from_disk._client is None + assert fitted_model_from_disk.client is None + assert fitted_model_from_disk.client_ == default_client() + assert fitted_model_from_disk.client_ == client2 + + # client will always be None after unpickling + if set_client: + from_disk_params = fitted_model_from_disk.get_params() + from_disk_params.pop("client", None) + dask_params = dask_model.get_params() + dask_params.pop("client", None) + assert from_disk_params == dask_params + else: + assert fitted_model_from_disk.get_params() == dask_model.get_params() + assert local_fitted_model_from_disk.get_params() == local_model.get_params() + + if set_client: + preds_orig = dask_model.predict(dX_1).compute() + preds_loaded_model = fitted_model_from_disk.predict(dX_1).compute() + preds_orig_local = local_model.predict(X_1) + preds_loaded_model_local = local_fitted_model_from_disk.predict(X_1) + else: + preds_orig = dask_model.predict(dX_2).compute() + preds_loaded_model = fitted_model_from_disk.predict(dX_2).compute() + preds_orig_local = local_model.predict(X_2) + preds_loaded_model_local = local_fitted_model_from_disk.predict(X_2) + + assert_eq(preds_orig, preds_loaded_model) + assert_eq(preds_orig_local, preds_loaded_model_local) def test_find_open_port_works(): @@ -774,7 +823,15 @@ def f(part): assert 'foo' in str(info.value) -def test_dask_classes_and_sklearn_equivalents_have_identical_constructors_except_client_arg(): +@pytest.mark.parametrize( + "classes", + [ + (lgb.DaskLGBMClassifier, lgb.LGBMClassifier), + (lgb.DaskLGBMRegressor, lgb.LGBMRegressor), + (lgb.DaskLGBMRanker, lgb.LGBMRanker) + ] +) +def test_dask_classes_and_sklearn_equivalents_have_identical_constructors_except_client_arg(classes): def _compare_spec(dask_cls, sklearn_cls): dask_spec = inspect.getfullargspec(dask_cls) sklearn_spec = inspect.getfullargspec(sklearn_cls) @@ -789,6 +846,4 @@ def _compare_spec(dask_cls, sklearn_cls): assert dask_spec.args[-1] == 'client' assert dask_spec.defaults[-1] is None - _compare_spec(lgb.DaskLGBMClassifier, lgb.LGBMClassifier) - _compare_spec(lgb.DaskLGBMRegressor, lgb.LGBMRegressor) - _compare_spec(lgb.DaskLGBMRanker, lgb.LGBMRanker) + _compare_spec(classes[0], classes[1])