Skip to content

Commit

Permalink
Add 'sycl' devices to the context (#9691)
Browse files Browse the repository at this point in the history
Co-authored-by: Dmitry Razdoburdin <>
  • Loading branch information
razdoburdin authored Oct 26, 2023
1 parent d4d7097 commit f41a08f
Show file tree
Hide file tree
Showing 3 changed files with 207 additions and 9 deletions.
79 changes: 77 additions & 2 deletions include/xgboost/context.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ struct CUDAContext;
struct DeviceSym {
static auto constexpr CPU() { return "cpu"; }
static auto constexpr CUDA() { return "cuda"; }
static auto constexpr SyclDefault() { return "sycl"; }
static auto constexpr SyclCPU() { return "sycl:cpu"; }
static auto constexpr SyclGPU() { return "sycl:gpu"; }
};

/**
Expand All @@ -33,12 +36,19 @@ struct DeviceOrd {
static bst_d_ordinal_t constexpr CPUOrdinal() { return -1; }
static bst_d_ordinal_t constexpr InvalidOrdinal() { return -2; }

enum Type : std::int16_t { kCPU = 0, kCUDA = 1 } device{kCPU};
// CUDA device ordinal.
enum Type : std::int16_t { kCPU = 0, kCUDA = 1,
kSyclDefault = 2, kSyclCPU = 3, kSyclGPU = 4} device{kCPU};
// CUDA or Sycl device ordinal.
bst_d_ordinal_t ordinal{CPUOrdinal()};

[[nodiscard]] bool IsCUDA() const { return device == kCUDA; }
[[nodiscard]] bool IsCPU() const { return device == kCPU; }
[[nodiscard]] bool IsSyclDefault() const { return device == kSyclDefault; }
[[nodiscard]] bool IsSyclCPU() const { return device == kSyclCPU; }
[[nodiscard]] bool IsSyclGPU() const { return device == kSyclGPU; }
[[nodiscard]] bool IsSycl() const { return (IsSyclDefault() ||
IsSyclCPU() ||
IsSyclGPU()); }

constexpr DeviceOrd() = default;
constexpr DeviceOrd(Type type, bst_d_ordinal_t ord) : device{type}, ordinal{ord} {}
Expand All @@ -60,6 +70,31 @@ struct DeviceOrd {
[[nodiscard]] static constexpr auto CUDA(bst_d_ordinal_t ordinal) {
return DeviceOrd{kCUDA, ordinal};
}
/**
* @brief Constructor for SYCL.
*
* @param ordinal SYCL device ordinal.
*/
[[nodiscard]] constexpr static auto SyclDefault(bst_d_ordinal_t ordinal = -1) {
return DeviceOrd{kSyclDefault, ordinal};
}
/**
* @brief Constructor for SYCL CPU.
*
* @param ordinal SYCL CPU device ordinal.
*/
[[nodiscard]] constexpr static auto SyclCPU(bst_d_ordinal_t ordinal = -1) {
return DeviceOrd{kSyclCPU, ordinal};
}

/**
* @brief Constructor for SYCL GPU.
*
* @param ordinal SYCL GPU device ordinal.
*/
[[nodiscard]] constexpr static auto SyclGPU(bst_d_ordinal_t ordinal = -1) {
return DeviceOrd{kSyclGPU, ordinal};
}

[[nodiscard]] bool operator==(DeviceOrd const& that) const {
return device == that.device && ordinal == that.ordinal;
Expand All @@ -74,6 +109,12 @@ struct DeviceOrd {
return DeviceSym::CPU();
case DeviceOrd::kCUDA:
return DeviceSym::CUDA() + (':' + std::to_string(ordinal));
case DeviceOrd::kSyclDefault:
return DeviceSym::SyclDefault() + (':' + std::to_string(ordinal));
case DeviceOrd::kSyclCPU:
return DeviceSym::SyclCPU() + (':' + std::to_string(ordinal));
case DeviceOrd::kSyclGPU:
return DeviceSym::SyclGPU() + (':' + std::to_string(ordinal));
default: {
LOG(FATAL) << "Unknown device.";
return "";
Expand Down Expand Up @@ -142,6 +183,25 @@ struct Context : public XGBoostParameter<Context> {
* @brief Is XGBoost running on a CUDA device?
*/
[[nodiscard]] bool IsCUDA() const { return Device().IsCUDA(); }
/**
* @brief Is XGBoost running on the default SYCL device?
*/
[[nodiscard]] bool IsSyclDefault() const { return Device().IsSyclDefault(); }
/**
* @brief Is XGBoost running on a SYCL CPU?
*/
[[nodiscard]] bool IsSyclCPU() const { return Device().IsSyclCPU(); }
/**
* @brief Is XGBoost running on a SYCL GPU?
*/
[[nodiscard]] bool IsSyclGPU() const { return Device().IsSyclGPU(); }
/**
* @brief Is XGBoost running on any SYCL device?
*/
[[nodiscard]] bool IsSycl() const { return IsSyclDefault()
|| IsSyclCPU()
|| IsSyclGPU(); }

/**
* @brief Get the current device and ordinal.
*/
Expand Down Expand Up @@ -175,6 +235,7 @@ struct Context : public XGBoostParameter<Context> {
Context ctx = *this;
return ctx.SetDevice(DeviceOrd::CPU());
}

/**
* @brief Call function based on the current device.
*/
Expand All @@ -196,6 +257,20 @@ struct Context : public XGBoostParameter<Context> {
return std::invoke_result_t<CPUFn>();
}

/**
* @brief Call function for sycl devices
*/
template <typename CPUFn, typename CUDAFn, typename SYCLFn>
decltype(auto) DispatchDevice(CPUFn&& cpu_fn, CUDAFn&& cuda_fn, SYCLFn&& sycl_fn) const {
static_assert(std::is_same_v<std::invoke_result_t<CPUFn>, std::invoke_result_t<CUDAFn>>);
static_assert(std::is_same_v<std::invoke_result_t<CPUFn>, std::invoke_result_t<SYCLFn>>);
if (this->Device().IsSycl()) {
return sycl_fn();
} else {
return DispatchDevice(cpu_fn, cuda_fn);
}
}

// declare parameters
DMLC_DECLARE_PARAMETER(Context) {
DMLC_DECLARE_FIELD(seed)
Expand Down
44 changes: 37 additions & 7 deletions src/context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -104,19 +104,32 @@ DeviceOrd CUDAOrdinal(DeviceOrd device, bool) {
// mingw hangs on regex using rtools 430. Basic checks only.
CHECK_GE(input.size(), 3) << msg;
auto substr = input.substr(0, 3);
bool valid = substr == "cpu" || substr == "cud" || substr == "gpu";
bool valid = substr == "cpu" || substr == "cud" || substr == "gpu" || substr == "syc";
CHECK(valid) << msg;
#else
std::regex pattern{"gpu(:[0-9]+)?|cuda(:[0-9]+)?|cpu"};
std::regex pattern{"gpu(:[0-9]+)?|cuda(:[0-9]+)?|cpu|sycl(:cpu|:gpu)?(:-1|:[0-9]+)?"};
if (!std::regex_match(input, pattern)) {
fatal();
}
#endif // defined(__MINGW32__)

// handle alias
std::string s_device = std::regex_replace(input, std::regex{"gpu"}, DeviceSym::CUDA());
std::string s_device = input;
if (!std::regex_match(s_device, std::regex("sycl(:cpu|:gpu)?(:-1|:[0-9]+)?"))) {
s_device = std::regex_replace(s_device, std::regex{"gpu"}, DeviceSym::CUDA());
}

auto split_it = std::find(s_device.cbegin(), s_device.cend(), ':');
if (std::regex_match(s_device, std::regex("sycl:(cpu|gpu)?"))) split_it = s_device.cend();

// For s_device like "sycl:gpu:1"
if (split_it != s_device.cend()) {
auto second_split_it = std::find(split_it + 1, s_device.cend(), ':');
if (second_split_it != s_device.cend()) {
split_it = second_split_it;
}
}

DeviceOrd device;
device.ordinal = DeviceOrd::InvalidOrdinal(); // mark it invalid for check.
if (split_it == s_device.cend()) {
Expand All @@ -125,15 +138,22 @@ DeviceOrd CUDAOrdinal(DeviceOrd device, bool) {
device = DeviceOrd::CPU();
} else if (s_device == DeviceSym::CUDA()) {
device = DeviceOrd::CUDA(0); // use 0 as default;
} else if (s_device == DeviceSym::SyclDefault()) {
device = DeviceOrd::SyclDefault();
} else if (s_device == DeviceSym::SyclCPU()) {
device = DeviceOrd::SyclCPU();
} else if (s_device == DeviceSym::SyclGPU()) {
device = DeviceOrd::SyclGPU();
} else {
fatal();
}
} else {
// must be CUDA when ordinal is specifed.
// must be CUDA or SYCL when ordinal is specifed.
// +1 for colon
std::size_t offset = std::distance(s_device.cbegin(), split_it) + 1;
// substr
StringView s_ordinal = {s_device.data() + offset, s_device.size() - offset};
StringView s_type = {s_device.data(), offset - 1};
if (s_ordinal.empty()) {
fatal();
}
Expand All @@ -143,13 +163,23 @@ DeviceOrd CUDAOrdinal(DeviceOrd device, bool) {
}
CHECK_LE(opt_id.value(), std::numeric_limits<bst_d_ordinal_t>::max())
<< "Ordinal value too large.";
device = DeviceOrd::CUDA(opt_id.value());
if (s_type == DeviceSym::SyclDefault()) {
device = DeviceOrd::SyclDefault(opt_id.value());
} else if (s_type == DeviceSym::SyclCPU()) {
device = DeviceOrd::SyclCPU(opt_id.value());
} else if (s_type == DeviceSym::SyclGPU()) {
device = DeviceOrd::SyclGPU(opt_id.value());
} else {
device = DeviceOrd::CUDA(opt_id.value());
}
}

if (device.ordinal < DeviceOrd::CPUOrdinal()) {
fatal();
}
device = CUDAOrdinal(device, fail_on_invalid_gpu_id);
if (device.IsCUDA()) {
device = CUDAOrdinal(device, fail_on_invalid_gpu_id);
}

return device;
}
Expand Down Expand Up @@ -216,7 +246,7 @@ void Context::SetDeviceOrdinal(Args const& kwargs) {

if (this->IsCPU()) {
CHECK_EQ(this->device_.ordinal, DeviceOrd::CPUOrdinal());
} else {
} else if (this->IsCUDA()) {
CHECK_GT(this->device_.ordinal, DeviceOrd::CPUOrdinal());
}
}
Expand Down
93 changes: 93 additions & 0 deletions tests/cpp/test_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,97 @@ TEST(Context, ErrorInit) {
ASSERT_NE(msg.find("foo"), std::string::npos);
}
}

TEST(Context, SYCL) {
Context ctx;
// Default SYCL device
{
ctx.UpdateAllowUnknown(Args{{"device", "sycl"}});
ASSERT_EQ(ctx.Device(), DeviceOrd::SyclDefault());
ASSERT_EQ(ctx.Ordinal(), -1);

std::int32_t flag{0};
ctx.DispatchDevice([&] { flag = -1; }, [&] { flag = 1; }, [&] { flag = 2; });
ASSERT_EQ(flag, 2);

std::stringstream ss;
ss << ctx.Device();
ASSERT_EQ(ss.str(), "sycl:-1");
}

// SYCL device with idx
{
ctx.UpdateAllowUnknown(Args{{"device", "sycl:42"}});
ASSERT_EQ(ctx.Device(), DeviceOrd::SyclDefault(42));
ASSERT_EQ(ctx.Ordinal(), 42);

std::int32_t flag{0};
ctx.DispatchDevice([&] { flag = -1; }, [&] { flag = 1; }, [&] { flag = 2; });
ASSERT_EQ(flag, 2);

std::stringstream ss;
ss << ctx.Device();
ASSERT_EQ(ss.str(), "sycl:42");
}

// SYCL cpu
{
ctx.UpdateAllowUnknown(Args{{"device", "sycl:cpu"}});
ASSERT_EQ(ctx.Device(), DeviceOrd::SyclCPU());
ASSERT_EQ(ctx.Ordinal(), -1);

std::int32_t flag{0};
ctx.DispatchDevice([&] { flag = -1; }, [&] { flag = 1; }, [&] { flag = 2; });
ASSERT_EQ(flag, 2);

std::stringstream ss;
ss << ctx.Device();
ASSERT_EQ(ss.str(), "sycl:cpu:-1");
}

// SYCL cpu with idx
{
ctx.UpdateAllowUnknown(Args{{"device", "sycl:cpu:42"}});
ASSERT_EQ(ctx.Device(), DeviceOrd::SyclCPU(42));
ASSERT_EQ(ctx.Ordinal(), 42);

std::int32_t flag{0};
ctx.DispatchDevice([&] { flag = -1; }, [&] { flag = 1; }, [&] { flag = 2; });
ASSERT_EQ(flag, 2);

std::stringstream ss;
ss << ctx.Device();
ASSERT_EQ(ss.str(), "sycl:cpu:42");
}

// SYCL gpu
{
ctx.UpdateAllowUnknown(Args{{"device", "sycl:gpu"}});
ASSERT_EQ(ctx.Device(), DeviceOrd::SyclGPU());
ASSERT_EQ(ctx.Ordinal(), -1);

std::int32_t flag{0};
ctx.DispatchDevice([&] { flag = -1; }, [&] { flag = 1; }, [&] { flag = 2; });
ASSERT_EQ(flag, 2);

std::stringstream ss;
ss << ctx.Device();
ASSERT_EQ(ss.str(), "sycl:gpu:-1");
}

// SYCL gpu with idx
{
ctx.UpdateAllowUnknown(Args{{"device", "sycl:gpu:42"}});
ASSERT_EQ(ctx.Device(), DeviceOrd::SyclGPU(42));
ASSERT_EQ(ctx.Ordinal(), 42);

std::int32_t flag{0};
ctx.DispatchDevice([&] { flag = -1; }, [&] { flag = 1; }, [&] { flag = 2; });
ASSERT_EQ(flag, 2);

std::stringstream ss;
ss << ctx.Device();
ASSERT_EQ(ss.str(), "sycl:gpu:42");
}
}
} // namespace xgboost

0 comments on commit f41a08f

Please sign in to comment.