diff --git a/python-package/lightgbm/sklearn.py b/python-package/lightgbm/sklearn.py index c4853e10f05a..248b58643f03 100644 --- a/python-package/lightgbm/sklearn.py +++ b/python-package/lightgbm/sklearn.py @@ -607,6 +607,8 @@ def _get_meta_data(collection, name, i): self._best_score = self._Booster.best_score + self.fitted_ = True + # free dataset self._Booster.free_dataset() del train_set, valid_sets diff --git a/tests/python_package_test/test_sklearn.py b/tests/python_package_test/test_sklearn.py index 47d0697b2e68..7dc39de23224 100644 --- a/tests/python_package_test/test_sklearn.py +++ b/tests/python_package_test/test_sklearn.py @@ -20,6 +20,7 @@ RegressorChain) from sklearn.utils.estimator_checks import (_yield_all_checks, SkipTest, check_parameters_default_constructible) +from sklearn.utils.validation import check_is_fitted decreasing_generator = itertools.count(0, -1) @@ -1091,3 +1092,23 @@ def test_continue_training_with_model(self): self.assertEqual(len(init_gbm.evals_result_['valid_0']['multi_logloss']), 5) self.assertLess(gbm.evals_result_['valid_0']['multi_logloss'][-1], init_gbm.evals_result_['valid_0']['multi_logloss'][-1]) + + # sklearn < 0.22 requires passing "attributes" argument + @unittest.skipIf(sk_version < '0.22.0', 'scikit-learn version is less than 0.22') + def test_check_is_fitted(self): + X, y = load_digits(n_class=2, return_X_y=True) + est = lgb.LGBMModel(n_estimators=5, objective="binary") + clf = lgb.LGBMClassifier(n_estimators=5) + reg = lgb.LGBMRegressor(n_estimators=5) + rnk = lgb.LGBMRanker(n_estimators=5) + models = (est, clf, reg, rnk) + for model in models: + self.assertRaises(lgb.compat.LGBMNotFittedError, + check_is_fitted, + model) + est.fit(X, y) + clf.fit(X, y) + reg.fit(X, y) + rnk.fit(X, y, group=np.ones(X.shape[0])) + for model in models: + check_is_fitted(model)