Skip to content

Commit

Permalink
fix tests, add client_ property
Browse files Browse the repository at this point in the history
  • Loading branch information
jameslamb committed Feb 1, 2021
1 parent 80dc6b9 commit af269b5
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 23 deletions.
22 changes: 12 additions & 10 deletions python-package/lightgbm/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,26 +438,25 @@ class _DaskLGBMModel:
_client: Optional[Client] = None

@property
def client(self) -> Client:
def client_(self) -> Client:
"""Dask client.
This property can be passed in the constructor or directly assigned
like ``model.set_params(client=client)``.
This property can be passed in the constructor or updated
with ``model.set_params(client=client)``.
"""
if self._client is None:
return default_client()
else:
return self._client

@client.setter
def client(self, client: Client) -> None:
self._client = client

def _lgb_getstate(self) -> Dict[Any, Any]:
"""Remove un-picklable attributes before serialization."""
out = deepcopy(self.__dict__)
client = out.pop("_client", None)
client = self.__dict__.pop("client", None)
self.__dict__.pop("_client", None)
self._other_params.pop("client", None)
out = deepcopy(self.__dict__)
out.update({"_client": None, "client": None})
self._client = client
self.client = client
return out

Expand All @@ -477,7 +476,7 @@ def _fit(
params.pop("client", None)

model = _train(
client=self.client,
client=self.client_,
data=X,
label=y,
params=params,
Expand Down Expand Up @@ -539,6 +538,7 @@ def __init__(
**kwargs: Any
):
"""Docstring is inherited from the lightgbm.LGBMClassifier.__init__."""
self._client = client
self.client = client
super().__init__(
boosting_type=boosting_type,
Expand Down Expand Up @@ -656,6 +656,7 @@ def __init__(
**kwargs: Any
):
"""Docstring is inherited from the lightgbm.LGBMRegressor.__init__."""
self._client = client
self.client = client
super().__init__(
boosting_type=boosting_type,
Expand Down Expand Up @@ -762,6 +763,7 @@ def __init__(
**kwargs: Any
):
"""Docstring is inherited from the lightgbm.LGBMRanker.__init__."""
self._client = client
self.client = client
super().__init__(
boosting_type=boosting_type,
Expand Down
61 changes: 48 additions & 13 deletions tests/python_package_test/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,55 +480,61 @@ def test_training_works_if_client_not_provided_or_set_after_construction(task, l
# fit should work if client isn't provided
dask_model = model_factory(**params)
assert dask_model._client is None
assert dask_model.client == client
assert dask_model.client is None
assert dask_model.client_ == client

dask_model.fit(dX, dy, group=dg)
assert dask_model.fitted_
assert dask_model._client is None
assert dask_model.client == client
assert dask_model.client is None
assert dask_model.client_ == client

preds = dask_model.predict(dX)
assert isinstance(preds, da.Array)
assert dask_model.fitted_
assert dask_model._client is None
assert dask_model.client == client
assert dask_model.client is None
assert dask_model.client_ == client

local_model = dask_model.to_local()
assert getattr(local_model, "_client", None) is None
with pytest.raises(AttributeError):
local_model._client
local_model.client
local_model.client_

# should be able to set client after construction
dask_model = model_factory(**params)
dask_model.set_params(client=client)
assert dask_model._client == client
assert dask_model.client == client
assert dask_model.client_ == client

dask_model.fit(dX, dy, group=dg)
assert dask_model.fitted_
assert dask_model._client == client
assert dask_model.client == client
assert dask_model.client_ == client

preds = dask_model.predict(dX)
assert isinstance(preds, da.Array)
assert dask_model.fitted_
assert dask_model._client == client
assert dask_model.client == client
assert dask_model.client_ == client

local_model = dask_model.to_local()
assert getattr(local_model, "_client", None) is None
with pytest.raises(AttributeError):
local_model._client
local_model.client
local_model.client_

client.close(timeout=CLIENT_CLOSE_TIMEOUT)


@pytest.mark.parametrize('serializer', ['pickle', 'joblib', 'cloudpickle'])
@pytest.mark.parametrize('task', ['classification', 'regression', 'ranking'])
@pytest.mark.parametrize('set_client', [True, False])
# @pytest.mark.parametrize('serializer', ['pickle'])
# @pytest.mark.parametrize('task', ['classification', 'regression', 'ranking'])
# @pytest.mark.parametrize('set_client', [True])
def test_model_and_local_version_are_picklable_whether_or_not_client_set_explicitly(serializer, task, set_client, listen_port, client, tmp_path):
if task == 'ranking':
X, _, _, _, dX, dy, _, dg = _create_ranking_data(
Expand Down Expand Up @@ -563,10 +569,12 @@ def test_model_and_local_version_are_picklable_whether_or_not_client_set_explici
local_model = dask_model.to_local()
if set_client:
assert dask_model._client == client
assert dask_model.client == client
else:
assert dask_model._client is None
assert dask_model.client is None

assert dask_model.client == client
assert dask_model.client_ == client
assert "client" not in local_model.get_params()
assert getattr(local_model, "client", None) is None

Expand Down Expand Up @@ -594,11 +602,22 @@ def test_model_and_local_version_are_picklable_whether_or_not_client_set_explici

if set_client:
assert dask_model._client == client
assert dask_model.client == client
else:
assert dask_model._client is None
assert dask_model.client is None
assert model_from_disk._client is None
assert model_from_disk.client == client
assert model_from_disk.get_params() == dask_model.get_params()
assert model_from_disk.client is None
assert model_from_disk.client_ == client
# client will always be None after unpickling
if set_client:
from_disk_params = model_from_disk.get_params()
from_disk_params.pop("client", None)
dask_params = dask_model.get_params()
dask_params.pop("client", None)
assert from_disk_params == dask_params
else:
assert model_from_disk.get_params() == dask_model.get_params()
assert local_model_from_disk.get_params() == local_model.get_params()

# fitted model should survive pickling round trip, and pickling
Expand All @@ -607,7 +626,10 @@ def test_model_and_local_version_are_picklable_whether_or_not_client_set_explici
local_model = dask_model.to_local()

assert "client" not in local_model.get_params()
assert getattr(local_model, "client", None) is None
with pytest.raises(AttributeError):
local_model._client
local_model.client
local_model.client_

tmp_file2 = str(tmp_path / "model-2.pkl")
_pickle(
Expand All @@ -633,12 +655,25 @@ def test_model_and_local_version_are_picklable_whether_or_not_client_set_explici

if set_client:
assert dask_model._client == client
assert dask_model.client == client
else:
assert dask_model._client is None
assert dask_model.client is None
assert isinstance(fitted_model_from_disk, model_factory)
assert fitted_model_from_disk._client is None
assert fitted_model_from_disk.client == client
assert fitted_model_from_disk.get_params() == dask_model.get_params()
assert fitted_model_from_disk.client is None
assert fitted_model_from_disk.client_ == client

# client will always be None after unpickling
if set_client:
from_disk_params = fitted_model_from_disk.get_params()
from_disk_params.pop("client", None)
dask_params = dask_model.get_params()
dask_params.pop("client", None)
assert from_disk_params == dask_params
else:
assert fitted_model_from_disk.get_params() == dask_model.get_params()
assert local_fitted_model_from_disk.get_params() == local_model.get_params()

preds_orig = dask_model.predict(dX).compute()
preds_loaded_model = fitted_model_from_disk.predict(dX).compute()
Expand Down

0 comments on commit af269b5

Please sign in to comment.