From 217688a0b95ef760841eec12b1ac12ca77b6d752 Mon Sep 17 00:00:00 2001 From: Dmitry Razdoburdin Date: Fri, 11 Oct 2024 14:50:19 +0200 Subject: [PATCH] Fix for multinode (#65) * fix for multinode * one more fix --------- Co-authored-by: Dmitry Razdoburdin <> --- plugin/sycl/device_manager.cc | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/plugin/sycl/device_manager.cc b/plugin/sycl/device_manager.cc index 0ddbf144083b..021ced67ecaf 100644 --- a/plugin/sycl/device_manager.cc +++ b/plugin/sycl/device_manager.cc @@ -20,18 +20,25 @@ ::sycl::device DeviceManager::GetDevice(const DeviceOrd& device_spec) const { (collective::IsDistributed()); if (not_use_default_selector) { DeviceRegister& device_register = GetDevicesRegister(); - const int device_idx = - collective::IsDistributed() ? collective::GetRank() : device_spec.ordinal; if (device_spec.IsSyclDefault()) { auto& devices = device_register.devices; + const int device_idx = collective::IsDistributed() + ? collective::GetRank() % devices.size() + : device_spec.ordinal; CHECK_LT(device_idx, devices.size()); return devices[device_idx]; } else if (device_spec.IsSyclCPU()) { auto& cpu_devices = device_register.cpu_devices; + const int device_idx = collective::IsDistributed() + ? collective::GetRank() % cpu_devices.size() + : device_spec.ordinal; CHECK_LT(device_idx, cpu_devices.size()); return cpu_devices[device_idx]; } else { auto& gpu_devices = device_register.gpu_devices; + const int device_idx = collective::IsDistributed() + ? collective::GetRank() % gpu_devices.size() + : device_spec.ordinal; CHECK_LT(device_idx, gpu_devices.size()); return gpu_devices[device_idx]; } @@ -63,18 +70,25 @@ ::sycl::queue DeviceManager::GetQueue(const DeviceOrd& device_spec) const { std::lock_guard guard(queue_registering_mutex); if (not_use_default_selector) { DeviceRegister& device_register = GetDevicesRegister(); - const int device_idx = - collective::IsDistributed() ? collective::GetRank() : device_spec.ordinal; if (device_spec.IsSyclDefault()) { auto& devices = device_register.devices; + const int device_idx = collective::IsDistributed() + ? collective::GetRank() % devices.size() + : device_spec.ordinal; CHECK_LT(device_idx, devices.size()); queue_register[device_spec.Name()] = ::sycl::queue(devices[device_idx]); } else if (device_spec.IsSyclCPU()) { auto& cpu_devices = device_register.cpu_devices; + const int device_idx = collective::IsDistributed() + ? collective::GetRank() % cpu_devices.size() + : device_spec.ordinal; CHECK_LT(device_idx, cpu_devices.size()); queue_register[device_spec.Name()] = ::sycl::queue(cpu_devices[device_idx]); } else if (device_spec.IsSyclGPU()) { auto& gpu_devices = device_register.gpu_devices; + const int device_idx = collective::IsDistributed() + ? collective::GetRank() % gpu_devices.size() + : device_spec.ordinal; CHECK_LT(device_idx, gpu_devices.size()); queue_register[device_spec.Name()] = ::sycl::queue(gpu_devices[device_idx]); }