Skip to content

Commit

Permalink
IPC API implementation for CUDA provider
Browse files Browse the repository at this point in the history
  • Loading branch information
vinser52 committed Oct 18, 2024
1 parent 11c6408 commit feb0001
Showing 1 changed file with 107 additions and 2 deletions.
109 changes: 107 additions & 2 deletions src/provider/provider_cuda.c
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ typedef struct cu_ops_t {
CUresult (*cuGetErrorString)(CUresult error, const char **pStr);
CUresult (*cuCtxGetCurrent)(CUcontext *pctx);
CUresult (*cuCtxSetCurrent)(CUcontext ctx);
CUresult (*cuIpcGetMemHandle)(CUipcMemHandle *pHandle, CUdeviceptr dptr);
CUresult (*cuIpcOpenMemHandle)(CUdeviceptr *pdptr, CUipcMemHandle handle,
unsigned int Flags);
CUresult (*cuIpcCloseMemHandle)(CUdeviceptr dptr);
} cu_ops_t;

static cu_ops_t g_cu_ops;
Expand Down Expand Up @@ -123,12 +127,20 @@ static void init_cu_global_state(void) {
utils_get_symbol_addr(0, "cuCtxGetCurrent", lib_name);
*(void **)&g_cu_ops.cuCtxSetCurrent =
utils_get_symbol_addr(0, "cuCtxSetCurrent", lib_name);
*(void **)&g_cu_ops.cuIpcGetMemHandle =
utils_get_symbol_addr(0, "cuIpcGetMemHandle", lib_name);
*(void **)&g_cu_ops.cuIpcOpenMemHandle =
utils_get_symbol_addr(0, "cuIpcOpenMemHandle_v2", lib_name);
*(void **)&g_cu_ops.cuIpcCloseMemHandle =
utils_get_symbol_addr(0, "cuIpcCloseMemHandle", lib_name);

if (!g_cu_ops.cuMemGetAllocationGranularity || !g_cu_ops.cuMemAlloc ||
!g_cu_ops.cuMemAllocHost || !g_cu_ops.cuMemAllocManaged ||
!g_cu_ops.cuMemFree || !g_cu_ops.cuMemFreeHost ||
!g_cu_ops.cuGetErrorName || !g_cu_ops.cuGetErrorString ||
!g_cu_ops.cuCtxGetCurrent || !g_cu_ops.cuCtxSetCurrent) {
!g_cu_ops.cuCtxGetCurrent || !g_cu_ops.cuCtxSetCurrent ||
!g_cu_ops.cuIpcGetMemHandle || !g_cu_ops.cuIpcOpenMemHandle ||
!g_cu_ops.cuIpcCloseMemHandle) {
LOG_ERR("Required CUDA symbols not found.");
Init_cu_global_state_failed = true;
}
Expand Down Expand Up @@ -396,6 +408,99 @@ static const char *cu_memory_provider_get_name(void *provider) {
return "CUDA";
}

typedef CUipcMemHandle cu_ipc_data_t;

static umf_result_t cu_memory_provider_get_ipc_handle_size(void *provider,
size_t *size) {
if (provider == NULL || size == NULL) {
return UMF_RESULT_ERROR_INVALID_ARGUMENT;
}

*size = sizeof(cu_ipc_data_t);
return UMF_RESULT_SUCCESS;
}

static umf_result_t cu_memory_provider_get_ipc_handle(void *provider,
const void *ptr,
size_t size,
void *providerIpcData) {
(void)size;

if (provider == NULL || ptr == NULL || providerIpcData == NULL) {
return UMF_RESULT_ERROR_INVALID_ARGUMENT;
}

CUresult cu_result;
cu_ipc_data_t *cu_ipc_data = (cu_ipc_data_t *)providerIpcData;

cu_result = g_cu_ops.cuIpcGetMemHandle(cu_ipc_data, (CUdeviceptr)ptr);
if (cu_result != CUDA_SUCCESS) {
LOG_ERR("cuIpcGetMemHandle() failed.");
return cu2umf_result(cu_result);
}

return UMF_RESULT_SUCCESS;
}

static umf_result_t cu_memory_provider_put_ipc_handle(void *provider,
void *providerIpcData) {
if (provider == NULL || providerIpcData == NULL) {
return UMF_RESULT_ERROR_INVALID_ARGUMENT;
}

return UMF_RESULT_SUCCESS;
}

static umf_result_t cu_memory_provider_open_ipc_handle(void *provider,
void *providerIpcData,
void **ptr) {
if (provider == NULL || ptr == NULL || providerIpcData == NULL) {
return UMF_RESULT_ERROR_INVALID_ARGUMENT;
}

cu_memory_provider_t *cu_provider = (cu_memory_provider_t *)provider;

CUresult cu_result;
cu_ipc_data_t *cu_ipc_data = (cu_ipc_data_t *)providerIpcData;

// Remember current context and set the one from the provider
CUcontext restore_ctx = NULL;
umf_result_t umf_result = set_context(cu_provider->context, &restore_ctx);
if (umf_result != UMF_RESULT_SUCCESS) {
return umf_result;
}

cu_result = g_cu_ops.cuIpcOpenMemHandle((CUdeviceptr *)ptr, *cu_ipc_data,
CU_IPC_MEM_LAZY_ENABLE_PEER_ACCESS);

if (cu_result != CUDA_SUCCESS) {
LOG_ERR("cuIpcOpenMemHandle() failed.");
}

set_context(restore_ctx, &restore_ctx);

return cu2umf_result(cu_result);
}

static umf_result_t
cu_memory_provider_close_ipc_handle(void *provider, void *ptr, size_t size) {
(void)size;

if (provider == NULL || ptr == NULL) {
return UMF_RESULT_ERROR_INVALID_ARGUMENT;
}

CUresult cu_result;

cu_result = g_cu_ops.cuIpcCloseMemHandle((CUdeviceptr)ptr);
if (cu_result != CUDA_SUCCESS) {
LOG_ERR("cuIpcCloseMemHandle() failed.");
return cu2umf_result(cu_result);
}

return UMF_RESULT_SUCCESS;
}

static struct umf_memory_provider_ops_t UMF_CUDA_MEMORY_PROVIDER_OPS = {
.version = UMF_VERSION_CURRENT,
.initialize = cu_memory_provider_initialize,
Expand All @@ -412,12 +517,12 @@ static struct umf_memory_provider_ops_t UMF_CUDA_MEMORY_PROVIDER_OPS = {
.ext.purge_force = cu_memory_provider_purge_force,
.ext.allocation_merge = cu_memory_provider_allocation_merge,
.ext.allocation_split = cu_memory_provider_allocation_split,
*/
.ipc.get_ipc_handle_size = cu_memory_provider_get_ipc_handle_size,
.ipc.get_ipc_handle = cu_memory_provider_get_ipc_handle,
.ipc.put_ipc_handle = cu_memory_provider_put_ipc_handle,
.ipc.open_ipc_handle = cu_memory_provider_open_ipc_handle,
.ipc.close_ipc_handle = cu_memory_provider_close_ipc_handle,
*/
};

umf_memory_provider_ops_t *umfCUDAMemoryProviderOps(void) {
Expand Down

0 comments on commit feb0001

Please sign in to comment.