-
Notifications
You must be signed in to change notification settings - Fork 174
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implemented proper work with multiple threads #1361
Changes from 11 commits
284f7a7
53b59fd
674a666
bb2cd80
67e080c
d7ed459
6420de9
1670cee
18ea5bd
17ceebc
e331715
f016726
d542401
5648d30
20b0e84
9f74cee
a69e57f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,50 @@ | ||||||
# ============================================================================== | ||||||
# Copyright 2023 Intel Corporation | ||||||
# | ||||||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
# you may not use this file except in compliance with the License. | ||||||
# You may obtain a copy of the License at | ||||||
# | ||||||
# http://www.apache.org/licenses/LICENSE-2.0 | ||||||
# | ||||||
# Unless required by applicable law or agreed to in writing, software | ||||||
# distributed under the License is distributed on an "AS IS" BASIS, | ||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
# See the License for the specific language governing permissions and | ||||||
# limitations under the License. | ||||||
# ============================================================================== | ||||||
import pytest | ||||||
|
||||||
from sklearnex import config_context, patch_sklearn | ||||||
|
||||||
patch_sklearn() | ||||||
|
||||||
from sklearn.datasets import make_classification | ||||||
from sklearn.ensemble import BaggingClassifier | ||||||
from sklearn.svm import SVC | ||||||
|
||||||
try: | ||||||
import dpctl | ||||||
|
||||||
dpctl_is_available = True | ||||||
gpu_is_available = dpctl.has_gpu_devices() | ||||||
except (ImportError, ModuleNotFoundError): | ||||||
dpctl_is_available = False | ||||||
|
||||||
|
||||||
@pytest.mark.skipif( | ||||||
not dpctl_is_available or gpu_is_available, | ||||||
ethanglaser marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
reason="GPU device should not be available for this test " | ||||||
"to see raised 'SyclQueueCreationError'. " | ||||||
"'dpctl' module is required for test.", | ||||||
) | ||||||
def test_config_context_in_parallel(): | ||||||
x, y = make_classification(random_state=42) | ||||||
try: | ||||||
with config_context(target_offload="gpu"): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually maybe this modification isn't necessary, but I am a bit confused by the test - when would dpctl be available but no GPU? Thanks for adding the test. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. dpctl is not only for gpu devices. For example, CI has dpctl installed without gpu in azure pipelines used instances. |
||||||
BaggingClassifier(SVC(), n_jobs=2).fit(x, y) | ||||||
raise ValueError( | ||||||
"'SyclQueueCreationError' wasn't raised " "for non-existing 'gpu' device" | ||||||
) | ||||||
except dpctl._sycl_queue.SyclQueueCreationError: | ||||||
pass |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
# =============================================================================== | ||
# Copyright 2023 Intel Corporation | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# =============================================================================== | ||
|
||
import warnings | ||
from functools import update_wrapper | ||
|
||
from .._config import config_context, get_config | ||
|
||
|
||
class _FuncWrapper: | ||
"""Load the global configuration before calling the function.""" | ||
|
||
def __init__(self, function): | ||
self.function = function | ||
update_wrapper(self, self.function) | ||
|
||
def with_config(self, config): | ||
self.config = config | ||
return self | ||
|
||
def __call__(self, *args, **kwargs): | ||
config = getattr(self, "config", None) | ||
if config is None: | ||
warnings.warn( | ||
"`sklearn.utils.parallel.delayed` should be used with " | ||
"`sklearn.utils.parallel.Parallel` to make it possible to propagate " | ||
"the scikit-learn configuration of the current thread to the " | ||
"joblib workers.", | ||
UserWarning, | ||
) | ||
config = {} | ||
with config_context(**config): | ||
return self.function(*args, **kwargs) | ||
|
||
olegkkruglov marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
class _FuncWrapperOld: | ||
"""Load the global configuration before calling the function.""" | ||
|
||
def __init__(self, function): | ||
self.function = function | ||
self.config = get_config() | ||
update_wrapper(self, self.function) | ||
|
||
def __call__(self, *args, **kwargs): | ||
with config_context(**self.config): | ||
return self.function(*args, **kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍