Skip to content

Commit

Permalink
[python] improved sklearn interface (#870)
Browse files Browse the repository at this point in the history
* improved sklearn interface; added sklearns' tests

* moved best_score into the if statement

* improved docstrings; simplified LGBMCheckConsistentLength

* fixed typo

* pylint

* updated example

* fixed Ranker interface

* added missed boosting_type

* fixed more comfortable autocomplete without unused objects

* removed check for None of eval_at

* fixed according to review

* fixed typo

* added description of fit return type

* dictionary->dict for short

* markdown cleanup
  • Loading branch information
StrikerRUS authored and guolinke committed Sep 5, 2017
1 parent 898c88d commit 015c8ff
Show file tree
Hide file tree
Showing 7 changed files with 387 additions and 246 deletions.
2 changes: 1 addition & 1 deletion .travis/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ if [[ ${TASK} == "if-else" ]]; then
exit 0
fi

conda install --yes numpy scipy scikit-learn pandas matplotlib
conda install --yes numpy nose scipy scikit-learn pandas matplotlib
pip install pytest

if [[ ${TASK} == "sdist" ]]; then
Expand Down
2 changes: 1 addition & 1 deletion appveyor.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ test_script:
- conda config --set always_yes yes --set changeps1 no
- conda update -q conda
- conda info -a
- conda install --yes numpy scipy scikit-learn pandas matplotlib
- conda install --yes numpy nose scipy scikit-learn pandas matplotlib
- pip install pep8 pytest
- pytest tests/c_api_test/test_.py
- "set /p LGB_VER=< VERSION.txt"
Expand Down
2 changes: 1 addition & 1 deletion examples/python-guide/sklearn_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

print('Start predicting...')
# predict
y_pred = gbm.predict(X_test, num_iteration=gbm.best_iteration)
y_pred = gbm.predict(X_test, num_iteration=gbm.best_iteration_)
# eval
print('The rmse of prediction is:', mean_squared_error(y_test, y_pred) ** 0.5)

Expand Down
39 changes: 27 additions & 12 deletions python-package/lightgbm/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,23 +64,38 @@ class DataFrame(object):
from sklearn.base import RegressorMixin, ClassifierMixin
from sklearn.preprocessing import LabelEncoder
from sklearn.utils import deprecated
from sklearn.utils.multiclass import check_classification_targets
from sklearn.utils.validation import check_X_y, check_array, check_consistent_length
try:
from sklearn.model_selection import StratifiedKFold, GroupKFold
from sklearn.exceptions import NotFittedError
except ImportError:
from sklearn.cross_validation import StratifiedKFold, GroupKFold
from sklearn.utils.validation import NotFittedError
SKLEARN_INSTALLED = True
LGBMModelBase = BaseEstimator
LGBMRegressorBase = RegressorMixin
LGBMClassifierBase = ClassifierMixin
LGBMLabelEncoder = LabelEncoder
_LGBMModelBase = BaseEstimator
_LGBMRegressorBase = RegressorMixin
_LGBMClassifierBase = ClassifierMixin
_LGBMLabelEncoder = LabelEncoder
LGBMDeprecated = deprecated
LGBMStratifiedKFold = StratifiedKFold
LGBMGroupKFold = GroupKFold
LGBMNotFittedError = NotFittedError
_LGBMStratifiedKFold = StratifiedKFold
_LGBMGroupKFold = GroupKFold
_LGBMCheckXY = check_X_y
_LGBMCheckArray = check_array
_LGBMCheckConsistentLength = check_consistent_length
_LGBMCheckClassificationTargets = check_classification_targets
except ImportError:
SKLEARN_INSTALLED = False
LGBMModelBase = object
LGBMClassifierBase = object
LGBMRegressorBase = object
LGBMLabelEncoder = None
LGBMStratifiedKFold = None
LGBMGroupKFold = None
_LGBMModelBase = object
_LGBMClassifierBase = object
_LGBMRegressorBase = object
_LGBMLabelEncoder = None
# LGBMDeprecated = None Don't uncomment it because it causes error without installed sklearn
LGBMNotFittedError = ValueError
_LGBMStratifiedKFold = None
_LGBMGroupKFold = None
_LGBMCheckXY = None
_LGBMCheckArray = None
_LGBMCheckConsistentLength = None
_LGBMCheckClassificationTargets = None
6 changes: 3 additions & 3 deletions python-package/lightgbm/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from . import callback
from .basic import Booster, Dataset, LightGBMError, _InnerPredictor
from .compat import (SKLEARN_INSTALLED, LGBMGroupKFold, LGBMStratifiedKFold,
from .compat import (SKLEARN_INSTALLED, _LGBMGroupKFold, _LGBMStratifiedKFold,
integer_types, range_, string_type)


Expand Down Expand Up @@ -264,12 +264,12 @@ def _make_n_folds(full_data, folds, nfold, params, seed, fpreproc=None, stratifi
# lambdarank task, split according to groups
group_info = full_data.get_group().astype(int)
flatted_group = np.repeat(range(len(group_info)), repeats=group_info)
group_kfold = LGBMGroupKFold(n_splits=nfold)
group_kfold = _LGBMGroupKFold(n_splits=nfold)
folds = group_kfold.split(X=np.zeros(num_data), groups=flatted_group)
elif stratified:
if not SKLEARN_INSTALLED:
raise LightGBMError('Scikit-learn is required for stratified cv.')
skf = LGBMStratifiedKFold(n_splits=nfold, shuffle=shuffle, random_state=seed)
skf = _LGBMStratifiedKFold(n_splits=nfold, shuffle=shuffle, random_state=seed)
folds = skf.split(X=np.zeros(num_data), y=full_data.get_label())
else:
if shuffle:
Expand Down
Loading

0 comments on commit 015c8ff

Please sign in to comment.