From 05c32e8583e35f78e89f0e58b84ee17d1635dbef Mon Sep 17 00:00:00 2001 From: "Kruglov, Oleg" Date: Thu, 13 Jul 2023 18:32:15 -0700 Subject: [PATCH] Add support for older sklearn versions --- sklearnex/dispatcher.py | 10 ++++++++-- sklearnex/utils/parallel.py | 15 ++++++++++++++- 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/sklearnex/dispatcher.py b/sklearnex/dispatcher.py index bffed7d64b..4e263e2cb3 100644 --- a/sklearnex/dispatcher.py +++ b/sklearnex/dispatcher.py @@ -47,13 +47,19 @@ def get_patch_map(): import sklearn.linear_model as linear_model_module import sklearn.neighbors as neighbors_module import sklearn.svm as svm_module - import sklearn.utils.parallel as parallel_module + if sklearn_check_version('1.2.1'): + import sklearn.utils.parallel as parallel_module + else: + import sklearn.utils.fixes as parallel_module # Classes and functions for patching from ._config import config_context as config_context_sklearnex from ._config import get_config as get_config_sklearnex from ._config import set_config as set_config_sklearnex - from .utils.parallel import _FuncWrapper as _FuncWrapper_sklearnex + if sklearn_check_version('1.2.1'): + from .utils.parallel import _FuncWrapper as _FuncWrapper_sklearnex + else: + from .utils.parallel import _FuncWrapperOld as _FuncWrapper_sklearnex from .neighbors import KNeighborsClassifier as KNeighborsClassifier_sklearnex from .neighbors import KNeighborsRegressor as KNeighborsRegressor_sklearnex from .neighbors import LocalOutlierFactor as LocalOutlierFactor_sklearnex diff --git a/sklearnex/utils/parallel.py b/sklearnex/utils/parallel.py index 479f8aae3c..d0a7310830 100644 --- a/sklearnex/utils/parallel.py +++ b/sklearnex/utils/parallel.py @@ -16,7 +16,8 @@ import warnings from functools import update_wrapper -from .._config import config_context +from .._config import config_context, get_config + class _FuncWrapper: """Load the global configuration before calling the function.""" @@ -42,3 +43,15 @@ def __call__(self, *args, **kwargs): config = {} with config_context(**config): return self.function(*args, **kwargs) + +class _FuncWrapperOld: + """ "Load the global configuration before calling the function.""" + + def __init__(self, function): + self.function = function + self.config = get_config() + update_wrapper(self, self.function) + + def __call__(self, *args, **kwargs): + with config_context(**self.config): + return self.function(*args, **kwargs)