-
Notifications
You must be signed in to change notification settings - Fork 3.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[dask] remove 'client' kwarg from fit() and predict() (fixes #3808) #3883
Changes from 7 commits
5871b00
7b95654
e9eaeb2
2ed56d9
b8e53ed
f805939
eb2aee0
fb51493
344376b
b0cf6c6
56eb582
4a42133
87f76aa
80dc6b9
af269b5
22e6046
8cd6101
555a57a
dd3cfe0
e951669
876bfe5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -434,6 +434,25 @@ def _predict( | |
|
||
class _DaskLGBMModel: | ||
|
||
# self._client is set in the constructor of lightgbm.sklearn.LGBMModel | ||
_client: Optional[Client] = None | ||
|
||
@property | ||
def client(self) -> Client: | ||
"""Dask client. | ||
|
||
This property can be passed in the constructor or directly assigned | ||
like ``model.client = client``. | ||
""" | ||
if self._client is None: | ||
return default_client() | ||
else: | ||
return self._client | ||
|
||
@client.setter | ||
def client(self, client: Client) -> None: | ||
self._client = client | ||
|
||
def _fit( | ||
self, | ||
model_factory: Type[LGBMModel], | ||
|
@@ -446,13 +465,11 @@ def _fit( | |
) -> "_DaskLGBMModel": | ||
if not all((DASK_INSTALLED, PANDAS_INSTALLED, SKLEARN_INSTALLED)): | ||
raise LightGBMError('dask, pandas and scikit-learn are required for lightgbm.dask') | ||
if client is None: | ||
client = default_client() | ||
|
||
params = self.get_params(True) | ||
|
||
model = _train( | ||
client=client, | ||
client=self.client, | ||
data=X, | ||
label=y, | ||
params=params, | ||
|
@@ -478,18 +495,58 @@ def _copy_extra_params(source: Union["_DaskLGBMModel", LGBMModel], dest: Union[" | |
attributes = source.__dict__ | ||
extra_param_names = set(attributes.keys()).difference(params.keys()) | ||
for name in extra_param_names: | ||
setattr(dest, name, attributes[name]) | ||
if name != "_client": | ||
setattr(dest, name, attributes[name]) | ||
|
||
|
||
class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel): | ||
"""Distributed version of lightgbm.LGBMClassifier.""" | ||
|
||
def __init__(self, boosting_type='gbdt', num_leaves=31, max_depth=-1, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This seems to have the desired effect! When I built the docs locally, I saw that DaskLGBMClassifier.init() LGBMClassifier.init() Why I think copying is the best alternativeI tried several other ways to update the docstrings, but none of them quite worked. I'll explain in terms of
class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel):
...
_base_doc = LGBMClassifier.__init__.__doc__
_before_kwargs, _kwargs, _after_kwargs = _base_doc.partition('**kwargs')
DaskLGBMClassifier.__init__.__doc__ = (
_before_kwargs
+ 'client : dask.distributed.Client or None, optional (default=None)\n'
+ ' ' * 12 + 'Dask client. If ``None``, ``distributed.default_client()`` will be used at runtime. This client will not be saved if the model object is pickled.\n'
+ ' ' * 8 + _kwargs + _after_kwargs
) Doing this for each of the 3 model objects, the API docs show 3 copies of the doc for
class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
_base_doc = __init__.__doc__
_before_kwargs, _kwargs, _after_kwargs = _base_doc.partition('**kwargs')
__init__.__doc__ = (
_before_kwargs
+ 'client : dask.distributed.Client or None, optional (default=None)\n'
+ ' ' * 12 + 'Dask client. If ``None``, ``distributed.default_client()`` will be used at runtime. This client will not be saved if the model object is pickled.\n'
+ ' ' * 8 + _kwargs + _after_kwargs
) This results in this error at runtime.
Both of these variations: class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel):
__init__ = LGBMClassifier.__init__
_base_doc = __init__.__doc__
_before_kwargs, _kwargs, _after_kwargs = _base_doc.partition('**kwargs')
__init__.__doc__ = (
_before_kwargs
+ 'client : dask.distributed.Client or None, optional (default=None)\n'
+ ' ' * 12 + 'Dask client. If ``None``, ``distributed.default_client()`` will be used at runtime. This client will not be saved if the model object is pickled.\n'
+ ' ' * 8 + _kwargs + _after_kwargs
) Doing this for each of the 3 model objects, the API docs show 3 copies of the doc for |
||
learning_rate=0.1, n_estimators=100, | ||
subsample_for_bin=200000, objective=None, class_weight=None, | ||
min_split_gain=0., min_child_weight=1e-3, min_child_samples=20, | ||
subsample=1., subsample_freq=0, colsample_bytree=1., | ||
reg_alpha=0., reg_lambda=0., random_state=None, | ||
n_jobs=-1, silent=True, importance_type='split', **kwargs): | ||
super().__init__( | ||
boosting_type=boosting_type, | ||
num_leaves=num_leaves, | ||
max_depth=max_depth, | ||
learning_rate=learning_rate, | ||
n_estimators=n_estimators, | ||
subsample_for_bin=subsample_for_bin, | ||
objective=objective, | ||
class_weight=class_weight, | ||
min_split_gain=min_split_gain, | ||
min_child_weight=min_child_weight, | ||
min_child_samples=min_child_samples, | ||
subsample=subsample, | ||
subsample_freq=subsample_freq, | ||
colsample_bytree=colsample_bytree, | ||
reg_alpha=reg_alpha, | ||
reg_lambda=reg_lambda, | ||
random_state=random_state, | ||
n_jobs=n_jobs, | ||
silent=silent, | ||
importance_type=importance_type, | ||
**kwargs | ||
) | ||
|
||
_base_doc = LGBMClassifier.__init__.__doc__ | ||
_before_kwargs, _kwargs, _after_kwargs = _base_doc.partition('**kwargs') | ||
__init__.__doc__ = ( | ||
_before_kwargs | ||
+ 'client : dask.distributed.Client or None, optional (default=None)\n' | ||
+ ' ' * 12 + 'Dask client. If ``None``, ``distributed.default_client()`` will be used at runtime. This client will not be saved if the model object is pickled.\n' | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. TBH, it was not clear enough that any client will not be saved. "This" may refer to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok I can change it There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. updated in 555a57a to this
|
||
+ ' ' * 8 + _kwargs + _after_kwargs | ||
) | ||
|
||
def fit( | ||
self, | ||
X: _DaskMatrixLike, | ||
y: _DaskCollection, | ||
sample_weight: Optional[_DaskCollection] = None, | ||
client: Optional[Client] = None, | ||
**kwargs: Any | ||
) -> "DaskLGBMClassifier": | ||
"""Docstring is inherited from the lightgbm.LGBMClassifier.fit.""" | ||
|
@@ -498,16 +555,11 @@ def fit( | |
X=X, | ||
y=y, | ||
sample_weight=sample_weight, | ||
client=client, | ||
client=self.client, | ||
**kwargs | ||
) | ||
|
||
_base_doc = LGBMClassifier.fit.__doc__ | ||
_before_init_score, _init_score, _after_init_score = _base_doc.partition('init_score :') | ||
fit.__doc__ = (_before_init_score | ||
+ 'client : dask.distributed.Client or None, optional (default=None)\n' | ||
+ ' ' * 12 + 'Dask client.\n' | ||
+ ' ' * 8 + _init_score + _after_init_score) | ||
fit.__doc__ = LGBMClassifier.fit.__doc__ | ||
|
||
def predict(self, X: _DaskMatrixLike, **kwargs: Any) -> dask_Array: | ||
"""Docstring is inherited from the lightgbm.LGBMClassifier.predict.""" | ||
|
@@ -545,6 +597,46 @@ def to_local(self) -> LGBMClassifier: | |
class DaskLGBMRegressor(LGBMRegressor, _DaskLGBMModel): | ||
"""Distributed version of lightgbm.LGBMRegressor.""" | ||
|
||
def __init__(self, boosting_type='gbdt', num_leaves=31, max_depth=-1, | ||
learning_rate=0.1, n_estimators=100, | ||
subsample_for_bin=200000, objective=None, class_weight=None, | ||
min_split_gain=0., min_child_weight=1e-3, min_child_samples=20, | ||
subsample=1., subsample_freq=0, colsample_bytree=1., | ||
reg_alpha=0., reg_lambda=0., random_state=None, | ||
n_jobs=-1, silent=True, importance_type='split', **kwargs): | ||
super().__init__( | ||
boosting_type=boosting_type, | ||
num_leaves=num_leaves, | ||
max_depth=max_depth, | ||
learning_rate=learning_rate, | ||
n_estimators=n_estimators, | ||
subsample_for_bin=subsample_for_bin, | ||
objective=objective, | ||
class_weight=class_weight, | ||
min_split_gain=min_split_gain, | ||
min_child_weight=min_child_weight, | ||
min_child_samples=min_child_samples, | ||
subsample=subsample, | ||
subsample_freq=subsample_freq, | ||
colsample_bytree=colsample_bytree, | ||
reg_alpha=reg_alpha, | ||
reg_lambda=reg_lambda, | ||
random_state=random_state, | ||
n_jobs=n_jobs, | ||
silent=silent, | ||
importance_type=importance_type, | ||
**kwargs | ||
) | ||
|
||
_base_doc = LGBMRegressor.__init__.__doc__ | ||
_before_kwargs, _kwargs, _after_kwargs = _base_doc.partition('**kwargs') | ||
__init__.__doc__ = ( | ||
_before_kwargs | ||
+ 'client : dask.distributed.Client or None, optional (default=None)\n' | ||
+ ' ' * 12 + 'Dask client. If ``None``, ``distributed.default_client()`` will be used at runtime. This client will not be saved if the model object is pickled.\n' | ||
+ ' ' * 8 + _kwargs + _after_kwargs | ||
) | ||
|
||
def fit( | ||
self, | ||
X: _DaskMatrixLike, | ||
|
@@ -563,12 +655,7 @@ def fit( | |
**kwargs | ||
) | ||
|
||
_base_doc = LGBMRegressor.fit.__doc__ | ||
_before_init_score, _init_score, _after_init_score = _base_doc.partition('init_score :') | ||
fit.__doc__ = (_before_init_score | ||
+ 'client : dask.distributed.Client or None, optional (default=None)\n' | ||
+ ' ' * 12 + 'Dask client.\n' | ||
+ ' ' * 8 + _init_score + _after_init_score) | ||
fit.__doc__ = LGBMRegressor.fit.__doc__ | ||
|
||
def predict(self, X: _DaskMatrixLike, **kwargs) -> dask_Array: | ||
"""Docstring is inherited from the lightgbm.LGBMRegressor.predict.""" | ||
|
@@ -594,14 +681,53 @@ def to_local(self) -> LGBMRegressor: | |
class DaskLGBMRanker(LGBMRanker, _DaskLGBMModel): | ||
"""Distributed version of lightgbm.LGBMRanker.""" | ||
|
||
def __init__(self, boosting_type='gbdt', num_leaves=31, max_depth=-1, | ||
learning_rate=0.1, n_estimators=100, | ||
subsample_for_bin=200000, objective=None, class_weight=None, | ||
min_split_gain=0., min_child_weight=1e-3, min_child_samples=20, | ||
subsample=1., subsample_freq=0, colsample_bytree=1., | ||
reg_alpha=0., reg_lambda=0., random_state=None, | ||
n_jobs=-1, silent=True, importance_type='split', **kwargs): | ||
super().__init__( | ||
boosting_type=boosting_type, | ||
num_leaves=num_leaves, | ||
max_depth=max_depth, | ||
learning_rate=learning_rate, | ||
n_estimators=n_estimators, | ||
subsample_for_bin=subsample_for_bin, | ||
objective=objective, | ||
class_weight=class_weight, | ||
min_split_gain=min_split_gain, | ||
min_child_weight=min_child_weight, | ||
min_child_samples=min_child_samples, | ||
subsample=subsample, | ||
subsample_freq=subsample_freq, | ||
colsample_bytree=colsample_bytree, | ||
reg_alpha=reg_alpha, | ||
reg_lambda=reg_lambda, | ||
random_state=random_state, | ||
n_jobs=n_jobs, | ||
silent=silent, | ||
importance_type=importance_type, | ||
**kwargs | ||
) | ||
|
||
_base_doc = LGBMRanker.__init__.__doc__ | ||
_before_kwargs, _kwargs, _after_kwargs = _base_doc.partition('**kwargs') | ||
__init__.__doc__ = ( | ||
_before_kwargs | ||
+ 'client : dask.distributed.Client or None, optional (default=None)\n' | ||
+ ' ' * 12 + 'Dask client. If ``None``, ``distributed.default_client()`` will be used at runtime. This client will not be saved if the model object is pickled.\n' | ||
+ ' ' * 8 + _kwargs + _after_kwargs | ||
) | ||
|
||
def fit( | ||
self, | ||
X: _DaskMatrixLike, | ||
y: _DaskCollection, | ||
sample_weight: Optional[_DaskCollection] = None, | ||
init_score: Optional[_DaskCollection] = None, | ||
group: Optional[_DaskCollection] = None, | ||
client: Optional[Client] = None, | ||
**kwargs: Any | ||
) -> "DaskLGBMRanker": | ||
"""Docstring is inherited from the lightgbm.LGBMRanker.fit.""" | ||
|
@@ -614,16 +740,11 @@ def fit( | |
y=y, | ||
sample_weight=sample_weight, | ||
group=group, | ||
client=client, | ||
client=self.client, | ||
**kwargs | ||
) | ||
|
||
_base_doc = LGBMRanker.fit.__doc__ | ||
_before_eval_set, _eval_set, _after_eval_set = _base_doc.partition('eval_set :') | ||
fit.__doc__ = (_before_eval_set | ||
+ 'client : dask.distributed.Client or None, optional (default=None)\n' | ||
+ ' ' * 12 + 'Dask client.\n' | ||
+ ' ' * 8 + _eval_set + _after_eval_set) | ||
fit.__doc__ = LGBMRanker.fit.__doc__ | ||
|
||
def predict(self, X: _DaskMatrixLike, **kwargs: Any) -> dask_Array: | ||
"""Docstring is inherited from the lightgbm.LGBMRanker.predict.""" | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -291,6 +291,9 @@ def __init__(self, boosting_type='gbdt', num_leaves=31, max_depth=-1, | |
if not SKLEARN_INSTALLED: | ||
raise LightGBMError('scikit-learn is required for lightgbm.sklearn') | ||
|
||
# Dask estimators inherit from this and may pass an argument "client" | ||
self._client = kwargs.pop("client", None) | ||
|
||
self.boosting_type = boosting_type | ||
self.objective = objective | ||
self.num_leaves = num_leaves | ||
|
@@ -325,6 +328,13 @@ def __init__(self, boosting_type='gbdt', num_leaves=31, max_depth=-1, | |
self._n_classes = None | ||
self.set_params(**kwargs) | ||
|
||
def __getstate__(self): | ||
"""Remove un-picklable attributes before serialization.""" | ||
client = self.__dict__.pop("_client", None) | ||
out = copy.deepcopy(self.__dict__) | ||
self._client = client | ||
return out | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I believe this can be done in Dask estimators, not in parent class. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we do it there, this code would have to be copied 3 times. Because There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, probably
Just like we currently do for all other methods. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok, good suggestion! did that in 344376b. I called it I think I'll propose a PR to do that for other methods on the |
||
|
||
def _more_tags(self): | ||
return { | ||
'allow_nan': True, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sklearn requires that unfitted models raise
NotFittedError
and checks it via accessing "post-fitted" attributes: attributes with trailing underscores.BTW, I think it will be great to setup sklearn integration tests at our CI for Dask classes. It will not allow us to be sure that our classes fully compatible with sklearn but at least will check basic compatibility. WDYT?
LightGBM/tests/python_package_test/test_sklearn.py
Line 1169 in a4cae37
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh I see, ok. Why should the check be on
._n_features
? That seems kind of indirect. Shouldn't it be on.fitted_
directly for a check of whether or not the model has been fit?I'm also confused how you would like
_fit()
to work. I'm currently passingclient=self.client_
intoself._fit()
, and relying on that to resolve whether there is a client stored on the object orget_default_client()
should be used. This suggested change would break that behavior, because of course the model is probably not fitted yet at the time you call.fit()
. What do you suggest I do for that?Sure, I think that's fine. I think it's an ok fit for this PR and will add it here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For historical reasons.
fitted_
was introduced quite recently in our sklearn wrapper. But now you can use it, indeed 👍 .Oh I see now! Can you pass
self.client
into_fit()
and there assign result of a check toself._client
? Whileclient_
will always returnself._client
or raise error? Just like all other properties in sklearn wrapper.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe in a follow-up PR? I'm afraid that it can fail now and stop merging
client
argument migration.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I want to be really sure that there is exactly one place where resolve which client to use (#3883 (comment)).
So I don't want to put this code into the body of
_fit()
.I think it would work to pull that out into a small function like this:
Then use that in both the
client_
property (after the fitted check) and in_fit()
. That would give us confidence that accessing.client_
returns the same client that will be used in training.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, sounds good!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added #3894 to capture the task of adding a scikit-learn compatibility test. I made this a "good first issue" because I think it could be done without deep knowledge of LightGBM, but I also think it's ok for you or me to pick up in the near future (we don't need to reserve it for new contributors).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok I think I've addressed these suggestions in 555a57a
LGBMNotFittedError
if accessing.client_
for a not-yet-fitted model_get_dask_cluster()
to hold the logic of usingdefault_client()
ifclient
is NoneNote that the change to have two clusters active in the same test will result in a lot of this warning in the logs:
This is not something we have to worry about, and I'd rather leave it alone and let Dask do the right thing (picking a random port for the scheduler when the default one is unavailable) than add more complexity to the tests by adding our own logic to set the scheduler port.