Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jameslamb committed Feb 2, 2021
1 parent 8cd6101 commit 555a57a
Show file tree
Hide file tree
Showing 2 changed files with 238 additions and 164 deletions.
37 changes: 28 additions & 9 deletions python-package/lightgbm/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from .basic import _choose_param_value, _ConfigAliases, _LIB, _log_warning, _safe_call, LightGBMError
from .compat import (PANDAS_INSTALLED, pd_DataFrame, pd_Series, concat,
SKLEARN_INSTALLED,
SKLEARN_INSTALLED, LGBMNotFittedError,
DASK_INSTALLED, dask_DataFrame, dask_Array, dask_Series, delayed, Client, default_client, get_worker, wait)
from .sklearn import LGBMClassifier, LGBMModel, LGBMRegressor, LGBMRanker

Expand All @@ -27,6 +27,25 @@
_PredictionDtype = Union[Type[np.float32], Type[np.float64], Type[np.int32], Type[np.int64]]


def _get_dask_client(client: Optional[Client]) -> Client:
"""Choose a Dask client to use
Parameters
----------
client : dask.distributed.Client or None
Dask client.
Returns
-------
client : dask.distributed.Client
A Dask client.
"""
if client is None:
return default_client()
else:
return client


def _find_open_port(worker_ip: str, local_listen_port: int, ports_to_skip: Iterable[int]) -> int:
"""Find an open port.
Expand Down Expand Up @@ -444,10 +463,10 @@ def client_(self) -> Client:
This property can be passed in the constructor or updated
with ``model.set_params(client=client)``.
"""
if self._client is None:
return default_client()
else:
return self._client
if not getattr(self, "fitted_", False):
raise LGBMNotFittedError('Cannot access property client_ before calling fit().')

return _get_dask_client(client=self.client)

def _lgb_getstate(self) -> Dict[Any, Any]:
"""Remove un-picklable attributes before serialization."""
Expand Down Expand Up @@ -476,7 +495,7 @@ def _fit(
params.pop("client", None)

model = _train(
client=self.client_,
client=_get_dask_client(self.client),
data=X,
label=y,
params=params,
Expand Down Expand Up @@ -569,7 +588,7 @@ def __init__(
__init__.__doc__ = (
_before_kwargs
+ 'client : dask.distributed.Client or None, optional (default=None)\n'
+ ' ' * 12 + 'Dask client. If ``None``, ``distributed.default_client()`` will be used at runtime. This client will not be saved if the model object is pickled.\n'
+ ' ' * 12 + 'Dask client. If ``None``, ``distributed.default_client()`` will be used at runtime. The Dask client used by this class will not be saved if the model object is pickled.\n'
+ ' ' * 8 + _kwargs + _after_kwargs
)

Expand Down Expand Up @@ -687,7 +706,7 @@ def __init__(
__init__.__doc__ = (
_before_kwargs
+ 'client : dask.distributed.Client or None, optional (default=None)\n'
+ ' ' * 12 + 'Dask client. If ``None``, ``distributed.default_client()`` will be used at runtime. This client will not be saved if the model object is pickled.\n'
+ ' ' * 12 + 'Dask client. If ``None``, ``distributed.default_client()`` will be used at runtime. The Dask client used by this class will not be saved if the model object is pickled.\n'
+ ' ' * 8 + _kwargs + _after_kwargs
)

Expand Down Expand Up @@ -794,7 +813,7 @@ def __init__(
__init__.__doc__ = (
_before_kwargs
+ 'client : dask.distributed.Client or None, optional (default=None)\n'
+ ' ' * 12 + 'Dask client. If ``None``, ``distributed.default_client()`` will be used at runtime. This client will not be saved if the model object is pickled.\n'
+ ' ' * 12 + 'Dask client. If ``None``, ``distributed.default_client()`` will be used at runtime. The Dask client used by this class will not be saved if the model object is pickled.\n'
+ ' ' * 8 + _kwargs + _after_kwargs
)

Expand Down
Loading

0 comments on commit 555a57a

Please sign in to comment.