-
-
Notifications
You must be signed in to change notification settings - Fork 8.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add 'sycl' devices to the context #9691
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,6 +22,9 @@ struct CUDAContext; | |
struct DeviceSym { | ||
static auto constexpr CPU() { return "cpu"; } | ||
static auto constexpr CUDA() { return "cuda"; } | ||
static auto constexpr SYCL_default() { return "sycl"; } | ||
static auto constexpr SYCL_CPU() { return "sycl:cpu"; } | ||
static auto constexpr SYCL_GPU() { return "sycl:gpu"; } | ||
}; | ||
|
||
/** | ||
|
@@ -33,12 +36,19 @@ struct DeviceOrd { | |
static bst_d_ordinal_t constexpr CPUOrdinal() { return -1; } | ||
static bst_d_ordinal_t constexpr InvalidOrdinal() { return -2; } | ||
|
||
enum Type : std::int16_t { kCPU = 0, kCUDA = 1 } device{kCPU}; | ||
// CUDA device ordinal. | ||
enum Type : std::int16_t { kCPU = 0, kCUDA = 1, | ||
kSyclDefault = 2, kSyclCPU = 3, kSyclGPU = 4} device{kCPU}; | ||
// CUDA or Sycl device ordinal. | ||
bst_d_ordinal_t ordinal{CPUOrdinal()}; | ||
|
||
[[nodiscard]] bool IsCUDA() const { return device == kCUDA; } | ||
[[nodiscard]] bool IsCPU() const { return device == kCPU; } | ||
[[nodiscard]] bool IsSyclDefault() const { return device == kSyclDefault; } | ||
[[nodiscard]] bool IsSyclCPU() const { return device == kSyclCPU; } | ||
[[nodiscard]] bool IsSyclGPU() const { return device == kSyclGPU; } | ||
[[nodiscard]] bool IsSycl() const { return (IsSyclDefault() || | ||
IsSyclCPU() || | ||
IsSyclGPU()); } | ||
|
||
constexpr DeviceOrd() = default; | ||
constexpr DeviceOrd(Type type, bst_d_ordinal_t ord) : device{type}, ordinal{ord} {} | ||
|
@@ -60,6 +70,31 @@ struct DeviceOrd { | |
[[nodiscard]] static constexpr auto CUDA(bst_d_ordinal_t ordinal) { | ||
return DeviceOrd{kCUDA, ordinal}; | ||
} | ||
/** | ||
* @brief Constructor for SYCL. | ||
* | ||
* @param ordinal SYCL device ordinal. | ||
*/ | ||
[[nodiscard]] constexpr static auto SYCL_default(bst_d_ordinal_t ordinal = -1) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please follow the function naming convention "SYCLDefault"/"SyclDefault" There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
return DeviceOrd{kSyclDefault, ordinal}; | ||
} | ||
/** | ||
* @brief Constructor for SYCL CPU. | ||
* | ||
* @param ordinal SYCL CPU device ordinal. | ||
*/ | ||
[[nodiscard]] constexpr static auto SYCL_CPU(bst_d_ordinal_t ordinal = -1) { | ||
return DeviceOrd{kSyclCPU, ordinal}; | ||
} | ||
|
||
/** | ||
* @brief Constructor for SYCL GPU. | ||
* | ||
* @param ordinal SYCL GPU device ordinal. | ||
*/ | ||
[[nodiscard]] constexpr static auto SYCL_GPU(bst_d_ordinal_t ordinal = -1) { | ||
return DeviceOrd{kSyclGPU, ordinal}; | ||
} | ||
|
||
[[nodiscard]] bool operator==(DeviceOrd const& that) const { | ||
return device == that.device && ordinal == that.ordinal; | ||
|
@@ -74,6 +109,12 @@ struct DeviceOrd { | |
return DeviceSym::CPU(); | ||
case DeviceOrd::kCUDA: | ||
return DeviceSym::CUDA() + (':' + std::to_string(ordinal)); | ||
case DeviceOrd::kSyclDefault: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is "sync:{ordinal}"? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It means using the SYCL device with index |
||
return DeviceSym::SYCL_default() + (':' + std::to_string(ordinal)); | ||
case DeviceOrd::kSyclCPU: | ||
return DeviceSym::SYCL_CPU() + (':' + std::to_string(ordinal)); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is "sycl:cpu:ordinal" There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It means using CPU with the specific index for multi CPU systems. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. multi-socket or multi-core? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. multi-socket |
||
case DeviceOrd::kSyclGPU: | ||
return DeviceSym::SYCL_GPU() + (':' + std::to_string(ordinal)); | ||
default: { | ||
LOG(FATAL) << "Unknown device."; | ||
return ""; | ||
|
@@ -142,6 +183,25 @@ struct Context : public XGBoostParameter<Context> { | |
* @brief Is XGBoost running on a CUDA device? | ||
*/ | ||
[[nodiscard]] bool IsCUDA() const { return Device().IsCUDA(); } | ||
/** | ||
* @brief Is XGBoost running on the default SYCL device? | ||
*/ | ||
[[nodiscard]] bool IsSyclDefault() const { return Device().IsSyclDefault(); } | ||
/** | ||
* @brief Is XGBoost running on a SYCL CPU? | ||
*/ | ||
[[nodiscard]] bool IsSyclCPU() const { return Device().IsSyclCPU(); } | ||
/** | ||
* @brief Is XGBoost running on a SYCL GPU? | ||
*/ | ||
[[nodiscard]] bool IsSyclGPU() const { return Device().IsSyclGPU(); } | ||
/** | ||
* @brief Is XGBoost running on any SYCL device? | ||
*/ | ||
[[nodiscard]] bool IsSycl() const { return IsSyclDefault() | ||
|| IsSyclCPU() | ||
|| IsSyclGPU(); } | ||
|
||
/** | ||
* @brief Get the current device and ordinal. | ||
*/ | ||
|
@@ -175,6 +235,7 @@ struct Context : public XGBoostParameter<Context> { | |
Context ctx = *this; | ||
return ctx.SetDevice(DeviceOrd::CPU()); | ||
} | ||
|
||
/** | ||
* @brief Call function based on the current device. | ||
*/ | ||
|
@@ -196,6 +257,20 @@ struct Context : public XGBoostParameter<Context> { | |
return std::invoke_result_t<CPUFn>(); | ||
} | ||
|
||
/** | ||
* @brief Call function for sycl devices | ||
*/ | ||
template <typename CPUFn, typename CUDAFn, typename SYCLFn> | ||
decltype(auto) DispatchDevice(CPUFn&& cpu_fn, CUDAFn&& cuda_fn, SYCLFn&& sycl_fn) const { | ||
static_assert(std::is_same_v<std::invoke_result_t<CPUFn>, std::invoke_result_t<CUDAFn>>); | ||
static_assert(std::is_same_v<std::invoke_result_t<CPUFn>, std::invoke_result_t<SYCLFn>>); | ||
if (this->Device().IsSycl()) { | ||
return sycl_fn(); | ||
} else { | ||
return DispatchDevice(cpu_fn, cuda_fn); | ||
} | ||
} | ||
|
||
// declare parameters | ||
DMLC_DECLARE_PARAMETER(Context) { | ||
DMLC_DECLARE_FIELD(seed) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -94,29 +94,45 @@ DeviceOrd CUDAOrdinal(DeviceOrd device, bool) { | |
StringView msg{R"(Invalid argument for `device`. Expected to be one of the following: | ||
- cpu | ||
- cuda | ||
- cuda:<device ordinal> # e.g. cuda:0 | ||
- cuda:<device ordinal> # e.g. cuda:0 | ||
- gpu | ||
- gpu:<device ordinal> # e.g. gpu:0 | ||
- gpu:<device ordinal> # e.g. gpu:0 | ||
- sycl | ||
- sycl:<device ordinal> # e.g. sycl:0 | ||
- sycl:<cpu, gpu> | ||
- sycl:<cpu, gpu>:<device ordinal> # e.g. sycl:gpu:0 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's remove the error message change for now and get it back when there's a public ready feature that can be enabled. At this point, the message is only going to confuse users. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
)"}; | ||
auto fatal = [&] { LOG(FATAL) << msg << "Got: `" << input << "`."; }; | ||
|
||
#if defined(__MINGW32__) | ||
// mingw hangs on regex using rtools 430. Basic checks only. | ||
CHECK_GE(input.size(), 3) << msg; | ||
auto substr = input.substr(0, 3); | ||
bool valid = substr == "cpu" || substr == "cud" || substr == "gpu"; | ||
bool valid = substr == "cpu" || substr == "cud" || substr == "gpu" || substr == "syc"; | ||
CHECK(valid) << msg; | ||
#else | ||
std::regex pattern{"gpu(:[0-9]+)?|cuda(:[0-9]+)?|cpu"}; | ||
std::regex pattern{"gpu(:[0-9]+)?|cuda(:[0-9]+)?|cpu|sycl(:cpu|:gpu)?(:-1|:[0-9]+)?"}; | ||
if (!std::regex_match(input, pattern)) { | ||
fatal(); | ||
} | ||
#endif // defined(__MINGW32__) | ||
|
||
// handle alias | ||
std::string s_device = std::regex_replace(input, std::regex{"gpu"}, DeviceSym::CUDA()); | ||
std::string s_device = input; | ||
if (!std::regex_match(s_device, std::regex("sycl(:cpu|:gpu)?(:-1|:[0-9]+)?"))) | ||
s_device = std::regex_replace(s_device, std::regex{"gpu"}, DeviceSym::CUDA()); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't this mean There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess in the future it should work like you have described. But let's postpone this fix for future. |
||
|
||
auto split_it = std::find(s_device.cbegin(), s_device.cend(), ':'); | ||
if (std::regex_match(s_device, std::regex("sycl:(cpu|gpu)?"))) split_it = s_device.cend(); | ||
|
||
// For s_device like "sycl:gpu:1" | ||
if (split_it != s_device.cend()) { | ||
auto second_split_it = std::find(split_it + 1, s_device.cend(), ':'); | ||
if (second_split_it != s_device.cend()) { | ||
split_it = second_split_it; | ||
} | ||
} | ||
|
||
DeviceOrd device; | ||
device.ordinal = DeviceOrd::InvalidOrdinal(); // mark it invalid for check. | ||
if (split_it == s_device.cend()) { | ||
|
@@ -125,15 +141,22 @@ DeviceOrd CUDAOrdinal(DeviceOrd device, bool) { | |
device = DeviceOrd::CPU(); | ||
} else if (s_device == DeviceSym::CUDA()) { | ||
device = DeviceOrd::CUDA(0); // use 0 as default; | ||
} else if (s_device == DeviceSym::SYCL_default()) { | ||
device = DeviceOrd::SYCL_default(); | ||
} else if (s_device == DeviceSym::SYCL_CPU()) { | ||
device = DeviceOrd::SYCL_CPU(); | ||
} else if (s_device == DeviceSym::SYCL_GPU()) { | ||
device = DeviceOrd::SYCL_GPU(); | ||
} else { | ||
fatal(); | ||
} | ||
} else { | ||
// must be CUDA when ordinal is specifed. | ||
// must be CUDA or SYCL when ordinal is specifed. | ||
// +1 for colon | ||
std::size_t offset = std::distance(s_device.cbegin(), split_it) + 1; | ||
// substr | ||
StringView s_ordinal = {s_device.data() + offset, s_device.size() - offset}; | ||
StringView s_type = {s_device.data(), offset - 1}; | ||
if (s_ordinal.empty()) { | ||
fatal(); | ||
} | ||
|
@@ -143,13 +166,23 @@ DeviceOrd CUDAOrdinal(DeviceOrd device, bool) { | |
} | ||
CHECK_LE(opt_id.value(), std::numeric_limits<bst_d_ordinal_t>::max()) | ||
<< "Ordinal value too large."; | ||
device = DeviceOrd::CUDA(opt_id.value()); | ||
if (s_type == DeviceSym::SYCL_default()) { | ||
device = DeviceOrd::SYCL_default(opt_id.value()); | ||
} else if (s_type == DeviceSym::SYCL_CPU()) { | ||
device = DeviceOrd::SYCL_CPU(opt_id.value()); | ||
} else if (s_type == DeviceSym::SYCL_GPU()) { | ||
device = DeviceOrd::SYCL_GPU(opt_id.value()); | ||
} else { | ||
device = DeviceOrd::CUDA(opt_id.value()); | ||
} | ||
} | ||
|
||
if (device.ordinal < DeviceOrd::CPUOrdinal()) { | ||
fatal(); | ||
} | ||
device = CUDAOrdinal(device, fail_on_invalid_gpu_id); | ||
if (device.IsCUDA()) { | ||
device = CUDAOrdinal(device, fail_on_invalid_gpu_id); | ||
} | ||
|
||
return device; | ||
} | ||
|
@@ -216,7 +249,7 @@ void Context::SetDeviceOrdinal(Args const& kwargs) { | |
|
||
if (this->IsCPU()) { | ||
CHECK_EQ(this->device_.ordinal, DeviceOrd::CPUOrdinal()); | ||
} else { | ||
} else if (this->IsCUDA()) { | ||
CHECK_GT(this->device_.ordinal, DeviceOrd::CPUOrdinal()); | ||
} | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So, what's the default by specification?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SYCL try to run on GPU. If GPU isn't available for some reason, it launches on CPU.