diff --git a/.travis/test.sh b/.travis/test.sh index a56dc232348a..605f77731e5c 100644 --- a/.travis/test.sh +++ b/.travis/test.sh @@ -40,7 +40,7 @@ if [[ ${TASK} == "if-else" ]]; then exit 0 fi -conda install --yes numpy scipy scikit-learn pandas matplotlib +conda install --yes numpy nose scipy scikit-learn pandas matplotlib pip install pytest if [[ ${TASK} == "sdist" ]]; then diff --git a/appveyor.yml b/appveyor.yml index 9bc3eefeeba9..b7e8eb1956f4 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -21,7 +21,7 @@ test_script: - conda config --set always_yes yes --set changeps1 no - conda update -q conda - conda info -a - - conda install --yes numpy scipy scikit-learn pandas matplotlib + - conda install --yes numpy nose scipy scikit-learn pandas matplotlib - pip install pep8 pytest - pytest tests/c_api_test/test_.py - "set /p LGB_VER=< VERSION.txt" diff --git a/examples/python-guide/sklearn_example.py b/examples/python-guide/sklearn_example.py index 3700fcc0d067..8c2f8aa1b5fc 100644 --- a/examples/python-guide/sklearn_example.py +++ b/examples/python-guide/sklearn_example.py @@ -28,7 +28,7 @@ print('Start predicting...') # predict -y_pred = gbm.predict(X_test, num_iteration=gbm.best_iteration) +y_pred = gbm.predict(X_test, num_iteration=gbm.best_iteration_) # eval print('The rmse of prediction is:', mean_squared_error(y_test, y_pred) ** 0.5) diff --git a/python-package/lightgbm/compat.py b/python-package/lightgbm/compat.py index 3143e5231b61..e03343ca4cd2 100644 --- a/python-package/lightgbm/compat.py +++ b/python-package/lightgbm/compat.py @@ -64,23 +64,38 @@ class DataFrame(object): from sklearn.base import RegressorMixin, ClassifierMixin from sklearn.preprocessing import LabelEncoder from sklearn.utils import deprecated + from sklearn.utils.multiclass import check_classification_targets + from sklearn.utils.validation import check_X_y, check_array, check_consistent_length try: from sklearn.model_selection import StratifiedKFold, GroupKFold + from sklearn.exceptions import NotFittedError except ImportError: from sklearn.cross_validation import StratifiedKFold, GroupKFold + from sklearn.utils.validation import NotFittedError SKLEARN_INSTALLED = True - LGBMModelBase = BaseEstimator - LGBMRegressorBase = RegressorMixin - LGBMClassifierBase = ClassifierMixin - LGBMLabelEncoder = LabelEncoder + _LGBMModelBase = BaseEstimator + _LGBMRegressorBase = RegressorMixin + _LGBMClassifierBase = ClassifierMixin + _LGBMLabelEncoder = LabelEncoder LGBMDeprecated = deprecated - LGBMStratifiedKFold = StratifiedKFold - LGBMGroupKFold = GroupKFold + LGBMNotFittedError = NotFittedError + _LGBMStratifiedKFold = StratifiedKFold + _LGBMGroupKFold = GroupKFold + _LGBMCheckXY = check_X_y + _LGBMCheckArray = check_array + _LGBMCheckConsistentLength = check_consistent_length + _LGBMCheckClassificationTargets = check_classification_targets except ImportError: SKLEARN_INSTALLED = False - LGBMModelBase = object - LGBMClassifierBase = object - LGBMRegressorBase = object - LGBMLabelEncoder = None - LGBMStratifiedKFold = None - LGBMGroupKFold = None + _LGBMModelBase = object + _LGBMClassifierBase = object + _LGBMRegressorBase = object + _LGBMLabelEncoder = None +# LGBMDeprecated = None Don't uncomment it because it causes error without installed sklearn + LGBMNotFittedError = ValueError + _LGBMStratifiedKFold = None + _LGBMGroupKFold = None + _LGBMCheckXY = None + _LGBMCheckArray = None + _LGBMCheckConsistentLength = None + _LGBMCheckClassificationTargets = None diff --git a/python-package/lightgbm/engine.py b/python-package/lightgbm/engine.py index f9e44641c89a..482427b41fb6 100644 --- a/python-package/lightgbm/engine.py +++ b/python-package/lightgbm/engine.py @@ -11,7 +11,7 @@ from . import callback from .basic import Booster, Dataset, LightGBMError, _InnerPredictor -from .compat import (SKLEARN_INSTALLED, LGBMGroupKFold, LGBMStratifiedKFold, +from .compat import (SKLEARN_INSTALLED, _LGBMGroupKFold, _LGBMStratifiedKFold, integer_types, range_, string_type) @@ -264,12 +264,12 @@ def _make_n_folds(full_data, folds, nfold, params, seed, fpreproc=None, stratifi # lambdarank task, split according to groups group_info = full_data.get_group().astype(int) flatted_group = np.repeat(range(len(group_info)), repeats=group_info) - group_kfold = LGBMGroupKFold(n_splits=nfold) + group_kfold = _LGBMGroupKFold(n_splits=nfold) folds = group_kfold.split(X=np.zeros(num_data), groups=flatted_group) elif stratified: if not SKLEARN_INSTALLED: raise LightGBMError('Scikit-learn is required for stratified cv.') - skf = LGBMStratifiedKFold(n_splits=nfold, shuffle=shuffle, random_state=seed) + skf = _LGBMStratifiedKFold(n_splits=nfold, shuffle=shuffle, random_state=seed) folds = skf.split(X=np.zeros(num_data), y=full_data.get_label()) else: if shuffle: diff --git a/python-package/lightgbm/sklearn.py b/python-package/lightgbm/sklearn.py index c9ef7b17562d..f462fae03a36 100644 --- a/python-package/lightgbm/sklearn.py +++ b/python-package/lightgbm/sklearn.py @@ -7,9 +7,10 @@ import warnings from .basic import Dataset, LightGBMError -from .compat import (SKLEARN_INSTALLED, LGBMClassifierBase, LGBMDeprecated, - LGBMLabelEncoder, LGBMModelBase, LGBMRegressorBase, argc_, - range_) +from .compat import (SKLEARN_INSTALLED, _LGBMClassifierBase, LGBMDeprecated, + LGBMNotFittedError, _LGBMLabelEncoder, _LGBMModelBase, + _LGBMRegressorBase, _LGBMCheckXY, _LGBMCheckArray, _LGBMCheckConsistentLength, + _LGBMCheckClassificationTargets, argc_, range_) from .engine import train @@ -20,19 +21,20 @@ class LGBMDeprecationWarning(UserWarning): def _objective_function_wrapper(func): """Decorate an objective function - Note: for multi-class task, the y_pred is group by class_id first, then group by row_id - if you want to get i-th row y_pred in j-th class, the access way is y_pred[j*num_data+i] - and you should group grad and hess in this way as well + Note: for multi-class task, the y_pred is group by class_id first, then group by row_id. + If you want to get i-th row y_pred in j-th class, the access way is y_pred[j * num_data + i] + and you should group grad and hess in this way as well. + Parameters ---------- func: callable Expects a callable with signature ``func(y_true, y_pred)`` or ``func(y_true, y_pred, group): - y_true: array_like of shape [n_samples] - The target values - y_pred: array_like of shape [n_samples] or shape[n_samples * n_class] (for multi-class) - The predicted values - group: array_like - group/query data, used for ranking task + y_true: array-like of shape = [n_samples] + The target values. + y_pred: array-like of shape = [n_samples] or shape = [n_samples * n_classes] (for multi-class) + The predicted values. + group: array-like + Group/query data, used for ranking task. Returns ------- @@ -40,11 +42,11 @@ def _objective_function_wrapper(func): The new objective function as expected by ``lightgbm.engine.train``. The signature is ``new_func(preds, dataset)``: - preds: array_like, shape [n_samples] or shape[n_samples * n_class] - The predicted values + preds: array-like of shape = [n_samples] or shape = [n_samples * n_classes] + The predicted values. dataset: ``dataset`` The training set from which the labels will be extracted using - ``dataset.get_label()`` + ``dataset.get_label()``. """ def inner(preds, dataset): """internal function""" @@ -79,8 +81,9 @@ def inner(preds, dataset): def _eval_function_wrapper(func): """Decorate an eval function - Note: for multi-class task, the y_pred is group by class_id first, then group by row_id - if you want to get i-th row y_pred in j-th class, the access way is y_pred[j*num_data+i] + Note: for multi-class task, the y_pred is group by class_id first, then group by row_id. + If you want to get i-th row y_pred in j-th class, the access way is y_pred[j * num_data + i]. + Parameters ---------- func: callable @@ -90,14 +93,14 @@ def _eval_function_wrapper(func): or ``func(y_true, y_pred, weight, group)`` and return (eval_name->str, eval_result->float, is_bigger_better->Bool): - y_true: array_like of shape [n_samples] - The target values - y_pred: array_like of shape [n_samples] or shape[n_samples * n_class] (for multi-class) - The predicted values - weight: array_like of shape [n_samples] - The weight of samples - group: array_like - group/query data, used for ranking task + y_true: array-like of shape = [n_samples] + The target values. + y_pred: array-like of shape = [n_samples] or shape = [n_samples * n_classes] (for multi-class) + The predicted values. + weight: array_like of shape = [n_samples] + The weight of samples. + group: array-like + Group/query data, used for ranking task. Returns ------- @@ -105,11 +108,11 @@ def _eval_function_wrapper(func): The new eval function as expected by ``lightgbm.engine.train``. The signature is ``new_func(preds, dataset)``: - preds: array_like, shape [n_samples] or shape[n_samples * n_class] - The predicted values + preds: array-like of shape = [n_samples] or shape = [n_samples * n_classes] + The predicted values. dataset: ``dataset`` The training set from which the labels will be extracted using - ``dataset.get_label()`` + ``dataset.get_label()``. """ def inner(preds, dataset): """internal function""" @@ -126,102 +129,115 @@ def inner(preds, dataset): return inner -class LGBMModel(LGBMModelBase): +class LGBMModel(_LGBMModelBase): + """Implementation of the scikit-learn API for LightGBM.""" def __init__(self, boosting_type="gbdt", num_leaves=31, max_depth=-1, learning_rate=0.1, n_estimators=10, max_bin=255, subsample_for_bin=50000, objective=None, - min_split_gain=0, min_child_weight=5, min_child_samples=10, - subsample=1, subsample_freq=1, colsample_bytree=1, - reg_alpha=0, reg_lambda=0, random_state=0, + min_split_gain=0., min_child_weight=5, min_child_samples=10, + subsample=1., subsample_freq=1, colsample_bytree=1., + reg_alpha=0., reg_lambda=0., random_state=0, n_jobs=-1, silent=True, **kwargs): - """ - Implementation of the Scikit-Learn API for LightGBM. + """Construct a gradient boosting model. Parameters ---------- - boosting_type : string - gbdt, traditional Gradient Boosting Decision Tree. - dart, Dropouts meet Multiple Additive Regression Trees. - num_leaves : int + boosting_type : string, optional (default="gbdt") + 'gbdt', traditional Gradient Boosting Decision Tree. + 'dart', Dropouts meet Multiple Additive Regression Trees. + 'goss', Gradient-based One-Side Sampling. + 'rf', Random Forest. + num_leaves : int, optional (default=31) Maximum tree leaves for base learners. - max_depth : int + max_depth : int, optional (default=-1) Maximum tree depth for base learners, -1 means no limit. - learning_rate : float + learning_rate : float, optional (default=0.1) Boosting learning rate. - n_estimators : int + n_estimators : int, optional (default=10) Number of boosted trees to fit. - max_bin : int + max_bin : int, optional (default=255) Number of bucketed bin for feature values. - subsample_for_bin : int + subsample_for_bin : int, optional (default=50000) Number of samples for constructing bins. - objective : string or callable + objective : string, callable or None, optional (default=None) Specify the learning task and the corresponding learning objective or a custom objective function to be used (see note below). - default: binary for LGBMClassifier, lambdarank for LGBMRanker. - min_split_gain : float + default: 'binary' for LGBMClassifier, 'lambdarank' for LGBMRanker. + min_split_gain : float, optional (default=0.) Minimum loss reduction required to make a further partition on a leaf node of the tree. - min_child_weight : int + min_child_weight : int, optional (default=5) Minimum sum of instance weight(hessian) needed in a child(leaf). - min_child_samples : int + min_child_samples : int, optional (default=10) Minimum number of data need in a child(leaf). - subsample : float + subsample : float, optional (default=1.) Subsample ratio of the training instance. - subsample_freq : int - frequence of subsample, <=0 means no enable. - colsample_bytree : float + subsample_freq : int, optional (default=1) + Frequence of subsample, <=0 means no enable. + colsample_bytree : float, optional (default=1.) Subsample ratio of columns when constructing each tree. - reg_alpha : float + reg_alpha : float, optional (default=0.) L1 regularization term on weights. - reg_lambda : float + reg_lambda : float, optional (default=0.) L2 regularization term on weights. - random_state : int + random_state : int, optional (default=0) Random number seed. - n_jobs : int + n_jobs : int, optional (default=-1) Number of parallel threads. - silent : boolean + silent : bool, optional (default=True) Whether to print messages while running boosting. **kwargs : other parameters Check http://lightgbm.readthedocs.io/en/latest/Parameters.html for more parameters. Note: **kwargs is not supported in sklearn, it may cause unexpected issues. + Attributes + ---------- + n_features_ : int + The number of features of fitted model. + classes_ : array of shape = [n_classes] + The class label array (only for classification problem). + n_classes_ : int + The number of classes (only for classification problem). + best_score_ : dict or None + The best score of fitted model if `early_stopping_rounds` has been specified. + best_iteration_ : int or None + The best iteration of fitted model if `early_stopping_rounds` has been specified. + objective_ : string or callable + The concrete objective used while fitting this model. + booster_ : Booster + The underlying Booster of this model. + evals_result_ : dict or None + The evaluation results if `early_stopping_rounds` has been specified. + feature_importances_ : array of shape = [n_features] + The feature importances (the higher, the more important the feature). + Note ---- A custom objective function can be provided for the ``objective`` parameter. In this case, it should have the signature - ``objective(y_true, y_pred) -> grad, hess`` - or ``objective(y_true, y_pred, group) -> grad, hess``: + ``objective(y_true, y_pred) -> grad, hess`` or + ``objective(y_true, y_pred, group) -> grad, hess``: - y_true: array_like of shape [n_samples] + y_true: array-like of shape = [n_samples] The target values. - y_pred: array_like of shape [n_samples] or shape[n_samples * n_class] + y_pred: array-like of shape = [n_samples] or shape = [n_samples * n_classes] (for multi-class task) The predicted values. - group: array_like - group/query data, used for ranking task. - grad: array_like of shape [n_samples] or shape[n_samples * n_class] + group: array-like + Group/query data, used for ranking task. + grad: array-like of shape = [n_samples] or shape = [n_samples * n_classes] (for multi-class task) The value of the gradient for each sample point. - hess: array_like of shape [n_samples] or shape[n_samples * n_class] + hess: array-like of shape = [n_samples] or shape = [n_samples * n_classes] (for multi-class task) The value of the second derivative for each sample point. - for multi-class task, the y_pred is group by class_id first, then group by row_id - if you want to get i-th row y_pred in j-th class, the access way is y_pred[j*num_data+i] - and you should group grad and hess in this way as well + For multi-class task, the y_pred is group by class_id first, then group by row_id. + If you want to get i-th row y_pred in j-th class, the access way is y_pred[j * num_data + i] + and you should group grad and hess in this way as well. """ if not SKLEARN_INSTALLED: raise LightGBMError('Scikit-learn is required for this module') self.boosting_type = boosting_type - if objective is None: - if isinstance(self, LGBMRegressor): - self.objective = "regression" - elif isinstance(self, LGBMClassifier): - self.objective = "binary" - elif isinstance(self, LGBMRanker): - self.objective = "lambdarank" - else: - raise TypeError("Unknown LGBMModel type.") - else: - self.objective = objective + self.objective = objective self.num_leaves = num_leaves self.max_depth = max_depth self.learning_rate = learning_rate @@ -240,19 +256,19 @@ def __init__(self, boosting_type="gbdt", num_leaves=31, max_depth=-1, self.n_jobs = n_jobs self.silent = silent self._Booster = None - self.evals_result = None - self.best_iteration = -1 - self.best_score = {} - if callable(self.objective): - self.fobj = _objective_function_wrapper(self.objective) - else: - self.fobj = None - self.other_params = {} + self._evals_result = None + self._best_score = None + self._best_iteration = None + self._other_params = {} + self._objective = None + self._n_features = None + self._classes = None + self._n_classes = None self.set_params(**kwargs) def get_params(self, deep=True): params = super(LGBMModel, self).get_params(deep=deep) - params.update(self.other_params) + params.update(self._other_params) if 'seed' in params: warnings.warn('The `seed` parameter is deprecated and will be removed in next version. ' 'Please use `random_state` instead.', LGBMDeprecationWarning) @@ -265,85 +281,105 @@ def get_params(self, deep=True): def set_params(self, **params): for key, value in params.items(): setattr(self, key, value) - self.other_params[key] = value + self._other_params[key] = value return self def fit(self, X, y, sample_weight=None, init_score=None, group=None, eval_set=None, eval_names=None, eval_sample_weight=None, - eval_init_score=None, eval_group=None, - eval_metric=None, - early_stopping_rounds=None, verbose=True, - feature_name='auto', categorical_feature='auto', - callbacks=None): - """ - Fit the gradient boosting model + eval_init_score=None, eval_group=None, eval_metric=None, + early_stopping_rounds=None, verbose=True, feature_name='auto', + categorical_feature='auto', callbacks=None): + """Build a gradient boosting model from the training set (X, y). Parameters ---------- - X : array_like - Feature matrix - y : array_like - Labels - sample_weight : array_like - weight of training data - init_score : array_like - init score of training data - group : array_like - group data of training data - eval_set : list, optional - A list of (X, y) tuple pairs to use as a validation set for early-stopping - eval_names: list of string - Names of eval_set - eval_sample_weight : List of array - weight of eval data - eval_init_score : List of array - init score of eval data - eval_group : List of array - group data of eval data - eval_metric : str, list of str, callable, optional - If a str, should be a built-in evaluation metric to use. - If callable, a custom evaluation metric, see note for more details. - early_stopping_rounds : int - verbose : bool - If `verbose` and an evaluation set is used, writes the evaluation - feature_name : list of str, or 'auto' - Feature names - If 'auto' and data is pandas DataFrame, use data columns name - categorical_feature : list of str or int, or 'auto' - Categorical features, - type int represents index, - type str represents feature names (need to specify feature_name as well) - If 'auto' and data is pandas DataFrame, use pandas categorical columns - callbacks : list of callback functions + X : array-like or sparse matrix of shape = [n_samples, n_features] + Input feature matrix. + y : array-like of shape = [n_samples] + The target values (class labels in classification, real numbers in regression). + sample_weight : array-like of shape = [n_samples] or None, optional (default=None) + Weights of training data. + init_score : array-like of shape = [n_samples] or None, optional (default=None) + Init score of training data. + group : array-like of shape = [n_samples] or None, optional (default=None) + Group data of training data. + eval_set : list or None, optional (default=None) + A list of (X, y) tuple pairs to use as a validation sets for early-stopping. + eval_names: list of strings or None, optional (default=None) + Names of eval_set. + eval_sample_weight : list of arrays or None, optional (default=None) + Weights of eval data. + eval_init_score : list of arrays or None, optional (default=None) + Init score of eval data. + eval_group : list of arrays or None, optional (default=None) + Group data of eval data. + eval_metric : string, list of strings, callable or None, optional (default=None) + If string, it should be a built-in evaluation metric to use. + If callable, it should be a custom evaluation metric, see note for more details. + early_stopping_rounds : int or None, optional (default=None) + Activates early stopping. The model will train until the validation score stops improving. + Validation error needs to decrease at least every `early_stopping_rounds` round(s) + to continue training. + verbose : bool, optional (default=True) + If True and an evaluation set is used, writes the evaluation progress. + feature_name : list of strings or 'auto', optional (default="auto") + Feature names. + If 'auto' and data is pandas DataFrame, data columns names are used. + categorical_feature : list of strings or int, or 'auto', optional (default="auto") + Categorical features. + If list of int, interpreted as indices. + If list of strings, interpreted as feature names (need to specify feature_name as well). + If 'auto' and data is pandas DataFrame, pandas categorical columns are used. + callbacks : list of callback functions or None, optional (default=None) List of callback functions that are applied at each iteration. See Callbacks in Python-API.md for more information. + Returns + ------- + self : object + Returns self. + Note ---- Custom eval function expects a callable with following functions: - ``func(y_true, y_pred)``, ``func(y_true, y_pred, weight)`` - or ``func(y_true, y_pred, weight, group)``. - return (eval_name, eval_result, is_bigger_better) - or list of (eval_name, eval_result, is_bigger_better) - - y_true: array_like of shape [n_samples] - The target values - y_pred: array_like of shape [n_samples] or shape[n_samples * n_class] (for multi-class) - The predicted values - weight: array_like of shape [n_samples] - The weight of samples - group: array_like - group/query data, used for ranking task + ``func(y_true, y_pred)``, ``func(y_true, y_pred, weight)`` or + ``func(y_true, y_pred, weight, group)``. + Returns (eval_name, eval_result, is_bigger_better) or + list of (eval_name, eval_result, is_bigger_better) + + y_true: array-like of shape = [n_samples] + The target values. + y_pred: array-like of shape = [n_samples] or shape = [n_samples * n_classes] (for multi-class) + The predicted values. + weight: array-like of shape = [n_samples] + The weight of samples. + group: array-like + Group/query data, used for ranking task. eval_name: str - name of evaluation + The name of evaluation. eval_result: float - eval result + The eval result. is_bigger_better: bool - is eval result bigger better, e.g. AUC is bigger_better. - for multi-class task, the y_pred is group by class_id first, then group by row_id - if you want to get i-th row y_pred in j-th class, the access way is y_pred[j*num_data+i] + Is eval result bigger better, e.g. AUC is bigger_better. + For multi-class task, the y_pred is group by class_id first, then group by row_id. + If you want to get i-th row y_pred in j-th class, the access way is y_pred[j * num_data + i]. """ + if not hasattr(self, '_objective'): + self._objective = self.objective + if self._objective is None: + if isinstance(self, LGBMRegressor): + self._objective = "regression" + elif isinstance(self, LGBMClassifier): + self._objective = "binary" + elif isinstance(self, LGBMRanker): + self._objective = "lambdarank" + else: + raise ValueError("Unknown LGBMModel type.") + if callable(self._objective): + self._fobj = _objective_function_wrapper(self._objective) + else: + self._fobj = None evals_result = {} params = self.get_params() # sklearn interface has another naming convention @@ -354,11 +390,12 @@ def fit(self, X, y, params['verbose'] = -1 params.pop('silent', None) params.pop('n_estimators', None) - if hasattr(self, 'n_classes_') and self.n_classes_ > 2: - params['num_class'] = self.n_classes_ - if hasattr(self, 'eval_at'): - params['ndcg_eval_at'] = self.eval_at - if self.fobj: + if self._n_classes is not None and self._n_classes > 2: + params['num_class'] = self._n_classes + if hasattr(self, '_eval_at'): + params['ndcg_eval_at'] = self._eval_at + params['objective'] = self._objective + if self._fobj: params['objective'] = 'None' # objective = nullptr for unknown objective if callable(eval_metric): @@ -367,6 +404,10 @@ def fit(self, X, y, feval = None params['metric'] = eval_metric + X, y = _LGBMCheckXY(X, y, accept_sparse=True, force_all_finite=False, ensure_min_samples=2) + _LGBMCheckConsistentLength(X, y, sample_weight) + self._n_features = X.shape[1] + def _construct_dataset(X, y, sample_weight, init_score, group, params): ret = Dataset(X, label=y, max_bin=self.max_bin, weight=sample_weight, group=group, params=params) ret.set_init_score(init_score) @@ -401,17 +442,17 @@ def get_meta_data(collection, i): self._Booster = train(params, train_set, self.n_estimators, valid_sets=valid_sets, valid_names=eval_names, early_stopping_rounds=early_stopping_rounds, - evals_result=evals_result, fobj=self.fobj, feval=feval, + evals_result=evals_result, fobj=self._fobj, feval=feval, verbose_eval=verbose, feature_name=feature_name, categorical_feature=categorical_feature, callbacks=callbacks) if evals_result: - self.evals_result = evals_result + self._evals_result = evals_result if early_stopping_rounds is not None: - self.best_iteration = self._Booster.best_iteration - self.best_score = self._Booster.best_score + self._best_iteration = self._Booster.best_iteration + self._best_score = self._Booster.best_score # free dataset self.booster_.free_dataset() @@ -419,63 +460,110 @@ def get_meta_data(collection, i): return self def predict(self, X, raw_score=False, num_iteration=0): - """ - Return the predicted value for each sample. + """Return the predicted value for each sample. Parameters ---------- - X : array_like, shape=[n_samples, n_features] + X : array-like or sparse matrix of shape = [n_samples, n_features] Input features matrix. - - num_iteration : int + raw_score : bool, optional (default=False) + Whether to predict raw scores. + num_iteration : int, optional (default=0) Limit number of iterations in the prediction; defaults to 0 (use all trees). Returns ------- - predicted_result : array_like, shape=[n_samples] or [n_samples, n_classes] + predicted_result : array-like of shape = [n_samples] or shape = [n_samples, n_classes] + The predicted values. """ + if self._n_features is None: + raise LGBMNotFittedError("Estimator not fitted, call `fit` before exploiting the model.") + X = _LGBMCheckArray(X, accept_sparse=True, force_all_finite=False) + n_features = X.shape[1] + if self._n_features != n_features: + raise ValueError("Number of features of the model must " + "match the input. Model n_features_ is %s and " + "input n_features is %s " + % (self._n_features, n_features)) return self.booster_.predict(X, raw_score=raw_score, num_iteration=num_iteration) def apply(self, X, num_iteration=0): - """ - Return the predicted leaf every tree for each sample. + """Return the predicted leaf every tree for each sample. Parameters ---------- - X : array_like, shape=[n_samples, n_features] + X : array-like or sparse matrix of shape = [n_samples, n_features] Input features matrix. - - num_iteration : int + num_iteration : int, optional (default=0) Limit number of iterations in the prediction; defaults to 0 (use all trees). Returns ------- - X_leaves : array_like, shape=[n_samples, n_trees] + X_leaves : array-like of shape = [n_samples, n_trees] + The predicted leaf every tree for each sample. """ + if self._n_features is None: + raise LGBMNotFittedError("Estimator not fitted, call `fit` before exploiting the model.") + X = _LGBMCheckArray(X, accept_sparse=True, force_all_finite=False) + n_features = X.shape[1] + if self._n_features != n_features: + raise ValueError("Number of features of the model must " + "match the input. Model n_features_ is %s and " + "input n_features is %s " + % (self._n_features, n_features)) return self.booster_.predict(X, pred_leaf=True, num_iteration=num_iteration) + @property + def n_features_(self): + """Get the number of features of fitted model.""" + if self._n_features is None: + raise LGBMNotFittedError('No n_features found. Need to call fit beforehand.') + return self._n_features + + @property + def best_score_(self): + """Get the best score of fitted model.""" + if self._n_features is None: + raise LGBMNotFittedError('No best_score found. Need to call fit beforehand.') + return self._best_score + + @property + def best_iteration_(self): + """Get the best iteration of fitted model.""" + if self._n_features is None: + raise LGBMNotFittedError('No best_iteration found. Need to call fit with early_stopping_rounds beforehand.') + return self._best_iteration + + @property + def objective_(self): + """Get the concrete objective used while fitting this model.""" + if self._n_features is None: + raise LGBMNotFittedError('No objective found. Need to call fit beforehand.') + return self._objective + @property def booster_(self): """Get the underlying lightgbm Booster of this model.""" if self._Booster is None: - raise LightGBMError('No booster found. Need to call fit beforehand.') + raise LGBMNotFittedError('No booster found. Need to call fit beforehand.') return self._Booster @property def evals_result_(self): """Get the evaluation results.""" - if self.evals_result is None: - raise LightGBMError('No results found. Need to call fit with eval set beforehand.') - return self.evals_result + if self._n_features is None: + raise LGBMNotFittedError('No results found. Need to call fit with eval_set beforehand.') + return self._evals_result @property def feature_importances_(self): - """ - Get feature importances. + """Get feature importances. Note: feature importance in sklearn interface used to normalize to 1, - it's deprecated after 2.0.4 and same as Booster.feature_importance() now + it's deprecated after 2.0.4 and same as Booster.feature_importance() now. """ + if self._n_features is None: + raise LGBMNotFittedError('No feature_importances found. Need to call fit beforehand.') return self.booster_.feature_importance() @LGBMDeprecated('Use attribute booster_ instead.') @@ -487,15 +575,14 @@ def feature_importance(self): return self.feature_importances_ -class LGBMRegressor(LGBMModel, LGBMRegressorBase): +class LGBMRegressor(LGBMModel, _LGBMRegressorBase): + """LightGBM regressor.""" def fit(self, X, y, sample_weight=None, init_score=None, eval_set=None, eval_names=None, eval_sample_weight=None, - eval_init_score=None, - eval_metric="l2", - early_stopping_rounds=None, verbose=True, - feature_name='auto', categorical_feature='auto', callbacks=None): + eval_init_score=None, eval_metric="l2", early_stopping_rounds=None, + verbose=True, feature_name='auto', categorical_feature='auto', callbacks=None): super(LGBMRegressor, self).fit(X, y, sample_weight=sample_weight, init_score=init_score, eval_set=eval_set, @@ -509,25 +596,30 @@ def fit(self, X, y, callbacks=callbacks) return self + base_doc = LGBMModel.fit.__doc__ + fit.__doc__ = (base_doc[:base_doc.find('eval_metric :')] + + 'eval_metric : string, list of strings, callable or None, optional (default="l2")\n' + + base_doc[base_doc.find(' If string, it should be a built-in evaluation metric to use.'):]) -class LGBMClassifier(LGBMModel, LGBMClassifierBase): + +class LGBMClassifier(LGBMModel, _LGBMClassifierBase): + """LightGBM classifier.""" def fit(self, X, y, sample_weight=None, init_score=None, eval_set=None, eval_names=None, eval_sample_weight=None, - eval_init_score=None, - eval_metric="logloss", + eval_init_score=None, eval_metric="logloss", early_stopping_rounds=None, verbose=True, - feature_name='auto', categorical_feature='auto', - callbacks=None): - self._le = LGBMLabelEncoder().fit(y) + feature_name='auto', categorical_feature='auto', callbacks=None): + _LGBMCheckClassificationTargets(y) + self._le = _LGBMLabelEncoder().fit(y) _y = self._le.transform(y) - self.classes = self._le.classes_ - self.n_classes = len(self.classes_) - if self.n_classes > 2: + self._classes = self._le.classes_ + self._n_classes = len(self._classes) + if self._n_classes > 2: # Switch to using a multiclass objective in the underlying LGBM instance - self.objective = "multiclass" + self._objective = "multiclass" if eval_metric == 'logloss' or eval_metric == 'binary_logloss': eval_metric = "multi_logloss" elif eval_metric == 'error' or eval_metric == 'binary_error': @@ -559,66 +651,73 @@ def fit(self, X, y, callbacks=callbacks) return self + base_doc = LGBMModel.fit.__doc__ + fit.__doc__ = (base_doc[:base_doc.find('eval_metric :')] + + 'eval_metric : string, list of strings, callable or None, optional (default="logloss")\n' + + base_doc[base_doc.find(' If string, it should be a built-in evaluation metric to use.'):]) + def predict(self, X, raw_score=False, num_iteration=0): class_probs = self.predict_proba(X, raw_score, num_iteration) class_index = np.argmax(class_probs, axis=1) return self._le.inverse_transform(class_index) def predict_proba(self, X, raw_score=False, num_iteration=0): - """ - Return the predicted probability for each class for each sample. + """Return the predicted probability for each class for each sample. Parameters ---------- - X : array_like, shape=[n_samples, n_features] + X : array-like or sparse matrix of shape = [n_samples, n_features] Input features matrix. - - num_iteration : int + raw_score : bool, optional (default=False) + Whether to predict raw scores. + num_iteration : int, optional (default=0) Limit number of iterations in the prediction; defaults to 0 (use all trees). Returns ------- - predicted_probability : array_like, shape=[n_samples, n_classes] + predicted_probability : array-like of shape = [n_samples, n_classes] + The predicted probability for each class for each sample. """ + if self._n_features is None: + raise LGBMNotFittedError("Estimator not fitted, call `fit` before exploiting the model.") + X = _LGBMCheckArray(X, accept_sparse=True, force_all_finite=False) + n_features = X.shape[1] + if self._n_features != n_features: + raise ValueError("Number of features of the model must " + "match the input. Model n_features_ is %s and " + "input n_features is %s " + % (self._n_features, n_features)) class_probs = self.booster_.predict(X, raw_score=raw_score, num_iteration=num_iteration) - if self.n_classes > 2: + if self._n_classes > 2: return class_probs else: return np.vstack((1. - class_probs, class_probs)).transpose() @property def classes_(self): - """Get class label array.""" - if self.classes is None: - raise LightGBMError('No classes found. Need to call fit beforehand.') - return self.classes + """Get the class label array.""" + if self._classes is None: + raise LGBMNotFittedError('No classes found. Need to call fit beforehand.') + return self._classes @property def n_classes_(self): - """Get number of classes""" - if self.n_classes is None: - raise LightGBMError('No classes found. Need to call fit beforehand.') - return self.n_classes + """Get the number of classes.""" + if self._n_classes is None: + raise LGBMNotFittedError('No classes found. Need to call fit beforehand.') + return self._n_classes class LGBMRanker(LGBMModel): + """LightGBM ranker.""" def fit(self, X, y, sample_weight=None, init_score=None, group=None, eval_set=None, eval_names=None, eval_sample_weight=None, - eval_init_score=None, eval_group=None, - eval_metric='ndcg', eval_at=1, - early_stopping_rounds=None, verbose=True, - feature_name='auto', categorical_feature='auto', - callbacks=None): - """ - Most arguments like common methods except following: - - eval_at : list of int - The evaulation positions of NDCG - """ - - """check group data""" + eval_init_score=None, eval_group=None, eval_metric='ndcg', + eval_at=[1], early_stopping_rounds=None, verbose=True, + feature_name='auto', categorical_feature='auto', callbacks=None): + # check group data if group is None: raise ValueError("Should set group for ranking task") @@ -626,13 +725,13 @@ def fit(self, X, y, if eval_group is None: raise ValueError("Eval_group cannot be None when eval_set is not None") elif len(eval_group) != len(eval_set): - raise ValueError("Length of eval_group should equal to eval_set") + raise ValueError("Length of eval_group should be equal to eval_set") elif (isinstance(eval_group, dict) and any(i not in eval_group or eval_group[i] is None for i in range_(len(eval_group)))) \ or (isinstance(eval_group, list) and any(group is None for group in eval_group)): - raise ValueError("Should set group for all eval dataset for ranking task; if you use dict, the index should start from 0") + raise ValueError("Should set group for all eval datasets for ranking task; " + "if you use dict, the index should start from 0") - if eval_at is not None: - self.eval_at = eval_at + self._eval_at = eval_at super(LGBMRanker, self).fit(X, y, sample_weight=sample_weight, init_score=init_score, group=group, eval_set=eval_set, eval_names=eval_names, @@ -644,3 +743,11 @@ def fit(self, X, y, categorical_feature=categorical_feature, callbacks=callbacks) return self + + base_doc = LGBMModel.fit.__doc__ + fit.__doc__ = (base_doc[:base_doc.find('eval_metric :')] + + 'eval_metric : string, list of strings, callable or None, optional (default="ndcg")\n' + + base_doc[base_doc.find(' If string, it should be a built-in evaluation metric to use.'):base_doc.find('early_stopping_rounds :')] + + 'eval_at : list of int, optional (default=[1])\n' + ' The evaluation positions of NDCG.\n' + + base_doc[base_doc.find(' early_stopping_rounds :'):]) diff --git a/tests/python_package_test/test_sklearn.py b/tests/python_package_test/test_sklearn.py index b0237c7d3a73..39415e30548e 100644 --- a/tests/python_package_test/test_sklearn.py +++ b/tests/python_package_test/test_sklearn.py @@ -13,6 +13,9 @@ from sklearn.externals import joblib from sklearn.metrics import log_loss, mean_squared_error from sklearn.model_selection import GridSearchCV, train_test_split +from sklearn.utils.estimator_checks import (_yield_all_checks, SkipTest, + check_parameters_default_constructible, + check_no_fit_attributes_set_in_init) def multi_error(y_true, y_pred): @@ -32,7 +35,7 @@ def test_binary(self): gbm.fit(X_train, y_train, eval_set=[(X_test, y_test)], early_stopping_rounds=5, verbose=False) ret = log_loss(y_test, gbm.predict_proba(X_test)) self.assertLess(ret, 0.15) - self.assertAlmostEqual(ret, gbm.evals_result['valid_0']['binary_logloss'][gbm.best_iteration - 1], places=5) + self.assertAlmostEqual(ret, gbm.evals_result_['valid_0']['binary_logloss'][gbm.best_iteration_ - 1], places=5) def test_regreesion(self): X, y = load_boston(True) @@ -41,7 +44,7 @@ def test_regreesion(self): gbm.fit(X_train, y_train, eval_set=[(X_test, y_test)], early_stopping_rounds=5, verbose=False) ret = mean_squared_error(y_test, gbm.predict(X_test)) self.assertLess(ret, 16) - self.assertAlmostEqual(ret, gbm.evals_result['valid_0']['l2'][gbm.best_iteration - 1], places=5) + self.assertAlmostEqual(ret, gbm.evals_result_['valid_0']['l2'][gbm.best_iteration_ - 1], places=5) def test_multiclass(self): X, y = load_digits(10, True) @@ -51,7 +54,7 @@ def test_multiclass(self): ret = multi_error(y_test, gbm.predict(X_test)) self.assertLess(ret, 0.2) ret = multi_logloss(y_test, gbm.predict_proba(X_test)) - self.assertAlmostEqual(ret, gbm.evals_result['valid_0']['multi_logloss'][gbm.best_iteration - 1], places=5) + self.assertAlmostEqual(ret, gbm.evals_result_['valid_0']['multi_logloss'][gbm.best_iteration_ - 1], places=5) def test_lambdarank(self): X_train, y_train = load_svmlight_file(os.path.join(os.path.dirname(os.path.realpath(__file__)), '../../examples/lambdarank/rank.train')) @@ -74,7 +77,7 @@ def objective_ls(y_true, y_pred): gbm.fit(X_train, y_train, eval_set=[(X_test, y_test)], early_stopping_rounds=5, verbose=False) ret = mean_squared_error(y_test, gbm.predict(X_test)) self.assertLess(ret, 100) - self.assertAlmostEqual(ret, gbm.evals_result['valid_0']['l2'][gbm.best_iteration - 1], places=5) + self.assertAlmostEqual(ret, gbm.evals_result_['valid_0']['l2'][gbm.best_iteration_ - 1], places=5) def test_binary_classification_with_custom_objective(self): def logregobj(y_true, y_pred): @@ -177,3 +180,19 @@ def test_sklearn_backward_compatibility(self): clf_2.set_params(nthread=-1).fit(X_train, y_train) self.assertEqual(len(w), 2) self.assertTrue(issubclass(w[-1].category, Warning)) + + def test_sklearn_integration(self): + # we cannot use `check_estimator` directly since there is no skip test mechanism + for name, estimator in ((lgb.sklearn.LGBMClassifier.__name__, lgb.sklearn.LGBMClassifier), + (lgb.sklearn.LGBMRegressor.__name__, lgb.sklearn.LGBMRegressor)): + check_parameters_default_constructible(name, estimator) + check_no_fit_attributes_set_in_init(name, estimator) + # we cannot leave default params (see https://github.com/Microsoft/LightGBM/issues/833) + estimator = estimator(min_data=1, min_data_in_bin=1) + for check in _yield_all_checks(name, estimator): + if check.__name__ == 'check_estimators_nan_inf': + continue # skip test because LightGBM deals with nan + try: + check(name, estimator) + except SkipTest as message: + warnings.warn(message, SkipTestWarning)