Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dask] remove 'client' kwarg from fit() and predict() (fixes #3808) #3883

Merged
merged 21 commits into from
Feb 3, 2021
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 147 additions & 26 deletions python-package/lightgbm/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,25 @@ def _predict(

class _DaskLGBMModel:

# self._client is set in the constructor of lightgbm.sklearn.LGBMModel
_client: Optional[Client] = None

@property
def client(self) -> Client:
"""Dask client.

This property can be passed in the constructor or directly assigned
like ``model.client = client``.
"""
if self._client is None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sklearn requires that unfitted models raise NotFittedError and checks it via accessing "post-fitted" attributes: attributes with trailing underscores.
BTW, I think it will be great to setup sklearn integration tests at our CI for Dask classes. It will not allow us to be sure that our classes fully compatible with sklearn but at least will check basic compatibility. WDYT?

def test_sklearn_integration(estimator, check):

Suggested change
if self._client is None:
if self._n_features is None:
raise LGBMNotFittedError('No client found. Need to call fit beforehand.')
if self._client is None:

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see, ok. Why should the check be on ._n_features? That seems kind of indirect. Shouldn't it be on .fitted_ directly for a check of whether or not the model has been fit?

I'm also confused how you would like _fit() to work. I'm currently passing client=self.client_ into self._fit(), and relying on that to resolve whether there is a client stored on the object or get_default_client() should be used. This suggested change would break that behavior, because of course the model is probably not fitted yet at the time you call .fit(). What do you suggest I do for that?

BTW, I think it will be great to setup sklearn integration tests at our CI for Dask classes. It will not allow us to be sure that our classes fully compatible with sklearn but at least will check basic compatibility. WDYT

Sure, I think that's fine. I think it's an ok fit for this PR and will add it here.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why should the check be on ._n_features?

For historical reasons. fitted_ was introduced quite recently in our sklearn wrapper. But now you can use it, indeed 👍 .

I'm also confused how you would like _fit() to work.

Oh I see now! Can you pass self.client into _fit() and there assign result of a check to self._client? While client_ will always return self._client or raise error? Just like all other properties in sklearn wrapper.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's an ok fit for this PR and will add it here.

Maybe in a follow-up PR? I'm afraid that it can fail now and stop merging client argument migration.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I want to be really sure that there is exactly one place where resolve which client to use (#3883 (comment)).

So I don't want to put this code into the body of _fit().

if self.client is None:
    client = default_client()
else:
    client = self.client

I think it would work to pull that out into a small function like this:

def _choose_client(client: Optional[Client]):
    if self.client is None:
        return default_client()
    else:
        return self.client

Then use that in both the client_ property (after the fitted check) and in _fit(). That would give us confidence that accessing .client_ returns the same client that will be used in training.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, sounds good!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added #3894 to capture the task of adding a scikit-learn compatibility test. I made this a "good first issue" because I think it could be done without deep knowledge of LightGBM, but I also think it's ok for you or me to pick up in the near future (we don't need to reserve it for new contributors).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok I think I've addressed these suggestions in 555a57a

  • added an LGBMNotFittedError if accessing .client_ for a not-yet-fitted model
  • added an internal function _get_dask_cluster() to hold the logic of using default_client() if client is None
  • changed the unit tests to use two different clusters and clients

Note that the change to have two clusters active in the same test will result in a lot of this warning in the logs:

test_dask.py::test_model_and_local_version_are_picklable_whether_or_not_client_set_explicitly[False-ranking-cloudpickle]
/home/jlamb/miniconda3/lib/python3.7/site-packages/distributed/node.py:155: UserWarning: Port 8787 is already in use.
Perhaps you already have a cluster running?
Hosting the HTTP server on port 40743 instead
http_address["port"], self.http_server.port

This is not something we have to worry about, and I'd rather leave it alone and let Dask do the right thing (picking a random port for the scheduler when the default one is unavailable) than add more complexity to the tests by adding our own logic to set the scheduler port.

return default_client()
else:
return self._client

@client.setter
def client(self, client: Client) -> None:
self._client = client

def _fit(
self,
model_factory: Type[LGBMModel],
Expand All @@ -446,13 +465,11 @@ def _fit(
) -> "_DaskLGBMModel":
if not all((DASK_INSTALLED, PANDAS_INSTALLED, SKLEARN_INSTALLED)):
raise LightGBMError('dask, pandas and scikit-learn are required for lightgbm.dask')
if client is None:
client = default_client()

params = self.get_params(True)

model = _train(
client=client,
client=self.client,
data=X,
label=y,
params=params,
Expand All @@ -478,18 +495,58 @@ def _copy_extra_params(source: Union["_DaskLGBMModel", LGBMModel], dest: Union["
attributes = source.__dict__
extra_param_names = set(attributes.keys()).difference(params.keys())
for name in extra_param_names:
setattr(dest, name, attributes[name])
if name != "_client":
setattr(dest, name, attributes[name])


class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel):
"""Distributed version of lightgbm.LGBMClassifier."""

def __init__(self, boosting_type='gbdt', num_leaves=31, max_depth=-1,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to have the desired effect! When I built the docs locally, I saw that __init__() docs for DaskLGBMClassifier, DaskLGBMRegressor, and DaskLGBMRanker have the client doc added, and the docs for their scikit-learn equivalents do not.

DaskLGBMClassifier.init()

image

LGBMClassifier.init()

image

Why I think copying is the best alternative

I tried several other ways to update the docstrings, but none of them quite worked. I'll explain in terms of *Classifier, but this applies for *Regressor and *Ranker as well.

  1. Editing DaskLGBMClassifier.__init__.doc
class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel):
    ...

_base_doc = LGBMClassifier.__init__.__doc__
_before_kwargs, _kwargs, _after_kwargs = _base_doc.partition('**kwargs')
DaskLGBMClassifier.__init__.__doc__ = (
    _before_kwargs
    + 'client : dask.distributed.Client or None, optional (default=None)\n'
    + ' ' * 12 + 'Dask client. If ``None``, ``distributed.default_client()`` will be used at runtime. This client will not be saved if the model object is pickled.\n'
    + ' ' * 8 + _kwargs + _after_kwargs
)

Doing this for each of the 3 model objects, the API docs show 3 copies of the doc for client. Those 3 copies also show up on the docs for lightgbm.sklearn.LGBMClassifier.

Screen Shot 2021-01-30 at 11 23 24 PM

  1. Setting up a pass-through __init__()
class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    _base_doc = __init__.__doc__
    _before_kwargs, _kwargs, _after_kwargs = _base_doc.partition('**kwargs')
    __init__.__doc__ = (
        _before_kwargs
        + 'client : dask.distributed.Client or None, optional (default=None)\n'
        + ' ' * 12 + 'Dask client. If ``None``, ``distributed.default_client()`` will be used at runtime. This client will not be saved if the model object is pickled.\n'
        + ' ' * 8 + _kwargs + _after_kwargs
    )

This results in this error at runtime.

doing this to avoid:
RuntimeError: scikit-learn estimators should always specify their parameters in
the signature of their __init__ (no varargs).
<class 'lightgbm.dask.DaskLGBMClassifier'> with constructor (self, *args, **kwargs)
doesn't  follow this convention.
  1. Just copying __init__

Both of these variations:

class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel):
     __init__ = LGBMClassifier.__init__

    _base_doc = __init__.__doc__
    _before_kwargs, _kwargs, _after_kwargs = _base_doc.partition('**kwargs')
    __init__.__doc__ = (
        _before_kwargs
        + 'client : dask.distributed.Client or None, optional (default=None)\n'
        + ' ' * 12 + 'Dask client. If ``None``, ``distributed.default_client()`` will be used at runtime. This client will not be saved if the model object is pickled.\n'
        + ' ' * 8 + _kwargs + _after_kwargs
    )

Doing this for each of the 3 model objects, the API docs show 3 copies of the doc for client. Those 3 copies also show up on the docs for lightgbm.sklearn.LGBMClassifier. (same as in the image above)

learning_rate=0.1, n_estimators=100,
subsample_for_bin=200000, objective=None, class_weight=None,
min_split_gain=0., min_child_weight=1e-3, min_child_samples=20,
subsample=1., subsample_freq=0, colsample_bytree=1.,
reg_alpha=0., reg_lambda=0., random_state=None,
n_jobs=-1, silent=True, importance_type='split', **kwargs):
super().__init__(
boosting_type=boosting_type,
num_leaves=num_leaves,
max_depth=max_depth,
learning_rate=learning_rate,
n_estimators=n_estimators,
subsample_for_bin=subsample_for_bin,
objective=objective,
class_weight=class_weight,
min_split_gain=min_split_gain,
min_child_weight=min_child_weight,
min_child_samples=min_child_samples,
subsample=subsample,
subsample_freq=subsample_freq,
colsample_bytree=colsample_bytree,
reg_alpha=reg_alpha,
reg_lambda=reg_lambda,
random_state=random_state,
n_jobs=n_jobs,
silent=silent,
importance_type=importance_type,
**kwargs
)

_base_doc = LGBMClassifier.__init__.__doc__
_before_kwargs, _kwargs, _after_kwargs = _base_doc.partition('**kwargs')
__init__.__doc__ = (
_before_kwargs
+ 'client : dask.distributed.Client or None, optional (default=None)\n'
+ ' ' * 12 + 'Dask client. If ``None``, ``distributed.default_client()`` will be used at runtime. This client will not be saved if the model object is pickled.\n'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TBH, it was not clear enough that any client will not be saved. "This" may refer to distributed.default_client(), I guess, and confuse users that custom client will be saved...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok I can change it

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated in 555a57a to this

The Dask client used by this class will not be saved if the model object is pickled.

+ ' ' * 8 + _kwargs + _after_kwargs
)

def fit(
self,
X: _DaskMatrixLike,
y: _DaskCollection,
sample_weight: Optional[_DaskCollection] = None,
client: Optional[Client] = None,
**kwargs: Any
) -> "DaskLGBMClassifier":
"""Docstring is inherited from the lightgbm.LGBMClassifier.fit."""
Expand All @@ -498,16 +555,11 @@ def fit(
X=X,
y=y,
sample_weight=sample_weight,
client=client,
client=self.client,
**kwargs
)

_base_doc = LGBMClassifier.fit.__doc__
_before_init_score, _init_score, _after_init_score = _base_doc.partition('init_score :')
fit.__doc__ = (_before_init_score
+ 'client : dask.distributed.Client or None, optional (default=None)\n'
+ ' ' * 12 + 'Dask client.\n'
+ ' ' * 8 + _init_score + _after_init_score)
fit.__doc__ = LGBMClassifier.fit.__doc__

def predict(self, X: _DaskMatrixLike, **kwargs: Any) -> dask_Array:
"""Docstring is inherited from the lightgbm.LGBMClassifier.predict."""
Expand Down Expand Up @@ -545,6 +597,46 @@ def to_local(self) -> LGBMClassifier:
class DaskLGBMRegressor(LGBMRegressor, _DaskLGBMModel):
"""Distributed version of lightgbm.LGBMRegressor."""

def __init__(self, boosting_type='gbdt', num_leaves=31, max_depth=-1,
learning_rate=0.1, n_estimators=100,
subsample_for_bin=200000, objective=None, class_weight=None,
min_split_gain=0., min_child_weight=1e-3, min_child_samples=20,
subsample=1., subsample_freq=0, colsample_bytree=1.,
reg_alpha=0., reg_lambda=0., random_state=None,
n_jobs=-1, silent=True, importance_type='split', **kwargs):
super().__init__(
boosting_type=boosting_type,
num_leaves=num_leaves,
max_depth=max_depth,
learning_rate=learning_rate,
n_estimators=n_estimators,
subsample_for_bin=subsample_for_bin,
objective=objective,
class_weight=class_weight,
min_split_gain=min_split_gain,
min_child_weight=min_child_weight,
min_child_samples=min_child_samples,
subsample=subsample,
subsample_freq=subsample_freq,
colsample_bytree=colsample_bytree,
reg_alpha=reg_alpha,
reg_lambda=reg_lambda,
random_state=random_state,
n_jobs=n_jobs,
silent=silent,
importance_type=importance_type,
**kwargs
)

_base_doc = LGBMRegressor.__init__.__doc__
_before_kwargs, _kwargs, _after_kwargs = _base_doc.partition('**kwargs')
__init__.__doc__ = (
_before_kwargs
+ 'client : dask.distributed.Client or None, optional (default=None)\n'
+ ' ' * 12 + 'Dask client. If ``None``, ``distributed.default_client()`` will be used at runtime. This client will not be saved if the model object is pickled.\n'
+ ' ' * 8 + _kwargs + _after_kwargs
)

def fit(
self,
X: _DaskMatrixLike,
Expand All @@ -563,12 +655,7 @@ def fit(
**kwargs
)

_base_doc = LGBMRegressor.fit.__doc__
_before_init_score, _init_score, _after_init_score = _base_doc.partition('init_score :')
fit.__doc__ = (_before_init_score
+ 'client : dask.distributed.Client or None, optional (default=None)\n'
+ ' ' * 12 + 'Dask client.\n'
+ ' ' * 8 + _init_score + _after_init_score)
fit.__doc__ = LGBMRegressor.fit.__doc__

def predict(self, X: _DaskMatrixLike, **kwargs) -> dask_Array:
"""Docstring is inherited from the lightgbm.LGBMRegressor.predict."""
Expand All @@ -594,14 +681,53 @@ def to_local(self) -> LGBMRegressor:
class DaskLGBMRanker(LGBMRanker, _DaskLGBMModel):
"""Distributed version of lightgbm.LGBMRanker."""

def __init__(self, boosting_type='gbdt', num_leaves=31, max_depth=-1,
learning_rate=0.1, n_estimators=100,
subsample_for_bin=200000, objective=None, class_weight=None,
min_split_gain=0., min_child_weight=1e-3, min_child_samples=20,
subsample=1., subsample_freq=0, colsample_bytree=1.,
reg_alpha=0., reg_lambda=0., random_state=None,
n_jobs=-1, silent=True, importance_type='split', **kwargs):
super().__init__(
boosting_type=boosting_type,
num_leaves=num_leaves,
max_depth=max_depth,
learning_rate=learning_rate,
n_estimators=n_estimators,
subsample_for_bin=subsample_for_bin,
objective=objective,
class_weight=class_weight,
min_split_gain=min_split_gain,
min_child_weight=min_child_weight,
min_child_samples=min_child_samples,
subsample=subsample,
subsample_freq=subsample_freq,
colsample_bytree=colsample_bytree,
reg_alpha=reg_alpha,
reg_lambda=reg_lambda,
random_state=random_state,
n_jobs=n_jobs,
silent=silent,
importance_type=importance_type,
**kwargs
)

_base_doc = LGBMRanker.__init__.__doc__
_before_kwargs, _kwargs, _after_kwargs = _base_doc.partition('**kwargs')
__init__.__doc__ = (
_before_kwargs
+ 'client : dask.distributed.Client or None, optional (default=None)\n'
+ ' ' * 12 + 'Dask client. If ``None``, ``distributed.default_client()`` will be used at runtime. This client will not be saved if the model object is pickled.\n'
+ ' ' * 8 + _kwargs + _after_kwargs
)

def fit(
self,
X: _DaskMatrixLike,
y: _DaskCollection,
sample_weight: Optional[_DaskCollection] = None,
init_score: Optional[_DaskCollection] = None,
group: Optional[_DaskCollection] = None,
client: Optional[Client] = None,
**kwargs: Any
) -> "DaskLGBMRanker":
"""Docstring is inherited from the lightgbm.LGBMRanker.fit."""
Expand All @@ -614,16 +740,11 @@ def fit(
y=y,
sample_weight=sample_weight,
group=group,
client=client,
client=self.client,
**kwargs
)

_base_doc = LGBMRanker.fit.__doc__
_before_eval_set, _eval_set, _after_eval_set = _base_doc.partition('eval_set :')
fit.__doc__ = (_before_eval_set
+ 'client : dask.distributed.Client or None, optional (default=None)\n'
+ ' ' * 12 + 'Dask client.\n'
+ ' ' * 8 + _eval_set + _after_eval_set)
fit.__doc__ = LGBMRanker.fit.__doc__

def predict(self, X: _DaskMatrixLike, **kwargs: Any) -> dask_Array:
"""Docstring is inherited from the lightgbm.LGBMRanker.predict."""
Expand Down
10 changes: 10 additions & 0 deletions python-package/lightgbm/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,9 @@ def __init__(self, boosting_type='gbdt', num_leaves=31, max_depth=-1,
if not SKLEARN_INSTALLED:
raise LightGBMError('scikit-learn is required for lightgbm.sklearn')

# Dask estimators inherit from this and may pass an argument "client"
self._client = kwargs.pop("client", None)

self.boosting_type = boosting_type
self.objective = objective
self.num_leaves = num_leaves
Expand Down Expand Up @@ -325,6 +328,13 @@ def __init__(self, boosting_type='gbdt', num_leaves=31, max_depth=-1,
self._n_classes = None
self.set_params(**kwargs)

def __getstate__(self):
"""Remove un-picklable attributes before serialization."""
client = self.__dict__.pop("_client", None)
out = copy.deepcopy(self.__dict__)
self._client = client
return out
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this can be done in Dask estimators, not in parent class.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we do it there, this code would have to be copied 3 times. Because _DaskLGBMModel comes second in MRO, it wouldn't be safe to just put this on that mixin. Are you ok with that duplication?

Copy link
Collaborator

@StrikerRUS StrikerRUS Jan 31, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, probably __getstate__() is overridden somewhere above in pure scikit-learn branch of inheritance. But we can name this method as _getstate() in _DaskLGBMModel and call it, so no duplication will be required.

def __getstate__(self):
   return self._get_state()

Just like we currently do for all other methods.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, good suggestion! did that in 344376b.

I called it _lgb_get_state because I'm nervous about using general-sounding method names in _DaskLGBMModel when there are so many classes upstream of LGBMClassifier. I'm nervous about some version of scikit-learn introducing a _get_state that would override this. We might not catch that in tests because we don't test against many different versions of scikit-learn.

I think I'll propose a PR to do that for other methods on the _DaskLGBMModel mixin as well.


def _more_tags(self):
return {
'allow_nan': True,
Expand Down
Loading