Skip to content

Commit

Permalink
[python][scikit-learn] change MRO (#3192)
Browse files Browse the repository at this point in the history
* chanche MRO

* fix MRO resolution
  • Loading branch information
StrikerRUS authored Apr 26, 2021
1 parent eb7a1b7 commit b6c71e5
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 6 deletions.
26 changes: 22 additions & 4 deletions python-package/lightgbm/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,22 @@ def _check_sample_weight(sample_weight, X, dtype=None):
_LGBMComputeSampleWeight = compute_sample_weight
except ImportError:
SKLEARN_INSTALLED = False
_LGBMModelBase = object
_LGBMClassifierBase = object
_LGBMRegressorBase = object

class _LGBMModelBase: # type: ignore
"""Dummy class for sklearn.base.BaseEstimator."""

pass

class _LGBMClassifierBase: # type: ignore
"""Dummy class for sklearn.base.ClassifierMixin."""

pass

class _LGBMRegressorBase: # type: ignore
"""Dummy class for sklearn.base.RegressorMixin."""

pass

_LGBMLabelEncoder = None
LGBMNotFittedError = ValueError
_LGBMStratifiedKFold = None
Expand All @@ -118,11 +131,16 @@ def _check_sample_weight(sample_weight, X, dtype=None):
DASK_INSTALLED = True
except ImportError:
DASK_INSTALLED = False

delayed = None
Client = object
default_client = None
wait = None

class Client: # type: ignore
"""Dummy class for dask.distributed.Client."""

pass

class dask_Array: # type: ignore
"""Dummy class for dask.array.Array."""

Expand Down
4 changes: 2 additions & 2 deletions python-package/lightgbm/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -804,7 +804,7 @@ def feature_name_(self):
return self._Booster.feature_name()


class LGBMRegressor(LGBMModel, _LGBMRegressorBase):
class LGBMRegressor(_LGBMRegressorBase, LGBMModel):
"""LightGBM regressor."""

def fit(self, X, y,
Expand All @@ -830,7 +830,7 @@ def fit(self, X, y,
+ _base_doc[_base_doc.find('eval_metric :'):])


class LGBMClassifier(LGBMModel, _LGBMClassifierBase):
class LGBMClassifier(_LGBMClassifierBase, LGBMModel):
"""LightGBM classifier."""

def fit(self, X, y,
Expand Down

0 comments on commit b6c71e5

Please sign in to comment.