-
Notifications
You must be signed in to change notification settings - Fork 3.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[dask] remove 'client' kwarg from fit() and predict() (fixes #3808) #3883
Changes from 16 commits
5871b00
7b95654
e9eaeb2
2ed56d9
b8e53ed
f805939
eb2aee0
fb51493
344376b
b0cf6c6
56eb582
4a42133
87f76aa
80dc6b9
af269b5
22e6046
8cd6101
555a57a
dd3cfe0
e951669
876bfe5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -9,7 +9,7 @@ | |||||||||||
import socket | ||||||||||||
from collections import defaultdict | ||||||||||||
from copy import deepcopy | ||||||||||||
from typing import Any, Dict, Iterable, List, Optional, Type, Union | ||||||||||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Type, Union | ||||||||||||
from urllib.parse import urlparse | ||||||||||||
|
||||||||||||
import numpy as np | ||||||||||||
|
@@ -434,25 +434,49 @@ def _predict( | |||||||||||
|
||||||||||||
class _DaskLGBMModel: | ||||||||||||
|
||||||||||||
# self._client is set in the constructor of classes that use this mixin | ||||||||||||
_client: Optional[Client] = None | ||||||||||||
|
||||||||||||
@property | ||||||||||||
def client_(self) -> Client: | ||||||||||||
"""Dask client. | ||||||||||||
|
||||||||||||
This property can be passed in the constructor or updated | ||||||||||||
with ``model.set_params(client=client)``. | ||||||||||||
""" | ||||||||||||
if self._client is None: | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sklearn requires that unfitted models raise
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh I see, ok. Why should the check be on I'm also confused how you would like
Sure, I think that's fine. I think it's an ok fit for this PR and will add it here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
For historical reasons.
Oh I see now! Can you pass There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Maybe in a follow-up PR? I'm afraid that it can fail now and stop merging There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I want to be really sure that there is exactly one place where resolve which client to use (#3883 (comment)). So I don't want to put this code into the body of if self.client is None:
client = default_client()
else:
client = self.client I think it would work to pull that out into a small function like this: def _choose_client(client: Optional[Client]):
if self.client is None:
return default_client()
else:
return self.client Then use that in both the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK, sounds good! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added #3894 to capture the task of adding a scikit-learn compatibility test. I made this a "good first issue" because I think it could be done without deep knowledge of LightGBM, but I also think it's ok for you or me to pick up in the near future (we don't need to reserve it for new contributors). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok I think I've addressed these suggestions in 555a57a
Note that the change to have two clusters active in the same test will result in a lot of this warning in the logs:
This is not something we have to worry about, and I'd rather leave it alone and let Dask do the right thing (picking a random port for the scheduler when the default one is unavailable) than add more complexity to the tests by adding our own logic to set the scheduler port. |
||||||||||||
return default_client() | ||||||||||||
else: | ||||||||||||
return self._client | ||||||||||||
|
||||||||||||
def _lgb_getstate(self) -> Dict[Any, Any]: | ||||||||||||
"""Remove un-picklable attributes before serialization.""" | ||||||||||||
client = self.__dict__.pop("client", None) | ||||||||||||
self.__dict__.pop("_client", None) | ||||||||||||
self._other_params.pop("client", None) | ||||||||||||
out = deepcopy(self.__dict__) | ||||||||||||
out.update({"_client": None, "client": None}) | ||||||||||||
self._client = client | ||||||||||||
self.client = client | ||||||||||||
return out | ||||||||||||
|
||||||||||||
def _fit( | ||||||||||||
self, | ||||||||||||
model_factory: Type[LGBMModel], | ||||||||||||
X: _DaskMatrixLike, | ||||||||||||
y: _DaskCollection, | ||||||||||||
sample_weight: Optional[_DaskCollection] = None, | ||||||||||||
group: Optional[_DaskCollection] = None, | ||||||||||||
client: Optional[Client] = None, | ||||||||||||
**kwargs: Any | ||||||||||||
) -> "_DaskLGBMModel": | ||||||||||||
if not all((DASK_INSTALLED, PANDAS_INSTALLED, SKLEARN_INSTALLED)): | ||||||||||||
raise LightGBMError('dask, pandas and scikit-learn are required for lightgbm.dask') | ||||||||||||
if client is None: | ||||||||||||
client = default_client() | ||||||||||||
|
||||||||||||
params = self.get_params(True) | ||||||||||||
params.pop("client", None) | ||||||||||||
|
||||||||||||
model = _train( | ||||||||||||
client=client, | ||||||||||||
client=self.client_, | ||||||||||||
data=X, | ||||||||||||
label=y, | ||||||||||||
params=params, | ||||||||||||
|
@@ -468,8 +492,11 @@ def _fit( | |||||||||||
return self | ||||||||||||
|
||||||||||||
def _to_local(self, model_factory: Type[LGBMModel]) -> LGBMModel: | ||||||||||||
model = model_factory(**self.get_params()) | ||||||||||||
params = self.get_params() | ||||||||||||
params.pop("client", None) | ||||||||||||
model = model_factory(**params) | ||||||||||||
self._copy_extra_params(self, model) | ||||||||||||
model._other_params.pop("client", None) | ||||||||||||
return model | ||||||||||||
|
||||||||||||
@staticmethod | ||||||||||||
|
@@ -478,18 +505,82 @@ def _copy_extra_params(source: Union["_DaskLGBMModel", LGBMModel], dest: Union[" | |||||||||||
attributes = source.__dict__ | ||||||||||||
extra_param_names = set(attributes.keys()).difference(params.keys()) | ||||||||||||
for name in extra_param_names: | ||||||||||||
setattr(dest, name, attributes[name]) | ||||||||||||
if name != "_client": | ||||||||||||
setattr(dest, name, attributes[name]) | ||||||||||||
|
||||||||||||
|
||||||||||||
class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel): | ||||||||||||
"""Distributed version of lightgbm.LGBMClassifier.""" | ||||||||||||
|
||||||||||||
def __init__( | ||||||||||||
self, | ||||||||||||
boosting_type: str = 'gbdt', | ||||||||||||
num_leaves: int = 31, | ||||||||||||
max_depth: int = -1, | ||||||||||||
learning_rate: float = 0.1, | ||||||||||||
n_estimators: int = 100, | ||||||||||||
subsample_for_bin: int = 200000, | ||||||||||||
objective: Optional[Union[Callable, str]] = None, | ||||||||||||
class_weight: Optional[Union[dict, str]] = None, | ||||||||||||
min_split_gain: float = 0., | ||||||||||||
min_child_weight: float = 1e-3, | ||||||||||||
min_child_samples: int = 20, | ||||||||||||
subsample: float = 1., | ||||||||||||
subsample_freq: int = 0, | ||||||||||||
colsample_bytree: float = 1., | ||||||||||||
reg_alpha: float = 0., | ||||||||||||
reg_lambda: float = 0., | ||||||||||||
random_state: Optional[Union[int, np.random.RandomState]] = None, | ||||||||||||
n_jobs: int = -1, | ||||||||||||
silent: bool = True, | ||||||||||||
importance_type: str = 'split', | ||||||||||||
client: Optional[Client] = None, | ||||||||||||
**kwargs: Any | ||||||||||||
): | ||||||||||||
"""Docstring is inherited from the lightgbm.LGBMClassifier.__init__.""" | ||||||||||||
self._client = client | ||||||||||||
self.client = client | ||||||||||||
super().__init__( | ||||||||||||
boosting_type=boosting_type, | ||||||||||||
num_leaves=num_leaves, | ||||||||||||
max_depth=max_depth, | ||||||||||||
learning_rate=learning_rate, | ||||||||||||
n_estimators=n_estimators, | ||||||||||||
subsample_for_bin=subsample_for_bin, | ||||||||||||
objective=objective, | ||||||||||||
class_weight=class_weight, | ||||||||||||
min_split_gain=min_split_gain, | ||||||||||||
min_child_weight=min_child_weight, | ||||||||||||
min_child_samples=min_child_samples, | ||||||||||||
subsample=subsample, | ||||||||||||
subsample_freq=subsample_freq, | ||||||||||||
colsample_bytree=colsample_bytree, | ||||||||||||
reg_alpha=reg_alpha, | ||||||||||||
reg_lambda=reg_lambda, | ||||||||||||
random_state=random_state, | ||||||||||||
n_jobs=n_jobs, | ||||||||||||
silent=silent, | ||||||||||||
importance_type=importance_type, | ||||||||||||
**kwargs | ||||||||||||
) | ||||||||||||
|
||||||||||||
_base_doc = LGBMClassifier.__init__.__doc__ | ||||||||||||
_before_kwargs, _kwargs, _after_kwargs = _base_doc.partition('**kwargs') | ||||||||||||
__init__.__doc__ = ( | ||||||||||||
_before_kwargs | ||||||||||||
+ 'client : dask.distributed.Client or None, optional (default=None)\n' | ||||||||||||
+ ' ' * 12 + 'Dask client. If ``None``, ``distributed.default_client()`` will be used at runtime. This client will not be saved if the model object is pickled.\n' | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. TBH, it was not clear enough that any client will not be saved. "This" may refer to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok I can change it There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. updated in 555a57a to this
|
||||||||||||
+ ' ' * 8 + _kwargs + _after_kwargs | ||||||||||||
) | ||||||||||||
|
||||||||||||
def __getstate__(self) -> Dict[Any, Any]: | ||||||||||||
return self._lgb_getstate() | ||||||||||||
|
||||||||||||
def fit( | ||||||||||||
self, | ||||||||||||
X: _DaskMatrixLike, | ||||||||||||
y: _DaskCollection, | ||||||||||||
sample_weight: Optional[_DaskCollection] = None, | ||||||||||||
client: Optional[Client] = None, | ||||||||||||
**kwargs: Any | ||||||||||||
) -> "DaskLGBMClassifier": | ||||||||||||
"""Docstring is inherited from the lightgbm.LGBMClassifier.fit.""" | ||||||||||||
|
@@ -498,16 +589,10 @@ def fit( | |||||||||||
X=X, | ||||||||||||
y=y, | ||||||||||||
sample_weight=sample_weight, | ||||||||||||
client=client, | ||||||||||||
**kwargs | ||||||||||||
) | ||||||||||||
|
||||||||||||
_base_doc = LGBMClassifier.fit.__doc__ | ||||||||||||
_before_init_score, _init_score, _after_init_score = _base_doc.partition('init_score :') | ||||||||||||
fit.__doc__ = (_before_init_score | ||||||||||||
+ 'client : dask.distributed.Client or None, optional (default=None)\n' | ||||||||||||
+ ' ' * 12 + 'Dask client.\n' | ||||||||||||
+ ' ' * 8 + _init_score + _after_init_score) | ||||||||||||
fit.__doc__ = LGBMClassifier.fit.__doc__ | ||||||||||||
|
||||||||||||
def predict(self, X: _DaskMatrixLike, **kwargs: Any) -> dask_Array: | ||||||||||||
"""Docstring is inherited from the lightgbm.LGBMClassifier.predict.""" | ||||||||||||
|
@@ -545,6 +630,70 @@ def to_local(self) -> LGBMClassifier: | |||||||||||
class DaskLGBMRegressor(LGBMRegressor, _DaskLGBMModel): | ||||||||||||
"""Distributed version of lightgbm.LGBMRegressor.""" | ||||||||||||
|
||||||||||||
def __init__( | ||||||||||||
self, | ||||||||||||
boosting_type: str = 'gbdt', | ||||||||||||
num_leaves: int = 31, | ||||||||||||
max_depth: int = -1, | ||||||||||||
learning_rate: float = 0.1, | ||||||||||||
n_estimators: int = 100, | ||||||||||||
subsample_for_bin: int = 200000, | ||||||||||||
objective: Optional[Union[Callable, str]] = None, | ||||||||||||
class_weight: Optional[Union[dict, str]] = None, | ||||||||||||
min_split_gain: float = 0., | ||||||||||||
min_child_weight: float = 1e-3, | ||||||||||||
min_child_samples: int = 20, | ||||||||||||
subsample: float = 1., | ||||||||||||
subsample_freq: int = 0, | ||||||||||||
colsample_bytree: float = 1., | ||||||||||||
reg_alpha: float = 0., | ||||||||||||
reg_lambda: float = 0., | ||||||||||||
random_state: Optional[Union[int, np.random.RandomState]] = None, | ||||||||||||
n_jobs: int = -1, | ||||||||||||
silent: bool = True, | ||||||||||||
importance_type: str = 'split', | ||||||||||||
client: Optional[Client] = None, | ||||||||||||
**kwargs: Any | ||||||||||||
): | ||||||||||||
"""Docstring is inherited from the lightgbm.LGBMRegressor.__init__.""" | ||||||||||||
self._client = client | ||||||||||||
self.client = client | ||||||||||||
super().__init__( | ||||||||||||
boosting_type=boosting_type, | ||||||||||||
num_leaves=num_leaves, | ||||||||||||
max_depth=max_depth, | ||||||||||||
learning_rate=learning_rate, | ||||||||||||
n_estimators=n_estimators, | ||||||||||||
subsample_for_bin=subsample_for_bin, | ||||||||||||
objective=objective, | ||||||||||||
class_weight=class_weight, | ||||||||||||
min_split_gain=min_split_gain, | ||||||||||||
min_child_weight=min_child_weight, | ||||||||||||
min_child_samples=min_child_samples, | ||||||||||||
subsample=subsample, | ||||||||||||
subsample_freq=subsample_freq, | ||||||||||||
colsample_bytree=colsample_bytree, | ||||||||||||
reg_alpha=reg_alpha, | ||||||||||||
reg_lambda=reg_lambda, | ||||||||||||
random_state=random_state, | ||||||||||||
n_jobs=n_jobs, | ||||||||||||
silent=silent, | ||||||||||||
importance_type=importance_type, | ||||||||||||
**kwargs | ||||||||||||
) | ||||||||||||
|
||||||||||||
_base_doc = LGBMRegressor.__init__.__doc__ | ||||||||||||
_before_kwargs, _kwargs, _after_kwargs = _base_doc.partition('**kwargs') | ||||||||||||
__init__.__doc__ = ( | ||||||||||||
_before_kwargs | ||||||||||||
+ 'client : dask.distributed.Client or None, optional (default=None)\n' | ||||||||||||
+ ' ' * 12 + 'Dask client. If ``None``, ``distributed.default_client()`` will be used at runtime. This client will not be saved if the model object is pickled.\n' | ||||||||||||
+ ' ' * 8 + _kwargs + _after_kwargs | ||||||||||||
) | ||||||||||||
|
||||||||||||
def __getstate__(self) -> Dict[Any, Any]: | ||||||||||||
return self._lgb_getstate() | ||||||||||||
|
||||||||||||
def fit( | ||||||||||||
self, | ||||||||||||
X: _DaskMatrixLike, | ||||||||||||
|
@@ -559,16 +708,10 @@ def fit( | |||||||||||
X=X, | ||||||||||||
y=y, | ||||||||||||
sample_weight=sample_weight, | ||||||||||||
client=client, | ||||||||||||
**kwargs | ||||||||||||
) | ||||||||||||
|
||||||||||||
_base_doc = LGBMRegressor.fit.__doc__ | ||||||||||||
_before_init_score, _init_score, _after_init_score = _base_doc.partition('init_score :') | ||||||||||||
fit.__doc__ = (_before_init_score | ||||||||||||
+ 'client : dask.distributed.Client or None, optional (default=None)\n' | ||||||||||||
+ ' ' * 12 + 'Dask client.\n' | ||||||||||||
+ ' ' * 8 + _init_score + _after_init_score) | ||||||||||||
fit.__doc__ = LGBMRegressor.fit.__doc__ | ||||||||||||
|
||||||||||||
def predict(self, X: _DaskMatrixLike, **kwargs) -> dask_Array: | ||||||||||||
"""Docstring is inherited from the lightgbm.LGBMRegressor.predict.""" | ||||||||||||
|
@@ -594,14 +737,77 @@ def to_local(self) -> LGBMRegressor: | |||||||||||
class DaskLGBMRanker(LGBMRanker, _DaskLGBMModel): | ||||||||||||
"""Distributed version of lightgbm.LGBMRanker.""" | ||||||||||||
|
||||||||||||
def __init__( | ||||||||||||
self, | ||||||||||||
boosting_type: str = 'gbdt', | ||||||||||||
num_leaves: int = 31, | ||||||||||||
max_depth: int = -1, | ||||||||||||
learning_rate: float = 0.1, | ||||||||||||
n_estimators: int = 100, | ||||||||||||
subsample_for_bin: int = 200000, | ||||||||||||
objective: Optional[Union[Callable, str]] = None, | ||||||||||||
class_weight: Optional[Union[dict, str]] = None, | ||||||||||||
min_split_gain: float = 0., | ||||||||||||
min_child_weight: float = 1e-3, | ||||||||||||
min_child_samples: int = 20, | ||||||||||||
subsample: float = 1., | ||||||||||||
subsample_freq: int = 0, | ||||||||||||
colsample_bytree: float = 1., | ||||||||||||
reg_alpha: float = 0., | ||||||||||||
reg_lambda: float = 0., | ||||||||||||
random_state: Optional[Union[int, np.random.RandomState]] = None, | ||||||||||||
n_jobs: int = -1, | ||||||||||||
silent: bool = True, | ||||||||||||
importance_type: str = 'split', | ||||||||||||
client: Optional[Client] = None, | ||||||||||||
**kwargs: Any | ||||||||||||
): | ||||||||||||
"""Docstring is inherited from the lightgbm.LGBMRanker.__init__.""" | ||||||||||||
self._client = client | ||||||||||||
self.client = client | ||||||||||||
super().__init__( | ||||||||||||
boosting_type=boosting_type, | ||||||||||||
num_leaves=num_leaves, | ||||||||||||
max_depth=max_depth, | ||||||||||||
learning_rate=learning_rate, | ||||||||||||
n_estimators=n_estimators, | ||||||||||||
subsample_for_bin=subsample_for_bin, | ||||||||||||
objective=objective, | ||||||||||||
class_weight=class_weight, | ||||||||||||
min_split_gain=min_split_gain, | ||||||||||||
min_child_weight=min_child_weight, | ||||||||||||
min_child_samples=min_child_samples, | ||||||||||||
subsample=subsample, | ||||||||||||
subsample_freq=subsample_freq, | ||||||||||||
colsample_bytree=colsample_bytree, | ||||||||||||
reg_alpha=reg_alpha, | ||||||||||||
reg_lambda=reg_lambda, | ||||||||||||
random_state=random_state, | ||||||||||||
n_jobs=n_jobs, | ||||||||||||
silent=silent, | ||||||||||||
importance_type=importance_type, | ||||||||||||
**kwargs | ||||||||||||
) | ||||||||||||
|
||||||||||||
_base_doc = LGBMRanker.__init__.__doc__ | ||||||||||||
_before_kwargs, _kwargs, _after_kwargs = _base_doc.partition('**kwargs') | ||||||||||||
__init__.__doc__ = ( | ||||||||||||
_before_kwargs | ||||||||||||
+ 'client : dask.distributed.Client or None, optional (default=None)\n' | ||||||||||||
+ ' ' * 12 + 'Dask client. If ``None``, ``distributed.default_client()`` will be used at runtime. This client will not be saved if the model object is pickled.\n' | ||||||||||||
+ ' ' * 8 + _kwargs + _after_kwargs | ||||||||||||
) | ||||||||||||
|
||||||||||||
def __getstate__(self) -> Dict[Any, Any]: | ||||||||||||
return self._lgb_getstate() | ||||||||||||
|
||||||||||||
def fit( | ||||||||||||
self, | ||||||||||||
X: _DaskMatrixLike, | ||||||||||||
y: _DaskCollection, | ||||||||||||
sample_weight: Optional[_DaskCollection] = None, | ||||||||||||
init_score: Optional[_DaskCollection] = None, | ||||||||||||
group: Optional[_DaskCollection] = None, | ||||||||||||
client: Optional[Client] = None, | ||||||||||||
**kwargs: Any | ||||||||||||
) -> "DaskLGBMRanker": | ||||||||||||
"""Docstring is inherited from the lightgbm.LGBMRanker.fit.""" | ||||||||||||
|
@@ -614,16 +820,10 @@ def fit( | |||||||||||
y=y, | ||||||||||||
sample_weight=sample_weight, | ||||||||||||
group=group, | ||||||||||||
client=client, | ||||||||||||
**kwargs | ||||||||||||
) | ||||||||||||
|
||||||||||||
_base_doc = LGBMRanker.fit.__doc__ | ||||||||||||
_before_eval_set, _eval_set, _after_eval_set = _base_doc.partition('eval_set :') | ||||||||||||
fit.__doc__ = (_before_eval_set | ||||||||||||
+ 'client : dask.distributed.Client or None, optional (default=None)\n' | ||||||||||||
+ ' ' * 12 + 'Dask client.\n' | ||||||||||||
+ ' ' * 8 + _eval_set + _after_eval_set) | ||||||||||||
fit.__doc__ = LGBMRanker.fit.__doc__ | ||||||||||||
|
||||||||||||
def predict(self, X: _DaskMatrixLike, **kwargs: Any) -> dask_Array: | ||||||||||||
"""Docstring is inherited from the lightgbm.LGBMRanker.predict.""" | ||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like
_client
was replaced with_get_dask_client()
function in the latest commit and is not needed anymore.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh good point! Removed in 876bfe5