Skip to content

Commit

Permalink
[common] Do not use loader APIs in ur_pool_manager
Browse files Browse the repository at this point in the history
Calling loader APIs is incorrect - handles would have
to be translated to and from loader handles.

Also, using loader APIs without explictly linking with
loaders results in linking failure on Windows.

Fix this, by using function pointers.
  • Loading branch information
igchor committed Sep 27, 2024
1 parent 622ce27 commit b00c00e
Showing 1 changed file with 50 additions and 14 deletions.
64 changes: 50 additions & 14 deletions source/common/ur_pool_manager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
#ifndef USM_POOL_MANAGER_HPP
#define USM_POOL_MANAGER_HPP 1

#include <ur_ddi.h>

#include "logger/ur_logger.hpp"
#include "umf_helpers.hpp"
#include "ur_api.h"
Expand All @@ -26,6 +28,26 @@

namespace usm {

namespace detail {
struct ddiTables {
ddiTables() {
auto ret =
urGetDeviceProcAddrTable(UR_API_VERSION_CURRENT, &deviceDdiTable);
if (ret != UR_RESULT_SUCCESS) {
throw ret;
}

ret =
urGetContextProcAddrTable(UR_API_VERSION_CURRENT, &contextDdiTable);
if (ret != UR_RESULT_SUCCESS) {
throw ret;
}
}
ur_device_dditable_t deviceDdiTable;
ur_context_dditable_t contextDdiTable;
};
} // namespace detail

/// @brief describes an internal USM pool instance.
struct pool_descriptor {
ur_usm_pool_handle_t poolHandle;
Expand All @@ -44,9 +66,12 @@ struct pool_descriptor {

static inline std::pair<ur_result_t, std::vector<ur_device_handle_t>>
urGetSubDevices(ur_device_handle_t hDevice) {
static detail::ddiTables ddi;

uint32_t nComputeUnits;
auto ret = urDeviceGetInfo(hDevice, UR_DEVICE_INFO_MAX_COMPUTE_UNITS,
sizeof(nComputeUnits), &nComputeUnits, nullptr);
auto ret = ddi.deviceDdiTable.pfnGetInfo(
hDevice, UR_DEVICE_INFO_MAX_COMPUTE_UNITS, sizeof(nComputeUnits),
&nComputeUnits, nullptr);
if (ret != UR_RESULT_SUCCESS) {
return {ret, {}};
}
Expand All @@ -64,15 +89,16 @@ urGetSubDevices(ur_device_handle_t hDevice) {

// Get the number of devices that will be created
uint32_t deviceCount;
ret = urDevicePartition(hDevice, &properties, 0, nullptr, &deviceCount);
ret = ddi.deviceDdiTable.pfnPartition(hDevice, &properties, 0, nullptr,
&deviceCount);
if (ret != UR_RESULT_SUCCESS) {
return {ret, {}};
}

std::vector<ur_device_handle_t> sub_devices(deviceCount);
ret = urDevicePartition(hDevice, &properties,
static_cast<uint32_t>(sub_devices.size()),
sub_devices.data(), nullptr);
ret = ddi.deviceDdiTable.pfnPartition(
hDevice, &properties, static_cast<uint32_t>(sub_devices.size()),
sub_devices.data(), nullptr);
if (ret != UR_RESULT_SUCCESS) {
return {ret, {}};
}
Expand All @@ -82,17 +108,20 @@ urGetSubDevices(ur_device_handle_t hDevice) {

inline std::pair<ur_result_t, std::vector<ur_device_handle_t>>
urGetAllDevicesAndSubDevices(ur_context_handle_t hContext) {
static detail::ddiTables ddi;

size_t deviceCount = 0;
auto ret = urContextGetInfo(hContext, UR_CONTEXT_INFO_NUM_DEVICES,
sizeof(deviceCount), &deviceCount, nullptr);
auto ret = ddi.contextDdiTable.pfnGetInfo(
hContext, UR_CONTEXT_INFO_NUM_DEVICES, sizeof(deviceCount),
&deviceCount, nullptr);
if (ret != UR_RESULT_SUCCESS || deviceCount == 0) {
return {ret, {}};
}

std::vector<ur_device_handle_t> devices(deviceCount);
ret = urContextGetInfo(hContext, UR_CONTEXT_INFO_DEVICES,
sizeof(ur_device_handle_t) * deviceCount,
devices.data(), nullptr);
ret = ddi.contextDdiTable.pfnGetInfo(
hContext, UR_CONTEXT_INFO_DEVICES,
sizeof(ur_device_handle_t) * deviceCount, devices.data(), nullptr);
if (ret != UR_RESULT_SUCCESS) {
return {ret, {}};
}
Expand Down Expand Up @@ -135,6 +164,8 @@ isSharedAllocationReadOnlyOnDevice(const pool_descriptor &desc) {
}

inline bool pool_descriptor::operator==(const pool_descriptor &other) const {
static usm::detail::ddiTables ddi;

const pool_descriptor &lhs = *this;
const pool_descriptor &rhs = other;
ur_native_handle_t lhsNative = 0, rhsNative = 0;
Expand All @@ -145,14 +176,16 @@ inline bool pool_descriptor::operator==(const pool_descriptor &other) const {
// Ref: https://github.com/intel/llvm/commit/86511c5dc84b5781dcfd828caadcb5cac157eae1
// TODO: is this L0 specific?
if (lhs.hDevice) {
auto ret = urDeviceGetNativeHandle(lhs.hDevice, &lhsNative);
auto ret =
ddi.deviceDdiTable.pfnGetNativeHandle(lhs.hDevice, &lhsNative);
if (ret != UR_RESULT_SUCCESS) {
throw ret;
}
}

if (rhs.hDevice) {
auto ret = urDeviceGetNativeHandle(rhs.hDevice, &rhsNative);
auto ret =
ddi.deviceDdiTable.pfnGetNativeHandle(rhs.hDevice, &rhsNative);
if (ret != UR_RESULT_SUCCESS) {
throw ret;
}
Expand Down Expand Up @@ -264,9 +297,12 @@ namespace std {
/// @brief hash specialization for usm::pool_descriptor
template <> struct hash<usm::pool_descriptor> {
inline size_t operator()(const usm::pool_descriptor &desc) const {
static usm::detail::ddiTables ddi;

ur_native_handle_t native = 0;
if (desc.hDevice) {
auto ret = urDeviceGetNativeHandle(desc.hDevice, &native);
auto ret =
ddi.deviceDdiTable.pfnGetNativeHandle(desc.hDevice, &native);
if (ret != UR_RESULT_SUCCESS) {
throw ret;
}
Expand Down

0 comments on commit b00c00e

Please sign in to comment.