Skip to content

Commit

Permalink
Fix for multinode (#65)
Browse files Browse the repository at this point in the history
* fix for multinode

* one more fix

---------

Co-authored-by: Dmitry Razdoburdin <>
  • Loading branch information
razdoburdin authored Oct 11, 2024
1 parent c25ed19 commit 217688a
Showing 1 changed file with 18 additions and 4 deletions.
22 changes: 18 additions & 4 deletions plugin/sycl/device_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,25 @@ ::sycl::device DeviceManager::GetDevice(const DeviceOrd& device_spec) const {
(collective::IsDistributed());
if (not_use_default_selector) {
DeviceRegister& device_register = GetDevicesRegister();
const int device_idx =
collective::IsDistributed() ? collective::GetRank() : device_spec.ordinal;
if (device_spec.IsSyclDefault()) {
auto& devices = device_register.devices;
const int device_idx = collective::IsDistributed()
? collective::GetRank() % devices.size()
: device_spec.ordinal;
CHECK_LT(device_idx, devices.size());
return devices[device_idx];
} else if (device_spec.IsSyclCPU()) {
auto& cpu_devices = device_register.cpu_devices;
const int device_idx = collective::IsDistributed()
? collective::GetRank() % cpu_devices.size()
: device_spec.ordinal;
CHECK_LT(device_idx, cpu_devices.size());
return cpu_devices[device_idx];
} else {
auto& gpu_devices = device_register.gpu_devices;
const int device_idx = collective::IsDistributed()
? collective::GetRank() % gpu_devices.size()
: device_spec.ordinal;
CHECK_LT(device_idx, gpu_devices.size());
return gpu_devices[device_idx];
}
Expand Down Expand Up @@ -63,18 +70,25 @@ ::sycl::queue DeviceManager::GetQueue(const DeviceOrd& device_spec) const {
std::lock_guard<std::mutex> guard(queue_registering_mutex);
if (not_use_default_selector) {
DeviceRegister& device_register = GetDevicesRegister();
const int device_idx =
collective::IsDistributed() ? collective::GetRank() : device_spec.ordinal;
if (device_spec.IsSyclDefault()) {
auto& devices = device_register.devices;
const int device_idx = collective::IsDistributed()
? collective::GetRank() % devices.size()
: device_spec.ordinal;
CHECK_LT(device_idx, devices.size());
queue_register[device_spec.Name()] = ::sycl::queue(devices[device_idx]);
} else if (device_spec.IsSyclCPU()) {
auto& cpu_devices = device_register.cpu_devices;
const int device_idx = collective::IsDistributed()
? collective::GetRank() % cpu_devices.size()
: device_spec.ordinal;
CHECK_LT(device_idx, cpu_devices.size());
queue_register[device_spec.Name()] = ::sycl::queue(cpu_devices[device_idx]);
} else if (device_spec.IsSyclGPU()) {
auto& gpu_devices = device_register.gpu_devices;
const int device_idx = collective::IsDistributed()
? collective::GetRank() % gpu_devices.size()
: device_spec.ordinal;
CHECK_LT(device_idx, gpu_devices.size());
queue_register[device_spec.Name()] = ::sycl::queue(gpu_devices[device_idx]);
}
Expand Down

0 comments on commit 217688a

Please sign in to comment.