Skip to content

Commit

Permalink
use client kwarg
Browse files Browse the repository at this point in the history
  • Loading branch information
jameslamb committed Jan 31, 2021
1 parent fb51493 commit 344376b
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 16 deletions.
2 changes: 1 addition & 1 deletion .ci/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ if [[ $TASK == "swig" ]]; then
exit 0
fi

conda install -q -y -n $CONDA_ENV dask distributed joblib matplotlib numpy pandas psutil pytest scikit-learn scipy
conda install -q -y -n $CONDA_ENV cloudpickle dask distributed joblib matplotlib numpy pandas psutil pytest scikit-learn scipy

# graphviz must come from conda-forge to avoid this on some linux distros:
# https://github.com/conda-forge/graphviz-feedstock/issues/18
Expand Down
27 changes: 23 additions & 4 deletions python-package/lightgbm/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,7 @@ def client(self) -> Client:
"""Dask client.
This property can be passed in the constructor or directly assigned
like ``model.client = client``.
like ``model.set_params(client=client)``.
"""
if self._client is None:
return default_client()
Expand All @@ -453,6 +453,13 @@ def client(self) -> Client:
def client(self, client: Client) -> None:
self._client = client

def _lgb_getstate(self) -> Dict[Any, Any]:
"""Remove un-picklable attributes before serialization."""
client = self.__dict__.pop("_client", None)
out = copy.deepcopy(self.__dict__)
self.set_params(client=client)
return out

def _fit(
self,
model_factory: Type[LGBMModel],
Expand Down Expand Up @@ -508,7 +515,7 @@ def __init__(self, boosting_type='gbdt', num_leaves=31, max_depth=-1,
min_split_gain=0., min_child_weight=1e-3, min_child_samples=20,
subsample=1., subsample_freq=0, colsample_bytree=1.,
reg_alpha=0., reg_lambda=0., random_state=None,
n_jobs=-1, silent=True, importance_type='split', **kwargs):
n_jobs=-1, silent=True, importance_type='split', client=None, **kwargs):
super().__init__(
boosting_type=boosting_type,
num_leaves=num_leaves,
Expand All @@ -532,6 +539,7 @@ def __init__(self, boosting_type='gbdt', num_leaves=31, max_depth=-1,
importance_type=importance_type,
**kwargs
)
self.set_params(client=client)

_base_doc = LGBMClassifier.__init__.__doc__
_before_kwargs, _kwargs, _after_kwargs = _base_doc.partition('**kwargs')
Expand All @@ -542,6 +550,9 @@ def __init__(self, boosting_type='gbdt', num_leaves=31, max_depth=-1,
+ ' ' * 8 + _kwargs + _after_kwargs
)

def __getstate__(self) -> Dict[Any, Any]:
return self._lgb_getstate()

def fit(
self,
X: _DaskMatrixLike,
Expand Down Expand Up @@ -603,7 +614,7 @@ def __init__(self, boosting_type='gbdt', num_leaves=31, max_depth=-1,
min_split_gain=0., min_child_weight=1e-3, min_child_samples=20,
subsample=1., subsample_freq=0, colsample_bytree=1.,
reg_alpha=0., reg_lambda=0., random_state=None,
n_jobs=-1, silent=True, importance_type='split', **kwargs):
n_jobs=-1, silent=True, importance_type='split', client=None, **kwargs):
super().__init__(
boosting_type=boosting_type,
num_leaves=num_leaves,
Expand All @@ -627,6 +638,7 @@ def __init__(self, boosting_type='gbdt', num_leaves=31, max_depth=-1,
importance_type=importance_type,
**kwargs
)
self.set_params(client=client)

_base_doc = LGBMRegressor.__init__.__doc__
_before_kwargs, _kwargs, _after_kwargs = _base_doc.partition('**kwargs')
Expand All @@ -637,6 +649,9 @@ def __init__(self, boosting_type='gbdt', num_leaves=31, max_depth=-1,
+ ' ' * 8 + _kwargs + _after_kwargs
)

def __getstate__(self) -> Dict[Any, Any]:
return self._lgb_getstate()

def fit(
self,
X: _DaskMatrixLike,
Expand Down Expand Up @@ -687,7 +702,7 @@ def __init__(self, boosting_type='gbdt', num_leaves=31, max_depth=-1,
min_split_gain=0., min_child_weight=1e-3, min_child_samples=20,
subsample=1., subsample_freq=0, colsample_bytree=1.,
reg_alpha=0., reg_lambda=0., random_state=None,
n_jobs=-1, silent=True, importance_type='split', **kwargs):
n_jobs=-1, silent=True, importance_type='split', client=None, **kwargs):
super().__init__(
boosting_type=boosting_type,
num_leaves=num_leaves,
Expand All @@ -711,6 +726,7 @@ def __init__(self, boosting_type='gbdt', num_leaves=31, max_depth=-1,
importance_type=importance_type,
**kwargs
)
self.set_params(client=client)

_base_doc = LGBMRanker.__init__.__doc__
_before_kwargs, _kwargs, _after_kwargs = _base_doc.partition('**kwargs')
Expand All @@ -721,6 +737,9 @@ def __init__(self, boosting_type='gbdt', num_leaves=31, max_depth=-1,
+ ' ' * 8 + _kwargs + _after_kwargs
)

def __getstate__(self) -> Dict[Any, Any]:
return self._lgb_getstate()

def fit(
self,
X: _DaskMatrixLike,
Expand Down
13 changes: 2 additions & 11 deletions python-package/lightgbm/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,9 +291,6 @@ def __init__(self, boosting_type='gbdt', num_leaves=31, max_depth=-1,
if not SKLEARN_INSTALLED:
raise LightGBMError('scikit-learn is required for lightgbm.sklearn')

# Dask estimators inherit from this and may pass an argument "client"
self._client = kwargs.pop("client", None)

self.boosting_type = boosting_type
self.objective = objective
self.num_leaves = num_leaves
Expand Down Expand Up @@ -328,13 +325,6 @@ def __init__(self, boosting_type='gbdt', num_leaves=31, max_depth=-1,
self._n_classes = None
self.set_params(**kwargs)

def __getstate__(self):
"""Remove un-picklable attributes before serialization."""
client = self.__dict__.pop("_client", None)
out = copy.deepcopy(self.__dict__)
self._client = client
return out

def _more_tags(self):
return {
'allow_nan': True,
Expand Down Expand Up @@ -382,7 +372,8 @@ def set_params(self, **params):
setattr(self, key, value)
if hasattr(self, '_' + key):
setattr(self, '_' + key, value)
self._other_params[key] = value
if key != "client":
self._other_params[key] = value
return self

def fit(self, X, y,
Expand Down

0 comments on commit 344376b

Please sign in to comment.