-
-
Notifications
You must be signed in to change notification settings - Fork 8.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add 'sycl' devices to the context #9691
Changes from 3 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -94,29 +94,45 @@ DeviceOrd CUDAOrdinal(DeviceOrd device, bool) { | |
StringView msg{R"(Invalid argument for `device`. Expected to be one of the following: | ||
- cpu | ||
- cuda | ||
- cuda:<device ordinal> # e.g. cuda:0 | ||
- cuda:<device ordinal> # e.g. cuda:0 | ||
- gpu | ||
- gpu:<device ordinal> # e.g. gpu:0 | ||
- gpu:<device ordinal> # e.g. gpu:0 | ||
- sycl | ||
- sycl:<device ordinal> # e.g. sycl:0 | ||
- sycl:<cpu, gpu> | ||
- sycl:<cpu, gpu>:<device ordinal> # e.g. sycl:gpu:0 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's remove the error message change for now and get it back when there's a public ready feature that can be enabled. At this point, the message is only going to confuse users. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
)"}; | ||
auto fatal = [&] { LOG(FATAL) << msg << "Got: `" << input << "`."; }; | ||
|
||
#if defined(__MINGW32__) | ||
// mingw hangs on regex using rtools 430. Basic checks only. | ||
CHECK_GE(input.size(), 3) << msg; | ||
auto substr = input.substr(0, 3); | ||
bool valid = substr == "cpu" || substr == "cud" || substr == "gpu"; | ||
bool valid = substr == "cpu" || substr == "cud" || substr == "gpu" || substr == "syc"; | ||
CHECK(valid) << msg; | ||
#else | ||
std::regex pattern{"gpu(:[0-9]+)?|cuda(:[0-9]+)?|cpu"}; | ||
std::regex pattern{"gpu(:[0-9]+)?|cuda(:[0-9]+)?|cpu|sycl(:cpu|:gpu)?(:-1|:[0-9]+)?"}; | ||
if (!std::regex_match(input, pattern)) { | ||
fatal(); | ||
} | ||
#endif // defined(__MINGW32__) | ||
|
||
// handle alias | ||
std::string s_device = std::regex_replace(input, std::regex{"gpu"}, DeviceSym::CUDA()); | ||
std::string s_device = input; | ||
if (!std::regex_match(s_device, std::regex("sycl(:cpu|:gpu)?(:-1|:[0-9]+)?"))) | ||
s_device = std::regex_replace(s_device, std::regex{"gpu"}, DeviceSym::CUDA()); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't this mean There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess in the future it should work like you have described. But let's postpone this fix for future. |
||
|
||
auto split_it = std::find(s_device.cbegin(), s_device.cend(), ':'); | ||
if (std::regex_match(s_device, std::regex("sycl:(cpu|gpu)?"))) split_it = s_device.cend(); | ||
|
||
// For s_device like "sycl:gpu:1" | ||
if (split_it != s_device.cend()) { | ||
auto second_split_it = std::find(split_it + 1, s_device.cend(), ':'); | ||
if (second_split_it != s_device.cend()) { | ||
split_it = second_split_it; | ||
} | ||
} | ||
|
||
DeviceOrd device; | ||
device.ordinal = DeviceOrd::InvalidOrdinal(); // mark it invalid for check. | ||
if (split_it == s_device.cend()) { | ||
|
@@ -125,15 +141,22 @@ DeviceOrd CUDAOrdinal(DeviceOrd device, bool) { | |
device = DeviceOrd::CPU(); | ||
} else if (s_device == DeviceSym::CUDA()) { | ||
device = DeviceOrd::CUDA(0); // use 0 as default; | ||
} else if (s_device == DeviceSym::SyclDefault()) { | ||
device = DeviceOrd::SyclDefault(); | ||
} else if (s_device == DeviceSym::SyclCPU()) { | ||
device = DeviceOrd::SyclCPU(); | ||
} else if (s_device == DeviceSym::SyclGPU()) { | ||
device = DeviceOrd::SyclGPU(); | ||
} else { | ||
fatal(); | ||
} | ||
} else { | ||
// must be CUDA when ordinal is specifed. | ||
// must be CUDA or SYCL when ordinal is specifed. | ||
// +1 for colon | ||
std::size_t offset = std::distance(s_device.cbegin(), split_it) + 1; | ||
// substr | ||
StringView s_ordinal = {s_device.data() + offset, s_device.size() - offset}; | ||
StringView s_type = {s_device.data(), offset - 1}; | ||
if (s_ordinal.empty()) { | ||
fatal(); | ||
} | ||
|
@@ -143,13 +166,23 @@ DeviceOrd CUDAOrdinal(DeviceOrd device, bool) { | |
} | ||
CHECK_LE(opt_id.value(), std::numeric_limits<bst_d_ordinal_t>::max()) | ||
<< "Ordinal value too large."; | ||
device = DeviceOrd::CUDA(opt_id.value()); | ||
if (s_type == DeviceSym::SyclDefault()) { | ||
device = DeviceOrd::SyclDefault(opt_id.value()); | ||
} else if (s_type == DeviceSym::SyclCPU()) { | ||
device = DeviceOrd::SyclCPU(opt_id.value()); | ||
} else if (s_type == DeviceSym::SyclGPU()) { | ||
device = DeviceOrd::SyclGPU(opt_id.value()); | ||
} else { | ||
device = DeviceOrd::CUDA(opt_id.value()); | ||
} | ||
} | ||
|
||
if (device.ordinal < DeviceOrd::CPUOrdinal()) { | ||
fatal(); | ||
} | ||
device = CUDAOrdinal(device, fail_on_invalid_gpu_id); | ||
if (device.IsCUDA()) { | ||
device = CUDAOrdinal(device, fail_on_invalid_gpu_id); | ||
} | ||
|
||
return device; | ||
} | ||
|
@@ -216,7 +249,7 @@ void Context::SetDeviceOrdinal(Args const& kwargs) { | |
|
||
if (this->IsCPU()) { | ||
CHECK_EQ(this->device_.ordinal, DeviceOrd::CPUOrdinal()); | ||
} else { | ||
} else if (this->IsCUDA()) { | ||
CHECK_GT(this->device_.ordinal, DeviceOrd::CPUOrdinal()); | ||
} | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is "sync:{ordinal}"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It means using the SYCL device with index
{ordinal}
. It can be CPU or GPU, depending on user's system settings.