diff --git a/.ci/scripts/run_sklearn_tests.sh b/.ci/scripts/run_sklearn_tests.sh index 8a05eaf47b..1ff00d24ca 100755 --- a/.ci/scripts/run_sklearn_tests.sh +++ b/.ci/scripts/run_sklearn_tests.sh @@ -39,5 +39,10 @@ if [ -n "${OCL_ICD_FILENAMES}" ]; then echo "OCL_ICD_FILENAMES is set to ${OCL_ICD_FILENAMES}" fi +# Show devices listed by dpctl +if [ -n "$(pip list | grep dpctl)" ]; then + python -c "import dpctl; print(dpctl.get_devices())" +fi + python scripts/run_sklearn_tests.py -d ${1:-none} exit $? diff --git a/conda-recipe/run_test.sh b/conda-recipe/run_test.sh index 879b37736b..0eea868ec7 100755 --- a/conda-recipe/run_test.sh +++ b/conda-recipe/run_test.sh @@ -54,7 +54,9 @@ pytest --verbose --pyargs ${daal4py_dir}/daal4py/sklearn return_code=$(($return_code + $?)) echo "Pytest of sklearnex running ..." -pytest --verbose --pyargs ${daal4py_dir}/sklearnex +# TODO: investigate why test_monkeypatch.py might cause failures of other tests +pytest --verbose --pyargs --deselect sklearnex/tests/test_monkeypatch.py ${daal4py_dir}/sklearnex +pytest --verbose ${daal4py_dir}/sklearnex/tests/test_monkeypatch.py return_code=$(($return_code + $?)) echo "Pytest of onedal running ..." diff --git a/deselected_tests.yaml b/deselected_tests.yaml index 8b514e3ef6..60de1bc44c 100755 --- a/deselected_tests.yaml +++ b/deselected_tests.yaml @@ -728,7 +728,6 @@ gpu: - neighbors/tests/test_neighbors.py::test_neigh_predictions_algorithm_agnosticity[float64-KNeighborsRegressor-50-500-l2-1000-5-100] - neighbors/tests/test_neighbors.py::test_neigh_predictions_algorithm_agnosticity[float64-KNeighborsRegressor-100-1000-l2-1000-5-100] # failing due to numeric/code error - - ensemble/tests/test_bagging.py::test_parallel_classification - linear_model/tests/test_common.py::test_balance_property[42-False-LogisticRegressionCV] - sklearn/manifold/tests/test_t_sne.py::test_n_iter_without_progress - model_selection/tests/test_search.py::test_searchcv_raise_warning_with_non_finite_score[RandomizedSearchCV-specialized_params1-False] diff --git a/sklearnex/dispatcher.py b/sklearnex/dispatcher.py index ebd30dee95..06343b6c92 100644 --- a/sklearnex/dispatcher.py +++ b/sklearnex/dispatcher.py @@ -48,11 +48,23 @@ def get_patch_map(): import sklearn.neighbors as neighbors_module import sklearn.svm as svm_module + if sklearn_check_version("1.2.1"): + import sklearn.utils.parallel as parallel_module + else: + import sklearn.utils.fixes as parallel_module + # Classes and functions for patching from ._config import config_context as config_context_sklearnex from ._config import get_config as get_config_sklearnex from ._config import set_config as set_config_sklearnex + + if sklearn_check_version("1.2.1"): + from .utils.parallel import _FuncWrapper as _FuncWrapper_sklearnex + else: + from .utils.parallel import _FuncWrapperOld as _FuncWrapper_sklearnex + from .cluster import DBSCAN as DBSCAN_sklearnex + from .neighbors import KNeighborsClassifier as KNeighborsClassifier_sklearnex from .neighbors import KNeighborsRegressor as KNeighborsRegressor_sklearnex from .neighbors import LocalOutlierFactor as LocalOutlierFactor_sklearnex @@ -226,6 +238,14 @@ def get_patch_map(): mapping["config_context"] = [ [(base_module, "config_context", config_context_sklearnex), None] ] + + # Necessary for proper work with multiple threads + mapping["parallel.get_config"] = [ + [(parallel_module, "get_config", get_config_sklearnex), None] + ] + mapping["_funcwrapper"] = [ + [(parallel_module, "_FuncWrapper", _FuncWrapper_sklearnex), None] + ] return mapping diff --git a/sklearnex/svm/_common.py b/sklearnex/svm/_common.py index f0be2fad06..00b2cb936a 100644 --- a/sklearnex/svm/_common.py +++ b/sklearnex/svm/_common.py @@ -100,39 +100,30 @@ def _compute_balanced_class_weight(self, y): return recip_freq[le.transform(classes)] def _fit_proba(self, X, y, sample_weight=None, queue=None): - from .._config import config_context, get_config - params = self.get_params() params["probability"] = False params["decision_function_shape"] = "ovr" clf_base = self.__class__(**params) - # We use stock metaestimators below, so the only way - # to pass a queue is using config_context. - cfg = get_config() - cfg["target_offload"] = queue - with config_context(**cfg): - try: - n_splits = 5 - n_jobs = n_splits if queue is None or queue.sycl_device.is_cpu else 1 - cv = StratifiedKFold( - n_splits=n_splits, shuffle=True, random_state=self.random_state - ) - if sklearn_check_version("0.24"): - self.clf_prob = CalibratedClassifierCV( - clf_base, ensemble=False, cv=cv, method="sigmoid", n_jobs=n_jobs - ) - else: - self.clf_prob = CalibratedClassifierCV( - clf_base, cv=cv, method="sigmoid" - ) - self.clf_prob.fit(X, y, sample_weight) - except ValueError: - clf_base = clf_base.fit(X, y, sample_weight) + try: + n_splits = 5 + n_jobs = n_splits if queue is None or queue.sycl_device.is_cpu else 1 + cv = StratifiedKFold( + n_splits=n_splits, shuffle=True, random_state=self.random_state + ) + if sklearn_check_version("0.24"): self.clf_prob = CalibratedClassifierCV( - clf_base, cv="prefit", method="sigmoid" + clf_base, ensemble=False, cv=cv, method="sigmoid", n_jobs=n_jobs ) - self.clf_prob.fit(X, y, sample_weight) + else: + self.clf_prob = CalibratedClassifierCV(clf_base, cv=cv, method="sigmoid") + self.clf_prob.fit(X, y, sample_weight) + except ValueError: + clf_base = clf_base.fit(X, y, sample_weight) + self.clf_prob = CalibratedClassifierCV( + clf_base, cv="prefit", method="sigmoid" + ) + self.clf_prob.fit(X, y, sample_weight) def _save_attributes(self): self.support_vectors_ = self._onedal_estimator.support_vectors_ diff --git a/sklearnex/tests/test_memory_usage.py b/sklearnex/tests/test_memory_usage.py index ff8df2e91f..853c165c16 100644 --- a/sklearnex/tests/test_memory_usage.py +++ b/sklearnex/tests/test_memory_usage.py @@ -90,7 +90,7 @@ def get_patched_estimators(ban_list, output_list): estimator, name = listing[0][0][2], listing[0][0][1] if not isinstance(estimator, types.FunctionType): if name not in ban_list: - if isinstance(estimator(), BaseEstimator): + if issubclass(estimator, BaseEstimator): if hasattr(estimator, "fit"): output_list.append(estimator) diff --git a/sklearnex/tests/test_parallel.py b/sklearnex/tests/test_parallel.py new file mode 100644 index 0000000000..2736c93cb5 --- /dev/null +++ b/sklearnex/tests/test_parallel.py @@ -0,0 +1,50 @@ +# ============================================================================== +# Copyright 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import pytest + +from sklearnex import config_context, patch_sklearn + +patch_sklearn() + +from sklearn.datasets import make_classification +from sklearn.ensemble import BaggingClassifier +from sklearn.svm import SVC + +try: + import dpctl + + dpctl_is_available = True + gpu_is_available = dpctl.has_gpu_devices() +except (ImportError, ModuleNotFoundError): + dpctl_is_available = False + + +@pytest.mark.skipif( + not dpctl_is_available or gpu_is_available, + reason="GPU device should not be available for this test " + "to see raised 'SyclQueueCreationError'. " + "'dpctl' module is required for test.", +) +def test_config_context_in_parallel(): + x, y = make_classification(random_state=42) + try: + with config_context(target_offload="gpu", allow_fallback_to_host=False): + BaggingClassifier(SVC(), n_jobs=2).fit(x, y) + raise ValueError( + "'SyclQueueCreationError' wasn't raised " "for non-existing 'gpu' device" + ) + except dpctl._sycl_queue.SyclQueueCreationError: + pass diff --git a/sklearnex/utils/parallel.py b/sklearnex/utils/parallel.py new file mode 100644 index 0000000000..52e9a0b6f6 --- /dev/null +++ b/sklearnex/utils/parallel.py @@ -0,0 +1,59 @@ +# =============================================================================== +# Copyright 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =============================================================================== + +import warnings +from functools import update_wrapper + +from .._config import config_context, get_config + + +class _FuncWrapper: + """Load the global configuration before calling the function.""" + + def __init__(self, function): + self.function = function + update_wrapper(self, self.function) + + def with_config(self, config): + self.config = config + return self + + def __call__(self, *args, **kwargs): + config = getattr(self, "config", None) + if config is None: + warnings.warn( + "`sklearn.utils.parallel.delayed` should be used with " + "`sklearn.utils.parallel.Parallel` to make it possible to propagate " + "the scikit-learn configuration of the current thread to the " + "joblib workers.", + UserWarning, + ) + config = {} + with config_context(**config): + return self.function(*args, **kwargs) + + +class _FuncWrapperOld: + """Load the global configuration before calling the function.""" + + def __init__(self, function): + self.function = function + self.config = get_config() + update_wrapper(self, self.function) + + def __call__(self, *args, **kwargs): + with config_context(**self.config): + return self.function(*args, **kwargs)