Skip to content

Commit

Permalink
Add support for older sklearn versions
Browse files Browse the repository at this point in the history
  • Loading branch information
olegkkruglov committed Jul 14, 2023
1 parent a7ccae7 commit 05c32e8
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 3 deletions.
10 changes: 8 additions & 2 deletions sklearnex/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,19 @@ def get_patch_map():
import sklearn.linear_model as linear_model_module
import sklearn.neighbors as neighbors_module
import sklearn.svm as svm_module
import sklearn.utils.parallel as parallel_module
if sklearn_check_version('1.2.1'):
import sklearn.utils.parallel as parallel_module
else:
import sklearn.utils.fixes as parallel_module

# Classes and functions for patching
from ._config import config_context as config_context_sklearnex
from ._config import get_config as get_config_sklearnex
from ._config import set_config as set_config_sklearnex
from .utils.parallel import _FuncWrapper as _FuncWrapper_sklearnex
if sklearn_check_version('1.2.1'):
from .utils.parallel import _FuncWrapper as _FuncWrapper_sklearnex
else:
from .utils.parallel import _FuncWrapperOld as _FuncWrapper_sklearnex
from .neighbors import KNeighborsClassifier as KNeighborsClassifier_sklearnex
from .neighbors import KNeighborsRegressor as KNeighborsRegressor_sklearnex
from .neighbors import LocalOutlierFactor as LocalOutlierFactor_sklearnex
Expand Down
15 changes: 14 additions & 1 deletion sklearnex/utils/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@

import warnings
from functools import update_wrapper
from .._config import config_context
from .._config import config_context, get_config


class _FuncWrapper:
"""Load the global configuration before calling the function."""
Expand All @@ -42,3 +43,15 @@ def __call__(self, *args, **kwargs):
config = {}
with config_context(**config):
return self.function(*args, **kwargs)

class _FuncWrapperOld:
""" "Load the global configuration before calling the function."""

def __init__(self, function):
self.function = function
self.config = get_config()
update_wrapper(self, self.function)

def __call__(self, *args, **kwargs):
with config_context(**self.config):
return self.function(*args, **kwargs)

0 comments on commit 05c32e8

Please sign in to comment.