diff --git a/.ci/test.sh b/.ci/test.sh index a150d33ffab0..da93d1b0f97f 100755 --- a/.ci/test.sh +++ b/.ci/test.sh @@ -90,7 +90,7 @@ if [[ $TASK == "swig" ]]; then exit 0 fi -conda install -q -y -n $CONDA_ENV dask distributed joblib matplotlib numpy pandas psutil pytest scikit-learn scipy +conda install -q -y -n $CONDA_ENV cloudpickle dask distributed joblib matplotlib numpy pandas psutil pytest scikit-learn scipy # graphviz must come from conda-forge to avoid this on some linux distros: # https://github.com/conda-forge/graphviz-feedstock/issues/18 diff --git a/python-package/lightgbm/dask.py b/python-package/lightgbm/dask.py index 35e401ea538c..c65e14187567 100644 --- a/python-package/lightgbm/dask.py +++ b/python-package/lightgbm/dask.py @@ -442,7 +442,7 @@ def client(self) -> Client: """Dask client. This property can be passed in the constructor or directly assigned - like ``model.client = client``. + like ``model.set_params(client=client)``. """ if self._client is None: return default_client() @@ -453,6 +453,13 @@ def client(self) -> Client: def client(self, client: Client) -> None: self._client = client + def _lgb_getstate(self) -> Dict[Any, Any]: + """Remove un-picklable attributes before serialization.""" + client = self.__dict__.pop("_client", None) + out = copy.deepcopy(self.__dict__) + self.set_params(client=client) + return out + def _fit( self, model_factory: Type[LGBMModel], @@ -508,7 +515,7 @@ def __init__(self, boosting_type='gbdt', num_leaves=31, max_depth=-1, min_split_gain=0., min_child_weight=1e-3, min_child_samples=20, subsample=1., subsample_freq=0, colsample_bytree=1., reg_alpha=0., reg_lambda=0., random_state=None, - n_jobs=-1, silent=True, importance_type='split', **kwargs): + n_jobs=-1, silent=True, importance_type='split', client=None, **kwargs): super().__init__( boosting_type=boosting_type, num_leaves=num_leaves, @@ -532,6 +539,7 @@ def __init__(self, boosting_type='gbdt', num_leaves=31, max_depth=-1, importance_type=importance_type, **kwargs ) + self.set_params(client=client) _base_doc = LGBMClassifier.__init__.__doc__ _before_kwargs, _kwargs, _after_kwargs = _base_doc.partition('**kwargs') @@ -542,6 +550,9 @@ def __init__(self, boosting_type='gbdt', num_leaves=31, max_depth=-1, + ' ' * 8 + _kwargs + _after_kwargs ) + def __getstate__(self) -> Dict[Any, Any]: + return self._lgb_getstate() + def fit( self, X: _DaskMatrixLike, @@ -603,7 +614,7 @@ def __init__(self, boosting_type='gbdt', num_leaves=31, max_depth=-1, min_split_gain=0., min_child_weight=1e-3, min_child_samples=20, subsample=1., subsample_freq=0, colsample_bytree=1., reg_alpha=0., reg_lambda=0., random_state=None, - n_jobs=-1, silent=True, importance_type='split', **kwargs): + n_jobs=-1, silent=True, importance_type='split', client=None, **kwargs): super().__init__( boosting_type=boosting_type, num_leaves=num_leaves, @@ -627,6 +638,7 @@ def __init__(self, boosting_type='gbdt', num_leaves=31, max_depth=-1, importance_type=importance_type, **kwargs ) + self.set_params(client=client) _base_doc = LGBMRegressor.__init__.__doc__ _before_kwargs, _kwargs, _after_kwargs = _base_doc.partition('**kwargs') @@ -637,6 +649,9 @@ def __init__(self, boosting_type='gbdt', num_leaves=31, max_depth=-1, + ' ' * 8 + _kwargs + _after_kwargs ) + def __getstate__(self) -> Dict[Any, Any]: + return self._lgb_getstate() + def fit( self, X: _DaskMatrixLike, @@ -687,7 +702,7 @@ def __init__(self, boosting_type='gbdt', num_leaves=31, max_depth=-1, min_split_gain=0., min_child_weight=1e-3, min_child_samples=20, subsample=1., subsample_freq=0, colsample_bytree=1., reg_alpha=0., reg_lambda=0., random_state=None, - n_jobs=-1, silent=True, importance_type='split', **kwargs): + n_jobs=-1, silent=True, importance_type='split', client=None, **kwargs): super().__init__( boosting_type=boosting_type, num_leaves=num_leaves, @@ -711,6 +726,7 @@ def __init__(self, boosting_type='gbdt', num_leaves=31, max_depth=-1, importance_type=importance_type, **kwargs ) + self.set_params(client=client) _base_doc = LGBMRanker.__init__.__doc__ _before_kwargs, _kwargs, _after_kwargs = _base_doc.partition('**kwargs') @@ -721,6 +737,9 @@ def __init__(self, boosting_type='gbdt', num_leaves=31, max_depth=-1, + ' ' * 8 + _kwargs + _after_kwargs ) + def __getstate__(self) -> Dict[Any, Any]: + return self._lgb_getstate() + def fit( self, X: _DaskMatrixLike, diff --git a/python-package/lightgbm/sklearn.py b/python-package/lightgbm/sklearn.py index 50e2c0d8c085..26c28a458512 100644 --- a/python-package/lightgbm/sklearn.py +++ b/python-package/lightgbm/sklearn.py @@ -291,9 +291,6 @@ def __init__(self, boosting_type='gbdt', num_leaves=31, max_depth=-1, if not SKLEARN_INSTALLED: raise LightGBMError('scikit-learn is required for lightgbm.sklearn') - # Dask estimators inherit from this and may pass an argument "client" - self._client = kwargs.pop("client", None) - self.boosting_type = boosting_type self.objective = objective self.num_leaves = num_leaves @@ -328,13 +325,6 @@ def __init__(self, boosting_type='gbdt', num_leaves=31, max_depth=-1, self._n_classes = None self.set_params(**kwargs) - def __getstate__(self): - """Remove un-picklable attributes before serialization.""" - client = self.__dict__.pop("_client", None) - out = copy.deepcopy(self.__dict__) - self._client = client - return out - def _more_tags(self): return { 'allow_nan': True, @@ -382,7 +372,8 @@ def set_params(self, **params): setattr(self, key, value) if hasattr(self, '_' + key): setattr(self, '_' + key, value) - self._other_params[key] = value + if key != "client": + self._other_params[key] = value return self def fit(self, X, y,