diff --git a/sklearnex/dispatcher.py b/sklearnex/dispatcher.py index f93e5e554b..bffed7d64b 100644 --- a/sklearnex/dispatcher.py +++ b/sklearnex/dispatcher.py @@ -47,11 +47,13 @@ def get_patch_map(): import sklearn.linear_model as linear_model_module import sklearn.neighbors as neighbors_module import sklearn.svm as svm_module + import sklearn.utils.parallel as parallel_module # Classes and functions for patching from ._config import config_context as config_context_sklearnex from ._config import get_config as get_config_sklearnex from ._config import set_config as set_config_sklearnex + from .utils.parallel import _FuncWrapper as _FuncWrapper_sklearnex from .neighbors import KNeighborsClassifier as KNeighborsClassifier_sklearnex from .neighbors import KNeighborsRegressor as KNeighborsRegressor_sklearnex from .neighbors import LocalOutlierFactor as LocalOutlierFactor_sklearnex @@ -202,6 +204,14 @@ def get_patch_map(): mapping["config_context"] = [ [(base_module, "config_context", config_context_sklearnex), None] ] + + # Necessary for proper work with multiple threads + mapping["parallel.get_config"] = [ + [(parallel_module, "get_config", get_config_sklearnex), None] + ] + mapping["_funcwrapper"] = [ + [(parallel_module, "_FuncWrapper", _FuncWrapper_sklearnex), None] + ] return mapping diff --git a/sklearnex/utils/parallel.py b/sklearnex/utils/parallel.py new file mode 100644 index 0000000000..479f8aae3c --- /dev/null +++ b/sklearnex/utils/parallel.py @@ -0,0 +1,44 @@ +#=============================================================================== +# Copyright 2022 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +#=============================================================================== + +import warnings +from functools import update_wrapper +from .._config import config_context + +class _FuncWrapper: + """Load the global configuration before calling the function.""" + + def __init__(self, function): + self.function = function + update_wrapper(self, self.function) + + def with_config(self, config): + self.config = config + return self + + def __call__(self, *args, **kwargs): + config = getattr(self, "config", None) + if config is None: + warnings.warn( + "`sklearn.utils.parallel.delayed` should be used with " + "`sklearn.utils.parallel.Parallel` to make it possible to propagate " + "the scikit-learn configuration of the current thread to the " + "joblib workers.", + UserWarning, + ) + config = {} + with config_context(**config): + return self.function(*args, **kwargs)