-
Notifications
You must be signed in to change notification settings - Fork 3.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[dask] remove 'client' kwarg from fit() and predict() (fixes #3808) #3883
Conversation
python-package/lightgbm/dask.py
Outdated
|
||
|
||
class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel): | ||
"""Distributed version of lightgbm.LGBMClassifier.""" | ||
|
||
def __init__(self, boosting_type='gbdt', num_leaves=31, max_depth=-1, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems to have the desired effect! When I built the docs locally, I saw that __init__()
docs for DaskLGBMClassifier
, DaskLGBMRegressor
, and DaskLGBMRanker
have the client
doc added, and the docs for their scikit-learn equivalents do not.
DaskLGBMClassifier.init()
LGBMClassifier.init()
Why I think copying is the best alternative
I tried several other ways to update the docstrings, but none of them quite worked. I'll explain in terms of *Classifier
, but this applies for *Regressor
and *Ranker
as well.
- Editing
DaskLGBMClassifier.__init__.doc
class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel):
...
_base_doc = LGBMClassifier.__init__.__doc__
_before_kwargs, _kwargs, _after_kwargs = _base_doc.partition('**kwargs')
DaskLGBMClassifier.__init__.__doc__ = (
_before_kwargs
+ 'client : dask.distributed.Client or None, optional (default=None)\n'
+ ' ' * 12 + 'Dask client. If ``None``, ``distributed.default_client()`` will be used at runtime. This client will not be saved if the model object is pickled.\n'
+ ' ' * 8 + _kwargs + _after_kwargs
)
Doing this for each of the 3 model objects, the API docs show 3 copies of the doc for client
. Those 3 copies also show up on the docs for lightgbm.sklearn.LGBMClassifier
.
- Setting up a pass-through
__init__()
class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
_base_doc = __init__.__doc__
_before_kwargs, _kwargs, _after_kwargs = _base_doc.partition('**kwargs')
__init__.__doc__ = (
_before_kwargs
+ 'client : dask.distributed.Client or None, optional (default=None)\n'
+ ' ' * 12 + 'Dask client. If ``None``, ``distributed.default_client()`` will be used at runtime. This client will not be saved if the model object is pickled.\n'
+ ' ' * 8 + _kwargs + _after_kwargs
)
This results in this error at runtime.
doing this to avoid:
RuntimeError: scikit-learn estimators should always specify their parameters in
the signature of their __init__ (no varargs).
<class 'lightgbm.dask.DaskLGBMClassifier'> with constructor (self, *args, **kwargs)
doesn't follow this convention.
- Just copying
__init__
Both of these variations:
class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel):
__init__ = LGBMClassifier.__init__
_base_doc = __init__.__doc__
_before_kwargs, _kwargs, _after_kwargs = _base_doc.partition('**kwargs')
__init__.__doc__ = (
_before_kwargs
+ 'client : dask.distributed.Client or None, optional (default=None)\n'
+ ' ' * 12 + 'Dask client. If ``None``, ``distributed.default_client()`` will be used at runtime. This client will not be saved if the model object is pickled.\n'
+ ' ' * 8 + _kwargs + _after_kwargs
)
Doing this for each of the 3 model objects, the API docs show 3 copies of the doc for client
. Those 3 copies also show up on the docs for lightgbm.sklearn.LGBMClassifier
. (same as in the image above)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jameslamb I'm afraid this is an antipattern in the scikit-learn world! Our classes will be incompatible with scikit-learn dask ecosystem.
This is why you got
RuntimeError: scikit-learn estimators should always specify their parameters in
the signature of their __init__ (no varargs).
I believe the best option will be passing client
in __init__
as normal argument with None
as default value. I mean the following:
class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel):
def __init__(self, ..., client=None, **kwargs):
self.client = client
super().__init__(...)
Also, do not suggest users setting client
directly via attribute. set_params(client=new_client)
should be used.
python-package/lightgbm/sklearn.py
Outdated
def __getstate__(self): | ||
"""Remove un-picklable attributes before serialization.""" | ||
client = self.__dict__.pop("_client", None) | ||
out = copy.deepcopy(self.__dict__) | ||
self._client = client | ||
return out |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe this can be done in Dask estimators, not in parent class.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we do it there, this code would have to be copied 3 times. Because _DaskLGBMModel
comes second in MRO, it wouldn't be safe to just put this on that mixin. Are you ok with that duplication?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, probably __getstate__()
is overridden somewhere above in pure scikit-learn branch of inheritance. But we can name this method as _getstate()
in _DaskLGBMModel
and call it, so no duplication will be required.
def __getstate__(self):
return self._get_state()
Just like we currently do for all other methods.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, good suggestion! did that in 344376b.
I called it _lgb_get_state
because I'm nervous about using general-sounding method names in _DaskLGBMModel
when there are so many classes upstream of LGBMClassifier
. I'm nervous about some version of scikit-learn introducing a _get_state
that would override this. We might not catch that in tests because we don't test against many different versions of scikit-learn.
I think I'll propose a PR to do that for other methods on the _DaskLGBMModel
mixin as well.
I'm really struggling to adopt this pattern. https://github.com/scikit-learn/scikit-learn/blob/b3ea3ed6a/sklearn/base.py#L152 So treating
Probably because LightGBM/python-package/lightgbm/dask.py Line 452 in 198a151
_train() . The client is in it and not pickleable.
I pushed 344376b so you can see what I tried. I really think that we should not expose I understand that you said my original proposal is a scikit-learn antipattern, but I don't think it's appropriate to treat
Given this, I really don't think the pattern I proposed violates the intention from scikit-learn.
|
Scikit-learn is really complex system with many undocumented "features", rules and assumptions about inherited classes. I guess you've already known this. So it is very likely that violating some of their rules can break a lot of integrations. Of course, inability to use hyperparameter search for Also, I'm strongly against interfering into I believe that we can achieve the goal to make |
Don't worry about it, I know you've mentioned you can't test Dask features in your local dev environment. I'll continue down this path you've suggested and we can look at the diff at the end. |
But I could use our Linux CI job 😉 . Not as comfortable as own local setup, but better than nothing. Only one thing is making me puzzled.
Given that LightGBM/python-package/lightgbm/basic.py Lines 2284 to 2303 in fd33199
|
Please don't spend time on it, I want to be respectful of your time and I think I can go faster testing locally. For your questiin about pickling...it's acceptable and expected that the model loaded from a file doesn't have .client set. A client is a live connection to a specfic Dask scheduler and you have no guarantee that exact same schefuler will still exist or be accessible from wherever you load the model. This has no impact on the correctness of the model. Load the same |
Yeah, you'll definitely be faster implementing this, but I'd like to help. Then I'm listing here some initial ideas for your consideration I wanted to try first:
|
For the last two points in my list of initial ideas I meant to find something close to the following: |
@jameslamb Seems pickling tests are still failing. Please check it.
Agree, makes sense! But please rename this property to https://scikit-learn.org/stable/developers/develop.html#parameters-and-init Please check the following example with LightGBM/python-package/lightgbm/sklearn.py Lines 211 to 214 in 763b5f3
LightGBM/python-package/lightgbm/sklearn.py Line 295 in 763b5f3
LightGBM/python-package/lightgbm/sklearn.py Line 318 in 763b5f3
LightGBM/python-package/lightgbm/sklearn.py Lines 494 to 502 in 763b5f3
LightGBM/python-package/lightgbm/sklearn.py Lines 735 to 740 in 763b5f3
|
Ok I think I've made the requested changes in af269b5. Add |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jameslamb
Awesome work!
Thanks a lot for digging deep in sklearn internals and finding way not to make a parent aware of its children. I left some comments for your consideration.
python-package/lightgbm/dask.py
Outdated
This property can be passed in the constructor or updated | ||
with ``model.set_params(client=client)``. | ||
""" | ||
if self._client is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sklearn requires that unfitted models raise NotFittedError
and checks it via accessing "post-fitted" attributes: attributes with trailing underscores.
BTW, I think it will be great to setup sklearn integration tests at our CI for Dask classes. It will not allow us to be sure that our classes fully compatible with sklearn but at least will check basic compatibility. WDYT?
def test_sklearn_integration(estimator, check): |
if self._client is None: | |
if self._n_features is None: | |
raise LGBMNotFittedError('No client found. Need to call fit beforehand.') | |
if self._client is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh I see, ok. Why should the check be on ._n_features
? That seems kind of indirect. Shouldn't it be on .fitted_
directly for a check of whether or not the model has been fit?
I'm also confused how you would like _fit()
to work. I'm currently passing client=self.client_
into self._fit()
, and relying on that to resolve whether there is a client stored on the object or get_default_client()
should be used. This suggested change would break that behavior, because of course the model is probably not fitted yet at the time you call .fit()
. What do you suggest I do for that?
BTW, I think it will be great to setup sklearn integration tests at our CI for Dask classes. It will not allow us to be sure that our classes fully compatible with sklearn but at least will check basic compatibility. WDYT
Sure, I think that's fine. I think it's an ok fit for this PR and will add it here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why should the check be on
._n_features
?
For historical reasons. fitted_
was introduced quite recently in our sklearn wrapper. But now you can use it, indeed 👍 .
I'm also confused how you would like
_fit()
to work.
Oh I see now! Can you pass self.client
into _fit()
and there assign result of a check to self._client
? While client_
will always return self._client
or raise error? Just like all other properties in sklearn wrapper.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's an ok fit for this PR and will add it here.
Maybe in a follow-up PR? I'm afraid that it can fail now and stop merging client
argument migration.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I want to be really sure that there is exactly one place where resolve which client to use (#3883 (comment)).
So I don't want to put this code into the body of _fit()
.
if self.client is None:
client = default_client()
else:
client = self.client
I think it would work to pull that out into a small function like this:
def _choose_client(client: Optional[Client]):
if self.client is None:
return default_client()
else:
return self.client
Then use that in both the client_
property (after the fitted check) and in _fit()
. That would give us confidence that accessing .client_
returns the same client that will be used in training.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, sounds good!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added #3894 to capture the task of adding a scikit-learn compatibility test. I made this a "good first issue" because I think it could be done without deep knowledge of LightGBM, but I also think it's ok for you or me to pick up in the near future (we don't need to reserve it for new contributors).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok I think I've addressed these suggestions in 555a57a
- added an
LGBMNotFittedError
if accessing.client_
for a not-yet-fitted model - added an internal function
_get_dask_cluster()
to hold the logic of usingdefault_client()
ifclient
is None - changed the unit tests to use two different clusters and clients
Note that the change to have two clusters active in the same test will result in a lot of this warning in the logs:
test_dask.py::test_model_and_local_version_are_picklable_whether_or_not_client_set_explicitly[False-ranking-cloudpickle]
/home/jlamb/miniconda3/lib/python3.7/site-packages/distributed/node.py:155: UserWarning: Port 8787 is already in use.
Perhaps you already have a cluster running?
Hosting the HTTP server on port 40743 instead
http_address["port"], self.http_server.port
This is not something we have to worry about, and I'd rather leave it alone and let Dask do the right thing (picking a random port for the scheduler when the default one is unavailable) than add more complexity to the tests by adding our own logic to set the scheduler port.
python-package/lightgbm/dask.py
Outdated
__init__.__doc__ = ( | ||
_before_kwargs | ||
+ 'client : dask.distributed.Client or None, optional (default=None)\n' | ||
+ ' ' * 12 + 'Dask client. If ``None``, ``distributed.default_client()`` will be used at runtime. This client will not be saved if the model object is pickled.\n' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TBH, it was not clear enough that any client will not be saved. "This" may refer to distributed.default_client()
, I guess, and confuse users that custom client will be saved...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok I can change it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated in 555a57a to this
The Dask client used by this class will not be saved if the model object is pickled.
dask_model = model_factory(**params) | ||
assert dask_model._client is None | ||
assert dask_model.client is None | ||
assert dask_model.client_ == client |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should raise UnfittedError
according to sklearn policy.
|
||
# should be able to set client after construction | ||
dask_model = model_factory(**params) | ||
dask_model.set_params(client=client) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we set here some custom client to ensure that path branch else
for condition if self._client is None:
works?
} | ||
|
||
if set_client: | ||
params.update({"client": client}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here about custom client to check difference between default_client()
and provided by user.
_compare_spec(lgb.DaskLGBMClassifier, lgb.LGBMClassifier) | ||
_compare_spec(lgb.DaskLGBMRegressor, lgb.LGBMRegressor) | ||
_compare_spec(lgb.DaskLGBMRanker, lgb.LGBMRanker) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe @pytest.mark.parametrize
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure, updated in 555a57a
Co-authored-by: Nikita Titov <nekit94-08@mail.ru>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jameslamb Great job! And I think this API is quite intuitive.
Just two nits.
python-package/lightgbm/dask.py
Outdated
# self._client is set in the constructor of classes that use this mixin | ||
_client: Optional[Client] = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like _client
was replaced with _get_dask_client()
function in the latest commit and is not needed anymore.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh good point! Removed in 876bfe5
def _compare_spec(dask_cls, sklearn_cls): | ||
dask_spec = inspect.getfullargspec(dask_cls) | ||
sklearn_spec = inspect.getfullargspec(sklearn_cls) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it is possible to simplify here by removing inner function and setting directly
dask_spec = inspect.getfullargspec(classes[0])
sklearn_spec = inspect.getfullargspec(classes[1])
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh right! didn't think about how using parametrize
meant this function was useless
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed this function in 876bfe5
Thanks for talking me through it! Learned a lot more about scikit-learn doing this, you're a good teacher. I've addressed the last two comments in 876bfe5. Will merge after I see if CI + readthedocs passes, and after checking the docs site. |
readthedocs looks ok to me! https://lightgbm.readthedocs.io/en/docs-jlamb/ no change to dask docs have client in them (I checked classifier, regressor, and ranker)
with the expected docs
|
@jameslamb Thank you! Sorry my help was in a form of only words not of actual code 🙁 . |
This pull request has been automatically locked since there has not been any recent activity since it was closed. To start a new related discussion, open a new issue at https://github.com/microsoft/LightGBM/issues including a reference to this. |
This PR proposes a different pattern for how users can provide a Dask client to model objects in the Dask module.
It removes the keyword argument
client
fromfit()
andpredict()
. As of this PR, users can optionally pass a client into the model's constructor. If they do that, it is saved on the model object and used byfit()
andpredict()
. Otherwise,distributed.default_client()
is used at runtime.Given that you've set up a cluster and client...
...any of the following would work as of this PR.
Why prefer this pattern?
Based on feedback from @jsignell and @martindurant in #3808 (comment) and #3808 (comment), who felt comfortable with the use of
default_client()
when none is given, and who made these excellent points:This pattern of an optional keyword-only argument in the constructor makes the signatures of
.predict()
and.fit()
identical to those in the scikit-learn model objects.I was also about to work around not being able to pickle a model object that had a Dask client as an attribute, which was my main concern with that approach.
This is very close to what
xgboost
chose (https://github.com/dmlc/xgboost/blob/d8ec7aad5a9a3eb580c55680aee8ad1a975cba20/python-package/xgboost/dask.py#L1386-L1394), except that today I believe they require users to set.client
after constructing a model object.cuml
also allows for passingclient
into the model constructor (https://github.com/rapidsai/cuml/blob/a22681c89630926c48f98ac3cb54d39bd2b91026/python/cuml/dask/common/base.py#L41-L45).Changes in this PR
client
to constructor for Dask model objects.client
after constructionclient
from.fit()
and.predict()
for Dask model objects__getstate__()
method onlightgbm.sklearn.LGBMModel
, to handle removing a stored Dask client from a model before serializing it. (see thepickle
docs for information on how this works)pickle
,joblib
, andcloudpickle
lightgbm.sklearn
equivalents into the Dask module, so that customizing documentation works correctly (see comment on the relevant part of the diff). Adds a unit test that the signatures for Dask model objects and their sklearn equivalent are identical.Notes for Reviewers
docs/jlamb
branch so we can see readthedocs buildsbreaking
because there has never been alightgbm
release with the dask module. This change does makelightgbm
's API different from the one indask-lightgbm
, but that shouldn't be considered a breaking change from LightGBM's perspectiveI know this is a small PR but with a lot of explanation and considerations. Thank you for your time and energy reviewing it!