diff --git a/.gitignore b/.gitignore index 6e4ffb35670a..f2f2c6c2316f 100644 --- a/.gitignore +++ b/.gitignore @@ -318,6 +318,8 @@ htmlcov/ .coverage.* .cache nosetests.xml +prof/ +*.prof coverage.xml *,cover .hypothesis/ diff --git a/tests/python_package_test/__init__.py b/tests/python_package_test/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/python_package_test/test_basic.py b/tests/python_package_test/test_basic.py index ee58da5484c8..a0ce5b8f8b66 100644 --- a/tests/python_package_test/test_basic.py +++ b/tests/python_package_test/test_basic.py @@ -7,9 +7,11 @@ import numpy as np from scipy import sparse -from sklearn.datasets import load_breast_cancer, dump_svmlight_file, load_svmlight_file +from sklearn.datasets import dump_svmlight_file, load_svmlight_file from sklearn.model_selection import train_test_split +from .utils import load_breast_cancer + class TestBasic(unittest.TestCase): diff --git a/tests/python_package_test/test_engine.py b/tests/python_package_test/test_engine.py index 0c90a6bada87..de8689fd3ea5 100644 --- a/tests/python_package_test/test_engine.py +++ b/tests/python_package_test/test_engine.py @@ -10,8 +10,7 @@ import lightgbm as lgb import numpy as np from scipy.sparse import csr_matrix, isspmatrix_csr, isspmatrix_csc -from sklearn.datasets import (load_boston, load_breast_cancer, load_digits, - load_iris, load_svmlight_file, make_multilabel_classification) +from sklearn.datasets import load_svmlight_file, make_multilabel_classification from sklearn.metrics import log_loss, mean_absolute_error, mean_squared_error, roc_auc_score, average_precision_score from sklearn.model_selection import train_test_split, TimeSeriesSplit, GroupKFold @@ -20,6 +19,8 @@ except ImportError: import pickle +from .utils import load_boston, load_breast_cancer, load_digits, load_iris + decreasing_generator = itertools.count(0, -1) @@ -2524,6 +2525,7 @@ def test_average_precision_metric(self): sklearn_ap = average_precision_score(y, pred) self.assertAlmostEqual(ap, sklearn_ap) # test that average precision is 1 where model predicts perfectly + y = y.copy() y[:] = 1 lgb_X = lgb.Dataset(X, label=y) lgb.train(params, lgb_X, num_boost_round=1, valid_sets=[lgb_X], evals_result=res) diff --git a/tests/python_package_test/test_plotting.py b/tests/python_package_test/test_plotting.py index 786b79760910..293012348ac3 100644 --- a/tests/python_package_test/test_plotting.py +++ b/tests/python_package_test/test_plotting.py @@ -3,7 +3,6 @@ import lightgbm as lgb from lightgbm.compat import MATPLOTLIB_INSTALLED, GRAPHVIZ_INSTALLED -from sklearn.datasets import load_breast_cancer from sklearn.model_selection import train_test_split if MATPLOTLIB_INSTALLED: @@ -12,6 +11,8 @@ if GRAPHVIZ_INSTALLED: import graphviz +from .utils import load_breast_cancer + class TestBasic(unittest.TestCase): diff --git a/tests/python_package_test/test_sklearn.py b/tests/python_package_test/test_sklearn.py index 4aa9831c2cea..623f83a517a5 100644 --- a/tests/python_package_test/test_sklearn.py +++ b/tests/python_package_test/test_sklearn.py @@ -10,9 +10,7 @@ import numpy as np from sklearn import __version__ as sk_version from sklearn.base import clone -from sklearn.datasets import (load_boston, load_breast_cancer, load_digits, - load_iris, load_linnerud, load_svmlight_file, - make_multilabel_classification) +from sklearn.datasets import load_svmlight_file, make_multilabel_classification from sklearn.exceptions import SkipTestWarning from sklearn.metrics import log_loss, mean_squared_error from sklearn.model_selection import GridSearchCV, RandomizedSearchCV, train_test_split @@ -22,6 +20,8 @@ check_parameters_default_constructible) from sklearn.utils.validation import check_is_fitted +from .utils import load_boston, load_breast_cancer, load_digits, load_iris, load_linnerud + decreasing_generator = itertools.count(0, -1) diff --git a/tests/python_package_test/utils.py b/tests/python_package_test/utils.py new file mode 100644 index 000000000000..f0b160d60dfb --- /dev/null +++ b/tests/python_package_test/utils.py @@ -0,0 +1,45 @@ +# coding: utf-8 +import sklearn.datasets + +try: + from functools import lru_cache +except ImportError: + import warnings + warnings.warn("Could not import functools.lru_cache", RuntimeWarning) + + def lru_cache(maxsize=None): + cache = {} + + def _lru_wrapper(user_function): + def wrapper(*args, **kwargs): + arg_key = (args, tuple(kwargs.items())) + if arg_key not in cache: + cache[arg_key] = user_function(*args, **kwargs) + return cache[arg_key] + return wrapper + return _lru_wrapper + + +@lru_cache(maxsize=None) +def load_boston(**kwargs): + return sklearn.datasets.load_boston(**kwargs) + + +@lru_cache(maxsize=None) +def load_breast_cancer(**kwargs): + return sklearn.datasets.load_breast_cancer(**kwargs) + + +@lru_cache(maxsize=None) +def load_digits(**kwargs): + return sklearn.datasets.load_digits(**kwargs) + + +@lru_cache(maxsize=None) +def load_iris(**kwargs): + return sklearn.datasets.load_iris(**kwargs) + + +@lru_cache(maxsize=None) +def load_linnerud(**kwargs): + return sklearn.datasets.load_linnerud(**kwargs)