-
-
Notifications
You must be signed in to change notification settings - Fork 8.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[pyspark] support gpu transform #9542
Conversation
try: | ||
import cupy | ||
|
||
return True | ||
except ImportError: | ||
return False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test for cuDF availability originally had the comment:
# Checking by `importing` instead of check `importlib.util.find_spec("cudf") is not None`
WeichenXu123 marked this conversation as resolved.
# because user might install cudf successfully but importing cudf raises issues (e.g. saying
# running on mismatched cuda version)
which to me was weird, that's a mismanaged environment and I'm not sure it's necessary for xgboost to work around it (or even a good idea to workaround anything since users with cuDF installed might expect GPU to be used).
python-package/xgboost/spark/core.py
Outdated
def set_device(self, value: str) -> "_SparkXGBParams": | ||
"""Set device (cpu, cuda, gpu)""" | ||
assert value in ("cpu", "cuda", "gpu") | ||
self.set(self.device, value) | ||
return self | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have a check here:
xgboost/python-package/xgboost/core.py
Line 284 in 3b9e590
def _check_distributed_params(kwargs: Dict[str, Any]) -> None: |
Line 95 in 3b9e590
StringView msg{R"(Invalid argument for `device`. Expected to be one of the following: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems _check_distributed_params just checks the type of the value, it seems we also need to restrict it to be one of (CPU ,GPU, cuda)
accf7e1
to
f74fccc
Compare
--------- Co-authored-by: Jiaming Yuan <jm.yuan@outlook.com>
PR to support gpu transform. @WeichenXu123 @trivialfis Please help to review this PR. Thx