diff --git a/python-package/lightgbm/dask.py b/python-package/lightgbm/dask.py index 17719b5666f3..9833c724356b 100644 --- a/python-package/lightgbm/dask.py +++ b/python-package/lightgbm/dask.py @@ -438,26 +438,25 @@ class _DaskLGBMModel: _client: Optional[Client] = None @property - def client(self) -> Client: + def client_(self) -> Client: """Dask client. - This property can be passed in the constructor or directly assigned - like ``model.set_params(client=client)``. + This property can be passed in the constructor or updated + with ``model.set_params(client=client)``. """ if self._client is None: return default_client() else: return self._client - @client.setter - def client(self, client: Client) -> None: - self._client = client - def _lgb_getstate(self) -> Dict[Any, Any]: """Remove un-picklable attributes before serialization.""" - out = deepcopy(self.__dict__) - client = out.pop("_client", None) + client = self.__dict__.pop("client", None) + self.__dict__.pop("_client", None) self._other_params.pop("client", None) + out = deepcopy(self.__dict__) + out.update({"_client": None, "client": None}) + self._client = client self.client = client return out @@ -477,7 +476,7 @@ def _fit( params.pop("client", None) model = _train( - client=self.client, + client=self.client_, data=X, label=y, params=params, @@ -539,6 +538,7 @@ def __init__( **kwargs: Any ): """Docstring is inherited from the lightgbm.LGBMClassifier.__init__.""" + self._client = client self.client = client super().__init__( boosting_type=boosting_type, @@ -656,6 +656,7 @@ def __init__( **kwargs: Any ): """Docstring is inherited from the lightgbm.LGBMRegressor.__init__.""" + self._client = client self.client = client super().__init__( boosting_type=boosting_type, @@ -762,6 +763,7 @@ def __init__( **kwargs: Any ): """Docstring is inherited from the lightgbm.LGBMRanker.__init__.""" + self._client = client self.client = client super().__init__( boosting_type=boosting_type, diff --git a/tests/python_package_test/test_dask.py b/tests/python_package_test/test_dask.py index 591cd34808eb..1cd5489140e1 100644 --- a/tests/python_package_test/test_dask.py +++ b/tests/python_package_test/test_dask.py @@ -480,45 +480,54 @@ def test_training_works_if_client_not_provided_or_set_after_construction(task, l # fit should work if client isn't provided dask_model = model_factory(**params) assert dask_model._client is None - assert dask_model.client == client + assert dask_model.client is None + assert dask_model.client_ == client dask_model.fit(dX, dy, group=dg) assert dask_model.fitted_ assert dask_model._client is None - assert dask_model.client == client + assert dask_model.client is None + assert dask_model.client_ == client preds = dask_model.predict(dX) assert isinstance(preds, da.Array) assert dask_model.fitted_ assert dask_model._client is None - assert dask_model.client == client + assert dask_model.client is None + assert dask_model.client_ == client local_model = dask_model.to_local() - assert getattr(local_model, "_client", None) is None with pytest.raises(AttributeError): + local_model._client local_model.client + local_model.client_ # should be able to set client after construction dask_model = model_factory(**params) dask_model.set_params(client=client) assert dask_model._client == client assert dask_model.client == client + assert dask_model.client_ == client dask_model.fit(dX, dy, group=dg) assert dask_model.fitted_ assert dask_model._client == client assert dask_model.client == client + assert dask_model.client_ == client preds = dask_model.predict(dX) assert isinstance(preds, da.Array) assert dask_model.fitted_ assert dask_model._client == client assert dask_model.client == client + assert dask_model.client_ == client local_model = dask_model.to_local() assert getattr(local_model, "_client", None) is None with pytest.raises(AttributeError): + local_model._client local_model.client + local_model.client_ client.close(timeout=CLIENT_CLOSE_TIMEOUT) @@ -526,9 +535,6 @@ def test_training_works_if_client_not_provided_or_set_after_construction(task, l @pytest.mark.parametrize('serializer', ['pickle', 'joblib', 'cloudpickle']) @pytest.mark.parametrize('task', ['classification', 'regression', 'ranking']) @pytest.mark.parametrize('set_client', [True, False]) -# @pytest.mark.parametrize('serializer', ['pickle']) -# @pytest.mark.parametrize('task', ['classification', 'regression', 'ranking']) -# @pytest.mark.parametrize('set_client', [True]) def test_model_and_local_version_are_picklable_whether_or_not_client_set_explicitly(serializer, task, set_client, listen_port, client, tmp_path): if task == 'ranking': X, _, _, _, dX, dy, _, dg = _create_ranking_data( @@ -563,10 +569,12 @@ def test_model_and_local_version_are_picklable_whether_or_not_client_set_explici local_model = dask_model.to_local() if set_client: assert dask_model._client == client + assert dask_model.client == client else: assert dask_model._client is None + assert dask_model.client is None - assert dask_model.client == client + assert dask_model.client_ == client assert "client" not in local_model.get_params() assert getattr(local_model, "client", None) is None @@ -594,11 +602,22 @@ def test_model_and_local_version_are_picklable_whether_or_not_client_set_explici if set_client: assert dask_model._client == client + assert dask_model.client == client else: assert dask_model._client is None + assert dask_model.client is None assert model_from_disk._client is None - assert model_from_disk.client == client - assert model_from_disk.get_params() == dask_model.get_params() + assert model_from_disk.client is None + assert model_from_disk.client_ == client + # client will always be None after unpickling + if set_client: + from_disk_params = model_from_disk.get_params() + from_disk_params.pop("client", None) + dask_params = dask_model.get_params() + dask_params.pop("client", None) + assert from_disk_params == dask_params + else: + assert model_from_disk.get_params() == dask_model.get_params() assert local_model_from_disk.get_params() == local_model.get_params() # fitted model should survive pickling round trip, and pickling @@ -607,7 +626,10 @@ def test_model_and_local_version_are_picklable_whether_or_not_client_set_explici local_model = dask_model.to_local() assert "client" not in local_model.get_params() - assert getattr(local_model, "client", None) is None + with pytest.raises(AttributeError): + local_model._client + local_model.client + local_model.client_ tmp_file2 = str(tmp_path / "model-2.pkl") _pickle( @@ -633,12 +655,25 @@ def test_model_and_local_version_are_picklable_whether_or_not_client_set_explici if set_client: assert dask_model._client == client + assert dask_model.client == client else: assert dask_model._client is None + assert dask_model.client is None assert isinstance(fitted_model_from_disk, model_factory) assert fitted_model_from_disk._client is None - assert fitted_model_from_disk.client == client - assert fitted_model_from_disk.get_params() == dask_model.get_params() + assert fitted_model_from_disk.client is None + assert fitted_model_from_disk.client_ == client + + # client will always be None after unpickling + if set_client: + from_disk_params = fitted_model_from_disk.get_params() + from_disk_params.pop("client", None) + dask_params = dask_model.get_params() + dask_params.pop("client", None) + assert from_disk_params == dask_params + else: + assert fitted_model_from_disk.get_params() == dask_model.get_params() + assert local_fitted_model_from_disk.get_params() == local_model.get_params() preds_orig = dask_model.predict(dX).compute() preds_loaded_model = fitted_model_from_disk.predict(dX).compute()