From feb0001ac01d7a77aebdb5356b6bd440393ff84a Mon Sep 17 00:00:00 2001 From: "Vinogradov, Sergei" Date: Thu, 17 Oct 2024 08:59:43 -0700 Subject: [PATCH] IPC API implementation for CUDA provider --- src/provider/provider_cuda.c | 109 ++++++++++++++++++++++++++++++++++- 1 file changed, 107 insertions(+), 2 deletions(-) diff --git a/src/provider/provider_cuda.c b/src/provider/provider_cuda.c index 292e85ac7..eaa529c45 100644 --- a/src/provider/provider_cuda.c +++ b/src/provider/provider_cuda.c @@ -53,6 +53,10 @@ typedef struct cu_ops_t { CUresult (*cuGetErrorString)(CUresult error, const char **pStr); CUresult (*cuCtxGetCurrent)(CUcontext *pctx); CUresult (*cuCtxSetCurrent)(CUcontext ctx); + CUresult (*cuIpcGetMemHandle)(CUipcMemHandle *pHandle, CUdeviceptr dptr); + CUresult (*cuIpcOpenMemHandle)(CUdeviceptr *pdptr, CUipcMemHandle handle, + unsigned int Flags); + CUresult (*cuIpcCloseMemHandle)(CUdeviceptr dptr); } cu_ops_t; static cu_ops_t g_cu_ops; @@ -123,12 +127,20 @@ static void init_cu_global_state(void) { utils_get_symbol_addr(0, "cuCtxGetCurrent", lib_name); *(void **)&g_cu_ops.cuCtxSetCurrent = utils_get_symbol_addr(0, "cuCtxSetCurrent", lib_name); + *(void **)&g_cu_ops.cuIpcGetMemHandle = + utils_get_symbol_addr(0, "cuIpcGetMemHandle", lib_name); + *(void **)&g_cu_ops.cuIpcOpenMemHandle = + utils_get_symbol_addr(0, "cuIpcOpenMemHandle_v2", lib_name); + *(void **)&g_cu_ops.cuIpcCloseMemHandle = + utils_get_symbol_addr(0, "cuIpcCloseMemHandle", lib_name); if (!g_cu_ops.cuMemGetAllocationGranularity || !g_cu_ops.cuMemAlloc || !g_cu_ops.cuMemAllocHost || !g_cu_ops.cuMemAllocManaged || !g_cu_ops.cuMemFree || !g_cu_ops.cuMemFreeHost || !g_cu_ops.cuGetErrorName || !g_cu_ops.cuGetErrorString || - !g_cu_ops.cuCtxGetCurrent || !g_cu_ops.cuCtxSetCurrent) { + !g_cu_ops.cuCtxGetCurrent || !g_cu_ops.cuCtxSetCurrent || + !g_cu_ops.cuIpcGetMemHandle || !g_cu_ops.cuIpcOpenMemHandle || + !g_cu_ops.cuIpcCloseMemHandle) { LOG_ERR("Required CUDA symbols not found."); Init_cu_global_state_failed = true; } @@ -396,6 +408,99 @@ static const char *cu_memory_provider_get_name(void *provider) { return "CUDA"; } +typedef CUipcMemHandle cu_ipc_data_t; + +static umf_result_t cu_memory_provider_get_ipc_handle_size(void *provider, + size_t *size) { + if (provider == NULL || size == NULL) { + return UMF_RESULT_ERROR_INVALID_ARGUMENT; + } + + *size = sizeof(cu_ipc_data_t); + return UMF_RESULT_SUCCESS; +} + +static umf_result_t cu_memory_provider_get_ipc_handle(void *provider, + const void *ptr, + size_t size, + void *providerIpcData) { + (void)size; + + if (provider == NULL || ptr == NULL || providerIpcData == NULL) { + return UMF_RESULT_ERROR_INVALID_ARGUMENT; + } + + CUresult cu_result; + cu_ipc_data_t *cu_ipc_data = (cu_ipc_data_t *)providerIpcData; + + cu_result = g_cu_ops.cuIpcGetMemHandle(cu_ipc_data, (CUdeviceptr)ptr); + if (cu_result != CUDA_SUCCESS) { + LOG_ERR("cuIpcGetMemHandle() failed."); + return cu2umf_result(cu_result); + } + + return UMF_RESULT_SUCCESS; +} + +static umf_result_t cu_memory_provider_put_ipc_handle(void *provider, + void *providerIpcData) { + if (provider == NULL || providerIpcData == NULL) { + return UMF_RESULT_ERROR_INVALID_ARGUMENT; + } + + return UMF_RESULT_SUCCESS; +} + +static umf_result_t cu_memory_provider_open_ipc_handle(void *provider, + void *providerIpcData, + void **ptr) { + if (provider == NULL || ptr == NULL || providerIpcData == NULL) { + return UMF_RESULT_ERROR_INVALID_ARGUMENT; + } + + cu_memory_provider_t *cu_provider = (cu_memory_provider_t *)provider; + + CUresult cu_result; + cu_ipc_data_t *cu_ipc_data = (cu_ipc_data_t *)providerIpcData; + + // Remember current context and set the one from the provider + CUcontext restore_ctx = NULL; + umf_result_t umf_result = set_context(cu_provider->context, &restore_ctx); + if (umf_result != UMF_RESULT_SUCCESS) { + return umf_result; + } + + cu_result = g_cu_ops.cuIpcOpenMemHandle((CUdeviceptr *)ptr, *cu_ipc_data, + CU_IPC_MEM_LAZY_ENABLE_PEER_ACCESS); + + if (cu_result != CUDA_SUCCESS) { + LOG_ERR("cuIpcOpenMemHandle() failed."); + } + + set_context(restore_ctx, &restore_ctx); + + return cu2umf_result(cu_result); +} + +static umf_result_t +cu_memory_provider_close_ipc_handle(void *provider, void *ptr, size_t size) { + (void)size; + + if (provider == NULL || ptr == NULL) { + return UMF_RESULT_ERROR_INVALID_ARGUMENT; + } + + CUresult cu_result; + + cu_result = g_cu_ops.cuIpcCloseMemHandle((CUdeviceptr)ptr); + if (cu_result != CUDA_SUCCESS) { + LOG_ERR("cuIpcCloseMemHandle() failed."); + return cu2umf_result(cu_result); + } + + return UMF_RESULT_SUCCESS; +} + static struct umf_memory_provider_ops_t UMF_CUDA_MEMORY_PROVIDER_OPS = { .version = UMF_VERSION_CURRENT, .initialize = cu_memory_provider_initialize, @@ -412,12 +517,12 @@ static struct umf_memory_provider_ops_t UMF_CUDA_MEMORY_PROVIDER_OPS = { .ext.purge_force = cu_memory_provider_purge_force, .ext.allocation_merge = cu_memory_provider_allocation_merge, .ext.allocation_split = cu_memory_provider_allocation_split, + */ .ipc.get_ipc_handle_size = cu_memory_provider_get_ipc_handle_size, .ipc.get_ipc_handle = cu_memory_provider_get_ipc_handle, .ipc.put_ipc_handle = cu_memory_provider_put_ipc_handle, .ipc.open_ipc_handle = cu_memory_provider_open_ipc_handle, .ipc.close_ipc_handle = cu_memory_provider_close_ipc_handle, - */ }; umf_memory_provider_ops_t *umfCUDAMemoryProviderOps(void) {