Skip to content

Commit

Permalink
Merge pull request #1565 from hdelan/cuda-multi-dev-ctx
Browse files Browse the repository at this point in the history
[CUDA] CUDA adapter multi device context
  • Loading branch information
kbenzie authored May 22, 2024
2 parents c911a9b + 7142006 commit d3502dc
Show file tree
Hide file tree
Showing 27 changed files with 1,157 additions and 706 deletions.
57 changes: 34 additions & 23 deletions source/adapters/cuda/command_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -203,9 +203,10 @@ static ur_result_t enqueueCommandBufferFillHelper(
}
}

UR_CHECK_ERROR(cuGraphAddMemsetNode(
&GraphNode, CommandBuffer->CudaGraph, DepsList.data(),
DepsList.size(), &NodeParams, CommandBuffer->Device->getContext()));
UR_CHECK_ERROR(
cuGraphAddMemsetNode(&GraphNode, CommandBuffer->CudaGraph,
DepsList.data(), DepsList.size(), &NodeParams,
CommandBuffer->Device->getNativeContext()));

// Get sync point and register the cuNode with it.
*SyncPoint =
Expand Down Expand Up @@ -237,7 +238,7 @@ static ur_result_t enqueueCommandBufferFillHelper(
UR_CHECK_ERROR(cuGraphAddMemsetNode(
&GraphNodeFirst, CommandBuffer->CudaGraph, DepsList.data(),
DepsList.size(), &NodeParamsStepFirst,
CommandBuffer->Device->getContext()));
CommandBuffer->Device->getNativeContext()));

// Get sync point and register the cuNode with it.
*SyncPoint = CommandBuffer->addSyncPoint(
Expand Down Expand Up @@ -269,7 +270,7 @@ static ur_result_t enqueueCommandBufferFillHelper(
UR_CHECK_ERROR(cuGraphAddMemsetNode(
&GraphNode, CommandBuffer->CudaGraph, DepsList.data(),
DepsList.size(), &NodeParamsStep,
CommandBuffer->Device->getContext()));
CommandBuffer->Device->getNativeContext()));

GraphNodePtr = std::make_shared<CUgraphNode>(GraphNode);
// Get sync point and register the cuNode with it.
Expand Down Expand Up @@ -478,7 +479,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendUSMMemcpyExp(

UR_CHECK_ERROR(cuGraphAddMemcpyNode(
&GraphNode, hCommandBuffer->CudaGraph, DepsList.data(), DepsList.size(),
&NodeParams, hCommandBuffer->Device->getContext()));
&NodeParams, hCommandBuffer->Device->getNativeContext()));

// Get sync point and register the cuNode with it.
*pSyncPoint =
Expand Down Expand Up @@ -513,16 +514,18 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendMemBufferCopyExp(
}

try {
auto Src = std::get<BufferMem>(hSrcMem->Mem).get() + srcOffset;
auto Dst = std::get<BufferMem>(hDstMem->Mem).get() + dstOffset;
auto Src = std::get<BufferMem>(hSrcMem->Mem)
.getPtrWithOffset(hCommandBuffer->Device, srcOffset);
auto Dst = std::get<BufferMem>(hDstMem->Mem)
.getPtrWithOffset(hCommandBuffer->Device, dstOffset);

CUDA_MEMCPY3D NodeParams = {};
setCopyParams(&Src, CU_MEMORYTYPE_DEVICE, &Dst, CU_MEMORYTYPE_DEVICE, size,
NodeParams);

UR_CHECK_ERROR(cuGraphAddMemcpyNode(
&GraphNode, hCommandBuffer->CudaGraph, DepsList.data(), DepsList.size(),
&NodeParams, hCommandBuffer->Device->getContext()));
&NodeParams, hCommandBuffer->Device->getNativeContext()));

// Get sync point and register the cuNode with it.
*pSyncPoint =
Expand Down Expand Up @@ -553,8 +556,10 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendMemBufferCopyRectExp(
}

try {
CUdeviceptr SrcPtr = std::get<BufferMem>(hSrcMem->Mem).get();
CUdeviceptr DstPtr = std::get<BufferMem>(hDstMem->Mem).get();
auto SrcPtr =
std::get<BufferMem>(hSrcMem->Mem).getPtr(hCommandBuffer->Device);
auto DstPtr =
std::get<BufferMem>(hDstMem->Mem).getPtr(hCommandBuffer->Device);
CUDA_MEMCPY3D NodeParams = {};

setCopyRectParams(region, &SrcPtr, CU_MEMORYTYPE_DEVICE, srcOrigin,
Expand All @@ -563,7 +568,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendMemBufferCopyRectExp(

UR_CHECK_ERROR(cuGraphAddMemcpyNode(
&GraphNode, hCommandBuffer->CudaGraph, DepsList.data(), DepsList.size(),
&NodeParams, hCommandBuffer->Device->getContext()));
&NodeParams, hCommandBuffer->Device->getNativeContext()));

// Get sync point and register the cuNode with it.
*pSyncPoint =
Expand Down Expand Up @@ -593,15 +598,16 @@ ur_result_t UR_APICALL urCommandBufferAppendMemBufferWriteExp(
}

try {
auto Dst = std::get<BufferMem>(hBuffer->Mem).get() + offset;
auto Dst = std::get<BufferMem>(hBuffer->Mem)
.getPtrWithOffset(hCommandBuffer->Device, offset);

CUDA_MEMCPY3D NodeParams = {};
setCopyParams(pSrc, CU_MEMORYTYPE_HOST, &Dst, CU_MEMORYTYPE_DEVICE, size,
NodeParams);

UR_CHECK_ERROR(cuGraphAddMemcpyNode(
&GraphNode, hCommandBuffer->CudaGraph, DepsList.data(), DepsList.size(),
&NodeParams, hCommandBuffer->Device->getContext()));
&NodeParams, hCommandBuffer->Device->getNativeContext()));

// Get sync point and register the cuNode with it.
*pSyncPoint =
Expand Down Expand Up @@ -630,15 +636,16 @@ ur_result_t UR_APICALL urCommandBufferAppendMemBufferReadExp(
}

try {
auto Src = std::get<BufferMem>(hBuffer->Mem).get() + offset;
auto Src = std::get<BufferMem>(hBuffer->Mem)
.getPtrWithOffset(hCommandBuffer->Device, offset);

CUDA_MEMCPY3D NodeParams = {};
setCopyParams(&Src, CU_MEMORYTYPE_DEVICE, pDst, CU_MEMORYTYPE_HOST, size,
NodeParams);

UR_CHECK_ERROR(cuGraphAddMemcpyNode(
&GraphNode, hCommandBuffer->CudaGraph, DepsList.data(), DepsList.size(),
&NodeParams, hCommandBuffer->Device->getContext()));
&NodeParams, hCommandBuffer->Device->getNativeContext()));

// Get sync point and register the cuNode with it.
*pSyncPoint =
Expand Down Expand Up @@ -670,7 +677,8 @@ ur_result_t UR_APICALL urCommandBufferAppendMemBufferWriteRectExp(
}

try {
CUdeviceptr DstPtr = std::get<BufferMem>(hBuffer->Mem).get();
auto DstPtr =
std::get<BufferMem>(hBuffer->Mem).getPtr(hCommandBuffer->Device);
CUDA_MEMCPY3D NodeParams = {};

setCopyRectParams(region, pSrc, CU_MEMORYTYPE_HOST, hostOffset,
Expand All @@ -680,7 +688,7 @@ ur_result_t UR_APICALL urCommandBufferAppendMemBufferWriteRectExp(

UR_CHECK_ERROR(cuGraphAddMemcpyNode(
&GraphNode, hCommandBuffer->CudaGraph, DepsList.data(), DepsList.size(),
&NodeParams, hCommandBuffer->Device->getContext()));
&NodeParams, hCommandBuffer->Device->getNativeContext()));

// Get sync point and register the cuNode with it.
*pSyncPoint =
Expand Down Expand Up @@ -712,7 +720,8 @@ ur_result_t UR_APICALL urCommandBufferAppendMemBufferReadRectExp(
}

try {
CUdeviceptr SrcPtr = std::get<BufferMem>(hBuffer->Mem).get();
auto SrcPtr =
std::get<BufferMem>(hBuffer->Mem).getPtr(hCommandBuffer->Device);
CUDA_MEMCPY3D NodeParams = {};

setCopyRectParams(region, &SrcPtr, CU_MEMORYTYPE_DEVICE, bufferOffset,
Expand All @@ -722,7 +731,7 @@ ur_result_t UR_APICALL urCommandBufferAppendMemBufferReadRectExp(

UR_CHECK_ERROR(cuGraphAddMemcpyNode(
&GraphNode, hCommandBuffer->CudaGraph, DepsList.data(), DepsList.size(),
&NodeParams, hCommandBuffer->Device->getContext()));
&NodeParams, hCommandBuffer->Device->getNativeContext()));

// Get sync point and register the cuNode with it.
*pSyncPoint =
Expand Down Expand Up @@ -821,7 +830,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendMemBufferFillExp(
PatternSizeIsValid,
UR_RESULT_ERROR_INVALID_SIZE);

auto DstDevice = std::get<BufferMem>(hBuffer->Mem).get() + offset;
auto DstDevice = std::get<BufferMem>(hBuffer->Mem)
.getPtrWithOffset(hCommandBuffer->Device, offset);

return enqueueCommandBufferFillHelper(
hCommandBuffer, &DstDevice, CU_MEMORYTYPE_DEVICE, pPattern, patternSize,
Expand Down Expand Up @@ -854,7 +864,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferEnqueueExp(

try {
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr};
ScopedContext Active(hQueue->getContext());
ScopedContext Active(hQueue->getDevice());
uint32_t StreamToken;
ur_stream_guard_ Guard;
CUstream CuStream = hQueue->getNextComputeStream(
Expand Down Expand Up @@ -972,7 +982,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp(
if (ArgValue == nullptr) {
Kernel->setKernelArg(ArgIndex, 0, nullptr);
} else {
CUdeviceptr CuPtr = std::get<BufferMem>(ArgValue->Mem).get();
CUdeviceptr CuPtr =
std::get<BufferMem>(ArgValue->Mem).getPtr(CommandBuffer->Device);
Kernel->setKernelArg(ArgIndex, sizeof(CUdeviceptr), (void *)&CuPtr);
}
} catch (ur_result_t Err) {
Expand Down
24 changes: 12 additions & 12 deletions source/adapters/cuda/context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,23 +46,19 @@ UR_APIEXPORT ur_result_t UR_APICALL
urContextCreate(uint32_t DeviceCount, const ur_device_handle_t *phDevices,
const ur_context_properties_t *pProperties,
ur_context_handle_t *phContext) {
std::ignore = DeviceCount;
std::ignore = pProperties;

assert(DeviceCount == 1);
ur_result_t RetErr = UR_RESULT_SUCCESS;

std::unique_ptr<ur_context_handle_t_> ContextPtr{nullptr};
try {
ContextPtr = std::unique_ptr<ur_context_handle_t_>(
new ur_context_handle_t_{*phDevices});
new ur_context_handle_t_{phDevices, DeviceCount});
*phContext = ContextPtr.release();
} catch (ur_result_t Err) {
RetErr = Err;
return Err;
} catch (...) {
RetErr = UR_RESULT_ERROR_OUT_OF_RESOURCES;
return UR_RESULT_ERROR_OUT_OF_RESOURCES;
}
return RetErr;
return UR_RESULT_SUCCESS;
}

UR_APIEXPORT ur_result_t UR_APICALL urContextGetInfo(
Expand All @@ -72,9 +68,10 @@ UR_APIEXPORT ur_result_t UR_APICALL urContextGetInfo(

switch (static_cast<uint32_t>(ContextInfoType)) {
case UR_CONTEXT_INFO_NUM_DEVICES:
return ReturnValue(1);
return ReturnValue(static_cast<uint32_t>(hContext->getDevices().size()));
case UR_CONTEXT_INFO_DEVICES:
return ReturnValue(hContext->getDevice());
return ReturnValue(hContext->getDevices().data(),
hContext->getDevices().size());
case UR_CONTEXT_INFO_REFERENCE_COUNT:
return ReturnValue(hContext->getReferenceCount());
case UR_CONTEXT_INFO_ATOMIC_MEMORY_ORDER_CAPABILITIES: {
Expand All @@ -88,7 +85,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urContextGetInfo(
int Major = 0;
UR_CHECK_ERROR(cuDeviceGetAttribute(
&Major, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR,
hContext->getDevice()->get()));
hContext->getDevices()[0]->get()));
uint32_t Capabilities =
(Major >= 7) ? UR_MEMORY_SCOPE_CAPABILITY_FLAG_WORK_ITEM |
UR_MEMORY_SCOPE_CAPABILITY_FLAG_SUB_GROUP |
Expand Down Expand Up @@ -137,7 +134,10 @@ urContextRetain(ur_context_handle_t hContext) {

UR_APIEXPORT ur_result_t UR_APICALL urContextGetNativeHandle(
ur_context_handle_t hContext, ur_native_handle_t *phNativeContext) {
*phNativeContext = reinterpret_cast<ur_native_handle_t>(hContext->get());
// FIXME: this entry point has been deprecated in the SYCL RT and should be
// changed to unsupoorted once deprecation period has elapsed.
*phNativeContext = reinterpret_cast<ur_native_handle_t>(
hContext->getDevices()[0]->getNativeContext());
return UR_RESULT_SUCCESS;
}

Expand Down
77 changes: 45 additions & 32 deletions source/adapters/cuda/context.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,27 +33,26 @@ typedef void (*ur_context_extended_deleter_t)(void *user_data);
///
/// One of the main differences between the UR API and the CUDA driver API is
/// that the second modifies the state of the threads by assigning
/// `CUcontext` objects to threads. `CUcontext` objects store data associated
/// \c CUcontext objects to threads. \c CUcontext objects store data associated
/// with a given device and control access to said device from the user side.
/// UR API context are objects that are passed to functions, and not bound
/// to threads.
/// The ur_context_handle_t_ object doesn't implement this behavior. It only
/// holds the CUDA context data. The RAII object \ref ScopedContext implements
/// the active context behavior.
///
/// <b> Primary vs User-defined context </b>
/// Since the \c ur_context_handle_t can contain multiple devices, and a \c
/// CUcontext refers to only a single device, the \c CUcontext is more tightly
/// coupled to a \c ur_device_handle_t than a \c ur_context_handle_t. In order
/// to remove some ambiguities about the different semantics of \c
/// \c ur_context_handle_t and native \c CUcontext, we access the native \c
/// CUcontext solely through the \c ur_device_handle_t class, by using the
/// object \ref ScopedContext, which sets the active device (by setting the
/// active native \c CUcontext).
///
/// CUDA has two different types of context, the Primary context,
/// which is usable by all threads on a given process for a given device, and
/// the aforementioned custom contexts.
/// The CUDA documentation, confirmed with performance analysis, suggest using
/// the Primary context whenever possible.
/// The Primary context is also used by the CUDA Runtime API.
/// For UR applications to interop with CUDA Runtime API, they have to use
/// the primary context - and make that active in the thread.
/// The `ur_context_handle_t_` object can be constructed with a `kind` parameter
/// that allows to construct a Primary or `user-defined` context, so that
/// the UR object interface is always the same.
/// <b> Primary vs User-defined \c CUcontext </b>
///
/// CUDA has two different types of \c CUcontext, the Primary context, which is
/// usable by all threads on a given process for a given device, and the
/// aforementioned custom \c CUcontext s. The CUDA documentation, confirmed with
/// performance analysis, suggest using the Primary context whenever possible.
///
/// <b> Destructor callback </b>
///
Expand All @@ -63,6 +62,18 @@ typedef void (*ur_context_extended_deleter_t)(void *user_data);
/// See proposal for details.
/// https://github.com/codeplaysoftware/standards-proposals/blob/master/extended-context-destruction/index.md
///
///
/// <b> Memory Management for Devices in a Context <\b>
///
/// A \c ur_mem_handle_t is associated with a \c ur_context_handle_t_, which
/// may refer to multiple devices. Therefore the \c ur_mem_handle_t must
/// handle a native allocation for each device in the context. UR is
/// responsible for automatically handling event dependencies for kernels
/// writing to or reading from the same \c ur_mem_handle_t and migrating memory
/// between native allocations for devices in the same \c ur_context_handle_t_
/// if necessary.
///
///
struct ur_context_handle_t_ {

struct deleter_data {
Expand All @@ -72,18 +83,21 @@ struct ur_context_handle_t_ {
void operator()() { Function(UserData); }
};

using native_type = CUcontext;

native_type CUContext;
ur_device_handle_t DeviceID;
std::vector<ur_device_handle_t> Devices;
std::atomic_uint32_t RefCount;

ur_context_handle_t_(ur_device_handle_t_ *DevID)
: CUContext{DevID->getContext()}, DeviceID{DevID}, RefCount{1} {
urDeviceRetain(DeviceID);
ur_context_handle_t_(const ur_device_handle_t *Devs, uint32_t NumDevices)
: Devices{Devs, Devs + NumDevices}, RefCount{1} {
for (auto &Dev : Devices) {
urDeviceRetain(Dev);
}
};

~ur_context_handle_t_() { urDeviceRelease(DeviceID); }
~ur_context_handle_t_() {
for (auto &Dev : Devices) {
urDeviceRelease(Dev);
}
}

void invokeExtendedDeleters() {
std::lock_guard<std::mutex> Guard(Mutex);
Expand All @@ -98,9 +112,9 @@ struct ur_context_handle_t_ {
ExtendedDeleters.emplace_back(deleter_data{Function, UserData});
}

ur_device_handle_t getDevice() const noexcept { return DeviceID; }

native_type get() const noexcept { return CUContext; }
const std::vector<ur_device_handle_t> &getDevices() const noexcept {
return Devices;
}

uint32_t incrementReferenceCount() noexcept { return ++RefCount; }

Expand All @@ -123,12 +137,11 @@ struct ur_context_handle_t_ {
namespace {
class ScopedContext {
public:
ScopedContext(ur_context_handle_t Context) {
if (!Context) {
throw UR_RESULT_ERROR_INVALID_CONTEXT;
ScopedContext(ur_device_handle_t Device) {
if (!Device) {
throw UR_RESULT_ERROR_INVALID_DEVICE;
}

setContext(Context->get());
setContext(Device->getNativeContext());
}

ScopedContext(CUcontext NativeContext) { setContext(NativeContext); }
Expand Down
4 changes: 2 additions & 2 deletions source/adapters/cuda/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetInfo(ur_device_handle_t hDevice,

static constexpr uint32_t MaxWorkItemDimensions = 3u;

ScopedContext Active(hDevice->getContext());
ScopedContext Active(hDevice);

switch ((uint32_t)propName) {
case UR_DEVICE_INFO_TYPE: {
Expand Down Expand Up @@ -1234,7 +1234,7 @@ ur_result_t UR_APICALL urDeviceGetGlobalTimestamps(ur_device_handle_t hDevice,
uint64_t *pDeviceTimestamp,
uint64_t *pHostTimestamp) {
CUevent Event;
ScopedContext Active(hDevice->getContext());
ScopedContext Active(hDevice);

if (pDeviceTimestamp) {
UR_CHECK_ERROR(cuEventCreate(&Event, CU_EVENT_DEFAULT));
Expand Down
Loading

0 comments on commit d3502dc

Please sign in to comment.