Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[python] improved sklearn interface #870

Merged
merged 15 commits into from
Sep 5, 2017
2 changes: 1 addition & 1 deletion .travis/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ if [[ ${TASK} == "if-else" ]]; then
exit 0
fi

conda install --yes numpy scipy scikit-learn pandas matplotlib
conda install --yes numpy nose scipy scikit-learn pandas matplotlib
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do you install nose?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sklearn's checks require nose. Without it new integration test fails.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sklearn's checks require nose

oh..., I got it.

pip install pytest

if [[ ${TASK} == "sdist" ]]; then
Expand Down
2 changes: 1 addition & 1 deletion appveyor.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ test_script:
- conda config --set always_yes yes --set changeps1 no
- conda update -q conda
- conda info -a
- conda install --yes numpy scipy scikit-learn pandas matplotlib
- conda install --yes numpy nose scipy scikit-learn pandas matplotlib
- pip install pep8 pytest
- pytest tests/c_api_test/test_.py
- "set /p LGB_VER=< VERSION.txt"
Expand Down
2 changes: 1 addition & 1 deletion examples/python-guide/sklearn_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

print('Start predicting...')
# predict
y_pred = gbm.predict(X_test, num_iteration=gbm.best_iteration)
y_pred = gbm.predict(X_test, num_iteration=gbm.best_iteration_)
# eval
print('The rmse of prediction is:', mean_squared_error(y_test, y_pred) ** 0.5)

Expand Down
41 changes: 28 additions & 13 deletions python-package/lightgbm/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,23 +64,38 @@ class DataFrame(object):
from sklearn.base import RegressorMixin, ClassifierMixin
from sklearn.preprocessing import LabelEncoder
from sklearn.utils import deprecated
from sklearn.utils.multiclass import check_classification_targets
from sklearn.utils.validation import check_X_y, check_array, check_consistent_length
try:
from sklearn.model_selection import StratifiedKFold, GroupKFold
from sklearn.exceptions import NotFittedError
except ImportError:
from sklearn.cross_validation import StratifiedKFold, GroupKFold
from sklearn.utils.validation import NotFittedError
SKLEARN_INSTALLED = True
LGBMModelBase = BaseEstimator
LGBMRegressorBase = RegressorMixin
LGBMClassifierBase = ClassifierMixin
LGBMLabelEncoder = LabelEncoder
_LGBMModelBase = BaseEstimator
_LGBMRegressorBase = RegressorMixin
_LGBMClassifierBase = ClassifierMixin
_LGBMLabelEncoder = LabelEncoder
LGBMDeprecated = deprecated
LGBMStratifiedKFold = StratifiedKFold
LGBMGroupKFold = GroupKFold
LGBMNotFittedError = NotFittedError
_LGBMStratifiedKFold = StratifiedKFold
_LGBMGroupKFold = GroupKFold
_LGBMCheckXY = check_X_y
_LGBMCheckArray = check_array
_LGBMCheckConsistentLength = check_consistent_length
_LGBMCheckClassificationTargets = check_classification_targets
except ImportError:
SKLEARN_INSTALLED = False
LGBMModelBase = object
LGBMClassifierBase = object
LGBMRegressorBase = object
LGBMLabelEncoder = None
LGBMStratifiedKFold = None
LGBMGroupKFold = None
_SKLEARN_INSTALLED = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

look like a typo? all other places are SKLEARN_INSTALLED

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops, sorry.

_LGBMModelBase = object
_LGBMClassifierBase = object
_LGBMRegressorBase = object
_LGBMLabelEncoder = None
LGBMDeprecated = None
Copy link
Contributor

@wxchan wxchan Sep 3, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I remember this will raise an error if sklearn not installed, check #221. You can uninstall scikit-learn and take a try.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought that it's just forgotten object. Sorry.
btw, in two words why these dummy variables are needed?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe some of these I don't need to specify too?

LGBMNotFittedError = ValueError 
_LGBMCheckXY = None
_LGBMCheckArray = None
_LGBMCheckConsistentLength = None
_LGBMCheckClassificationTargets = None 

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I uninstalled skikit-learn, removed the line
LGBMDeprecated = None
and
import lightgbm as lgb
didn't raise an error. Does it indicate that everything is OK?

LGBMNotFittedError = ValueError
_LGBMStratifiedKFold = None
_LGBMGroupKFold = None
_LGBMCheckXY = None
_LGBMCheckArray = None
_LGBMCheckConsistentLength = None
_LGBMCheckClassificationTargets = None
6 changes: 3 additions & 3 deletions python-package/lightgbm/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from . import callback
from .basic import Booster, Dataset, LightGBMError, _InnerPredictor
from .compat import (SKLEARN_INSTALLED, LGBMGroupKFold, LGBMStratifiedKFold,
from .compat import (SKLEARN_INSTALLED, _LGBMGroupKFold, _LGBMStratifiedKFold,
integer_types, range_, string_type)


Expand Down Expand Up @@ -264,12 +264,12 @@ def _make_n_folds(full_data, folds, nfold, params, seed, fpreproc=None, stratifi
# lambdarank task, split according to groups
group_info = full_data.get_group().astype(int)
flatted_group = np.repeat(range(len(group_info)), repeats=group_info)
group_kfold = LGBMGroupKFold(n_splits=nfold)
group_kfold = _LGBMGroupKFold(n_splits=nfold)
folds = group_kfold.split(X=np.zeros(num_data), groups=flatted_group)
elif stratified:
if not SKLEARN_INSTALLED:
raise LightGBMError('Scikit-learn is required for stratified cv.')
skf = LGBMStratifiedKFold(n_splits=nfold, shuffle=shuffle, random_state=seed)
skf = _LGBMStratifiedKFold(n_splits=nfold, shuffle=shuffle, random_state=seed)
folds = skf.split(X=np.zeros(num_data), y=full_data.get_label())
else:
if shuffle:
Expand Down
Loading