Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[python] improved sklearn interface #870

Merged
merged 15 commits into from
Sep 5, 2017
2 changes: 1 addition & 1 deletion .travis/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ if [[ ${TASK} == "if-else" ]]; then
exit 0
fi

conda install --yes numpy scipy scikit-learn pandas matplotlib
conda install --yes numpy nose scipy scikit-learn pandas matplotlib
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do you install nose?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sklearn's checks require nose. Without it new integration test fails.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sklearn's checks require nose

oh..., I got it.

pip install pytest

if [[ ${TASK} == "sdist" ]]; then
Expand Down
2 changes: 1 addition & 1 deletion appveyor.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ test_script:
- conda config --set always_yes yes --set changeps1 no
- conda update -q conda
- conda info -a
- conda install --yes numpy scipy scikit-learn pandas matplotlib
- conda install --yes numpy nose scipy scikit-learn pandas matplotlib
- pip install pep8 pytest
- pytest tests/c_api_test/test_.py
- "set /p LGB_VER=< VERSION.txt"
Expand Down
15 changes: 15 additions & 0 deletions python-package/lightgbm/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,23 +64,38 @@ class DataFrame(object):
from sklearn.base import RegressorMixin, ClassifierMixin
from sklearn.preprocessing import LabelEncoder
from sklearn.utils import deprecated
from sklearn.utils.multiclass import check_classification_targets
from sklearn.utils.validation import check_X_y, check_array, check_consistent_length
try:
from sklearn.model_selection import StratifiedKFold, GroupKFold
from sklearn.exceptions import NotFittedError
except ImportError:
from sklearn.cross_validation import StratifiedKFold, GroupKFold
from sklearn.utils.validation import NotFittedError
SKLEARN_INSTALLED = True
LGBMModelBase = BaseEstimator
LGBMRegressorBase = RegressorMixin
LGBMClassifierBase = ClassifierMixin
LGBMLabelEncoder = LabelEncoder
LGBMDeprecated = deprecated
LGBMNotFittedError = NotFittedError
LGBMStratifiedKFold = StratifiedKFold
LGBMGroupKFold = GroupKFold
LGBMCheckXY = check_X_y
LGBMCheckArray = check_array
LGBMCheckConsistentLength = check_consistent_length
LGBMCheckClassificationTargets = check_classification_targets
except ImportError:
SKLEARN_INSTALLED = False
LGBMModelBase = object
LGBMClassifierBase = object
LGBMRegressorBase = object
LGBMLabelEncoder = None
LGBMDeprecated = None
Copy link
Contributor

@wxchan wxchan Sep 3, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I remember this will raise an error if sklearn not installed, check #221. You can uninstall scikit-learn and take a try.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought that it's just forgotten object. Sorry.
btw, in two words why these dummy variables are needed?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe some of these I don't need to specify too?

LGBMNotFittedError = ValueError 
_LGBMCheckXY = None
_LGBMCheckArray = None
_LGBMCheckConsistentLength = None
_LGBMCheckClassificationTargets = None 

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I uninstalled skikit-learn, removed the line
LGBMDeprecated = None
and
import lightgbm as lgb
didn't raise an error. Does it indicate that everything is OK?

LGBMNotFittedError = ValueError
LGBMStratifiedKFold = None
LGBMGroupKFold = None
LGBMCheckXY = None
LGBMCheckArray = None
LGBMCheckConsistentLength = None
LGBMCheckClassificationTargets = None
Loading