Skip to content

Commit

Permalink
Merge pull request #849 from 0x12CC/cooperative_kernels
Browse files Browse the repository at this point in the history
Add cooperative kernels experimental feature
  • Loading branch information
kbenzie authored Sep 15, 2023
2 parents bcf2b2a + bb542f3 commit 12c8312
Show file tree
Hide file tree
Showing 14 changed files with 1,388 additions and 0 deletions.
59 changes: 59 additions & 0 deletions include/ur.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,8 @@ class ur_function_v(IntEnum):
COMMAND_BUFFER_APPEND_MEM_BUFFER_WRITE_RECT_EXP = 190 ## Enumerator for ::urCommandBufferAppendMemBufferWriteRectExp
COMMAND_BUFFER_APPEND_MEM_BUFFER_READ_RECT_EXP = 191## Enumerator for ::urCommandBufferAppendMemBufferReadRectExp
COMMAND_BUFFER_APPEND_MEM_BUFFER_FILL_EXP = 192 ## Enumerator for ::urCommandBufferAppendMemBufferFillExp
ENQUEUE_COOPERATIVE_KERNEL_LAUNCH_EXP = 193 ## Enumerator for ::urEnqueueCooperativeKernelLaunchExp
KERNEL_SUGGEST_MAX_COOPERATIVE_GROUP_COUNT_EXP = 194## Enumerator for ::urKernelSuggestMaxCooperativeGroupCountExp

class ur_function_t(c_int):
def __str__(self):
Expand Down Expand Up @@ -2272,6 +2274,11 @@ class ur_exp_command_buffer_sync_point_t(c_ulong):
class ur_exp_command_buffer_handle_t(c_void_p):
pass

###############################################################################
## @brief The extension string which defines support for cooperative-kernels
## which is returned when querying device extensions.
UR_COOPERATIVE_KERNELS_EXTENSION_STRING_EXP = "ur_exp_cooperative_kernels"

###############################################################################
## @brief Supported peer info
class ur_exp_peer_info_v(IntEnum):
Expand Down Expand Up @@ -2715,6 +2722,21 @@ class ur_kernel_dditable_t(Structure):
("pfnSetSpecializationConstants", c_void_p) ## _urKernelSetSpecializationConstants_t
]

###############################################################################
## @brief Function-pointer for urKernelSuggestMaxCooperativeGroupCountExp
if __use_win_types:
_urKernelSuggestMaxCooperativeGroupCountExp_t = WINFUNCTYPE( ur_result_t, ur_kernel_handle_t, POINTER(c_ulong) )
else:
_urKernelSuggestMaxCooperativeGroupCountExp_t = CFUNCTYPE( ur_result_t, ur_kernel_handle_t, POINTER(c_ulong) )


###############################################################################
## @brief Table of KernelExp functions pointers
class ur_kernel_exp_dditable_t(Structure):
_fields_ = [
("pfnSuggestMaxCooperativeGroupCountExp", c_void_p) ## _urKernelSuggestMaxCooperativeGroupCountExp_t
]

###############################################################################
## @brief Function-pointer for urSamplerCreate
if __use_win_types:
Expand Down Expand Up @@ -3142,6 +3164,21 @@ class ur_enqueue_dditable_t(Structure):
("pfnWriteHostPipe", c_void_p) ## _urEnqueueWriteHostPipe_t
]

###############################################################################
## @brief Function-pointer for urEnqueueCooperativeKernelLaunchExp
if __use_win_types:
_urEnqueueCooperativeKernelLaunchExp_t = WINFUNCTYPE( ur_result_t, ur_queue_handle_t, ur_kernel_handle_t, c_ulong, POINTER(c_size_t), POINTER(c_size_t), POINTER(c_size_t), c_ulong, POINTER(ur_event_handle_t), POINTER(ur_event_handle_t) )
else:
_urEnqueueCooperativeKernelLaunchExp_t = CFUNCTYPE( ur_result_t, ur_queue_handle_t, ur_kernel_handle_t, c_ulong, POINTER(c_size_t), POINTER(c_size_t), POINTER(c_size_t), c_ulong, POINTER(ur_event_handle_t), POINTER(ur_event_handle_t) )


###############################################################################
## @brief Table of EnqueueExp functions pointers
class ur_enqueue_exp_dditable_t(Structure):
_fields_ = [
("pfnCooperativeKernelLaunchExp", c_void_p) ## _urEnqueueCooperativeKernelLaunchExp_t
]

###############################################################################
## @brief Function-pointer for urQueueGetInfo
if __use_win_types:
Expand Down Expand Up @@ -3774,11 +3811,13 @@ class ur_dditable_t(Structure):
("Event", ur_event_dditable_t),
("Program", ur_program_dditable_t),
("Kernel", ur_kernel_dditable_t),
("KernelExp", ur_kernel_exp_dditable_t),
("Sampler", ur_sampler_dditable_t),
("Mem", ur_mem_dditable_t),
("PhysicalMem", ur_physical_mem_dditable_t),
("Global", ur_global_dditable_t),
("Enqueue", ur_enqueue_dditable_t),
("EnqueueExp", ur_enqueue_exp_dditable_t),
("Queue", ur_queue_dditable_t),
("BindlessImagesExp", ur_bindless_images_exp_dditable_t),
("USM", ur_usm_dditable_t),
Expand Down Expand Up @@ -3899,6 +3938,16 @@ def __init__(self, version : ur_api_version_t):
self.urKernelSetArgMemObj = _urKernelSetArgMemObj_t(self.__dditable.Kernel.pfnSetArgMemObj)
self.urKernelSetSpecializationConstants = _urKernelSetSpecializationConstants_t(self.__dditable.Kernel.pfnSetSpecializationConstants)

# call driver to get function pointers
KernelExp = ur_kernel_exp_dditable_t()
r = ur_result_v(self.__dll.urGetKernelExpProcAddrTable(version, byref(KernelExp)))
if r != ur_result_v.SUCCESS:
raise Exception(r)
self.__dditable.KernelExp = KernelExp

# attach function interface to function address
self.urKernelSuggestMaxCooperativeGroupCountExp = _urKernelSuggestMaxCooperativeGroupCountExp_t(self.__dditable.KernelExp.pfnSuggestMaxCooperativeGroupCountExp)

# call driver to get function pointers
Sampler = ur_sampler_dditable_t()
r = ur_result_v(self.__dll.urGetSamplerProcAddrTable(version, byref(Sampler)))
Expand Down Expand Up @@ -3993,6 +4042,16 @@ def __init__(self, version : ur_api_version_t):
self.urEnqueueReadHostPipe = _urEnqueueReadHostPipe_t(self.__dditable.Enqueue.pfnReadHostPipe)
self.urEnqueueWriteHostPipe = _urEnqueueWriteHostPipe_t(self.__dditable.Enqueue.pfnWriteHostPipe)

# call driver to get function pointers
EnqueueExp = ur_enqueue_exp_dditable_t()
r = ur_result_v(self.__dll.urGetEnqueueExpProcAddrTable(version, byref(EnqueueExp)))
if r != ur_result_v.SUCCESS:
raise Exception(r)
self.__dditable.EnqueueExp = EnqueueExp

# attach function interface to function address
self.urEnqueueCooperativeKernelLaunchExp = _urEnqueueCooperativeKernelLaunchExp_t(self.__dditable.EnqueueExp.pfnCooperativeKernelLaunchExp)

# call driver to get function pointers
Queue = ur_queue_dditable_t()
r = ur_result_v(self.__dll.urGetQueueProcAddrTable(version, byref(Queue)))
Expand Down
111 changes: 111 additions & 0 deletions include/ur_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,8 @@ typedef enum ur_function_t {
UR_FUNCTION_COMMAND_BUFFER_APPEND_MEM_BUFFER_WRITE_RECT_EXP = 190, ///< Enumerator for ::urCommandBufferAppendMemBufferWriteRectExp
UR_FUNCTION_COMMAND_BUFFER_APPEND_MEM_BUFFER_READ_RECT_EXP = 191, ///< Enumerator for ::urCommandBufferAppendMemBufferReadRectExp
UR_FUNCTION_COMMAND_BUFFER_APPEND_MEM_BUFFER_FILL_EXP = 192, ///< Enumerator for ::urCommandBufferAppendMemBufferFillExp
UR_FUNCTION_ENQUEUE_COOPERATIVE_KERNEL_LAUNCH_EXP = 193, ///< Enumerator for ::urEnqueueCooperativeKernelLaunchExp
UR_FUNCTION_KERNEL_SUGGEST_MAX_COOPERATIVE_GROUP_COUNT_EXP = 194, ///< Enumerator for ::urKernelSuggestMaxCooperativeGroupCountExp
/// @cond
UR_FUNCTION_FORCE_UINT32 = 0x7fffffff
/// @endcond
Expand Down Expand Up @@ -8171,6 +8173,90 @@ urCommandBufferEnqueueExp(
///< command-buffer execution instance.
);

#if !defined(__GNUC__)
#pragma endregion
#endif
// Intel 'oneAPI' Unified Runtime Experimental APIs for Cooperative Kernels
#if !defined(__GNUC__)
#pragma region cooperative kernels(experimental)
#endif
///////////////////////////////////////////////////////////////////////////////
#ifndef UR_COOPERATIVE_KERNELS_EXTENSION_STRING_EXP
/// @brief The extension string which defines support for cooperative-kernels
/// which is returned when querying device extensions.
#define UR_COOPERATIVE_KERNELS_EXTENSION_STRING_EXP "ur_exp_cooperative_kernels"
#endif // UR_COOPERATIVE_KERNELS_EXTENSION_STRING_EXP

///////////////////////////////////////////////////////////////////////////////
/// @brief Enqueue a command to execute a cooperative kernel
///
/// @returns
/// - ::UR_RESULT_SUCCESS
/// - ::UR_RESULT_ERROR_UNINITIALIZED
/// - ::UR_RESULT_ERROR_DEVICE_LOST
/// - ::UR_RESULT_ERROR_ADAPTER_SPECIFIC
/// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE
/// + `NULL == hQueue`
/// + `NULL == hKernel`
/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER
/// + `NULL == pGlobalWorkOffset`
/// + `NULL == pGlobalWorkSize`
/// - ::UR_RESULT_ERROR_INVALID_QUEUE
/// - ::UR_RESULT_ERROR_INVALID_KERNEL
/// - ::UR_RESULT_ERROR_INVALID_EVENT
/// - ::UR_RESULT_ERROR_INVALID_EVENT_WAIT_LIST
/// + `phEventWaitList == NULL && numEventsInWaitList > 0`
/// + `phEventWaitList != NULL && numEventsInWaitList == 0`
/// + If event objects in phEventWaitList are not valid events.
/// - ::UR_RESULT_ERROR_INVALID_WORK_DIMENSION
/// - ::UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE
/// - ::UR_RESULT_ERROR_INVALID_VALUE
/// - ::UR_RESULT_ERROR_OUT_OF_HOST_MEMORY
/// - ::UR_RESULT_ERROR_OUT_OF_RESOURCES
UR_APIEXPORT ur_result_t UR_APICALL
urEnqueueCooperativeKernelLaunchExp(
ur_queue_handle_t hQueue, ///< [in] handle of the queue object
ur_kernel_handle_t hKernel, ///< [in] handle of the kernel object
uint32_t workDim, ///< [in] number of dimensions, from 1 to 3, to specify the global and
///< work-group work-items
const size_t *pGlobalWorkOffset, ///< [in] pointer to an array of workDim unsigned values that specify the
///< offset used to calculate the global ID of a work-item
const size_t *pGlobalWorkSize, ///< [in] pointer to an array of workDim unsigned values that specify the
///< number of global work-items in workDim that will execute the kernel
///< function
const size_t *pLocalWorkSize, ///< [in][optional] pointer to an array of workDim unsigned values that
///< specify the number of local work-items forming a work-group that will
///< execute the kernel function.
///< If nullptr, the runtime implementation will choose the work-group
///< size.
uint32_t numEventsInWaitList, ///< [in] size of the event wait list
const ur_event_handle_t *phEventWaitList, ///< [in][optional][range(0, numEventsInWaitList)] pointer to a list of
///< events that must be complete before the kernel execution.
///< If nullptr, the numEventsInWaitList must be 0, indicating that no wait
///< event.
ur_event_handle_t *phEvent ///< [out][optional] return an event object that identifies this particular
///< kernel execution instance.
);

///////////////////////////////////////////////////////////////////////////////
/// @brief Query the maximum number of work groups for a cooperative kernel
///
/// @returns
/// - ::UR_RESULT_SUCCESS
/// - ::UR_RESULT_ERROR_UNINITIALIZED
/// - ::UR_RESULT_ERROR_DEVICE_LOST
/// - ::UR_RESULT_ERROR_ADAPTER_SPECIFIC
/// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE
/// + `NULL == hKernel`
/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER
/// + `NULL == pGroupCountRet`
/// - ::UR_RESULT_ERROR_INVALID_KERNEL
UR_APIEXPORT ur_result_t UR_APICALL
urKernelSuggestMaxCooperativeGroupCountExp(
ur_kernel_handle_t hKernel, ///< [in] handle of the kernel object
uint32_t *pGroupCountRet ///< [out] pointer to maximum number of groups
);

#if !defined(__GNUC__)
#pragma endregion
#endif
Expand Down Expand Up @@ -8939,6 +9025,15 @@ typedef struct ur_kernel_set_specialization_constants_params_t {
const ur_specialization_constant_info_t **ppSpecConstants;
} ur_kernel_set_specialization_constants_params_t;

///////////////////////////////////////////////////////////////////////////////
/// @brief Function parameters for urKernelSuggestMaxCooperativeGroupCountExp
/// @details Each entry is a pointer to the parameter passed to the function;
/// allowing the callback the ability to modify the parameter's value
typedef struct ur_kernel_suggest_max_cooperative_group_count_exp_params_t {
ur_kernel_handle_t *phKernel;
uint32_t **ppGroupCountRet;
} ur_kernel_suggest_max_cooperative_group_count_exp_params_t;

///////////////////////////////////////////////////////////////////////////////
/// @brief Function parameters for urSamplerCreate
/// @details Each entry is a pointer to the parameter passed to the function;
Expand Down Expand Up @@ -9586,6 +9681,22 @@ typedef struct ur_enqueue_write_host_pipe_params_t {
ur_event_handle_t **pphEvent;
} ur_enqueue_write_host_pipe_params_t;

///////////////////////////////////////////////////////////////////////////////
/// @brief Function parameters for urEnqueueCooperativeKernelLaunchExp
/// @details Each entry is a pointer to the parameter passed to the function;
/// allowing the callback the ability to modify the parameter's value
typedef struct ur_enqueue_cooperative_kernel_launch_exp_params_t {
ur_queue_handle_t *phQueue;
ur_kernel_handle_t *phKernel;
uint32_t *pworkDim;
const size_t **ppGlobalWorkOffset;
const size_t **ppGlobalWorkSize;
const size_t **ppLocalWorkSize;
uint32_t *pnumEventsInWaitList;
const ur_event_handle_t **pphEventWaitList;
ur_event_handle_t **pphEvent;
} ur_enqueue_cooperative_kernel_launch_exp_params_t;

///////////////////////////////////////////////////////////////////////////////
/// @brief Function parameters for urQueueGetInfo
/// @details Each entry is a pointer to the parameter passed to the function;
Expand Down
75 changes: 75 additions & 0 deletions include/ur_ddi.h
Original file line number Diff line number Diff line change
Expand Up @@ -567,6 +567,39 @@ typedef ur_result_t(UR_APICALL *ur_pfnGetKernelProcAddrTable_t)(
ur_api_version_t,
ur_kernel_dditable_t *);

///////////////////////////////////////////////////////////////////////////////
/// @brief Function-pointer for urKernelSuggestMaxCooperativeGroupCountExp
typedef ur_result_t(UR_APICALL *ur_pfnKernelSuggestMaxCooperativeGroupCountExp_t)(
ur_kernel_handle_t,
uint32_t *);

///////////////////////////////////////////////////////////////////////////////
/// @brief Table of KernelExp functions pointers
typedef struct ur_kernel_exp_dditable_t {
ur_pfnKernelSuggestMaxCooperativeGroupCountExp_t pfnSuggestMaxCooperativeGroupCountExp;
} ur_kernel_exp_dditable_t;

///////////////////////////////////////////////////////////////////////////////
/// @brief Exported function for filling application's KernelExp table
/// with current process' addresses
///
/// @returns
/// - ::UR_RESULT_SUCCESS
/// - ::UR_RESULT_ERROR_UNINITIALIZED
/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER
/// - ::UR_RESULT_ERROR_UNSUPPORTED_VERSION
UR_DLLEXPORT ur_result_t UR_APICALL
urGetKernelExpProcAddrTable(
ur_api_version_t version, ///< [in] API version requested
ur_kernel_exp_dditable_t *pDdiTable ///< [in,out] pointer to table of DDI function pointers
);

///////////////////////////////////////////////////////////////////////////////
/// @brief Function-pointer for urGetKernelExpProcAddrTable
typedef ur_result_t(UR_APICALL *ur_pfnGetKernelExpProcAddrTable_t)(
ur_api_version_t,
ur_kernel_exp_dditable_t *);

///////////////////////////////////////////////////////////////////////////////
/// @brief Function-pointer for urSamplerCreate
typedef ur_result_t(UR_APICALL *ur_pfnSamplerCreate_t)(
Expand Down Expand Up @@ -1246,6 +1279,46 @@ typedef ur_result_t(UR_APICALL *ur_pfnGetEnqueueProcAddrTable_t)(
ur_api_version_t,
ur_enqueue_dditable_t *);

///////////////////////////////////////////////////////////////////////////////
/// @brief Function-pointer for urEnqueueCooperativeKernelLaunchExp
typedef ur_result_t(UR_APICALL *ur_pfnEnqueueCooperativeKernelLaunchExp_t)(
ur_queue_handle_t,
ur_kernel_handle_t,
uint32_t,
const size_t *,
const size_t *,
const size_t *,
uint32_t,
const ur_event_handle_t *,
ur_event_handle_t *);

///////////////////////////////////////////////////////////////////////////////
/// @brief Table of EnqueueExp functions pointers
typedef struct ur_enqueue_exp_dditable_t {
ur_pfnEnqueueCooperativeKernelLaunchExp_t pfnCooperativeKernelLaunchExp;
} ur_enqueue_exp_dditable_t;

///////////////////////////////////////////////////////////////////////////////
/// @brief Exported function for filling application's EnqueueExp table
/// with current process' addresses
///
/// @returns
/// - ::UR_RESULT_SUCCESS
/// - ::UR_RESULT_ERROR_UNINITIALIZED
/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER
/// - ::UR_RESULT_ERROR_UNSUPPORTED_VERSION
UR_DLLEXPORT ur_result_t UR_APICALL
urGetEnqueueExpProcAddrTable(
ur_api_version_t version, ///< [in] API version requested
ur_enqueue_exp_dditable_t *pDdiTable ///< [in,out] pointer to table of DDI function pointers
);

///////////////////////////////////////////////////////////////////////////////
/// @brief Function-pointer for urGetEnqueueExpProcAddrTable
typedef ur_result_t(UR_APICALL *ur_pfnGetEnqueueExpProcAddrTable_t)(
ur_api_version_t,
ur_enqueue_exp_dditable_t *);

///////////////////////////////////////////////////////////////////////////////
/// @brief Function-pointer for urQueueGetInfo
typedef ur_result_t(UR_APICALL *ur_pfnQueueGetInfo_t)(
Expand Down Expand Up @@ -2154,11 +2227,13 @@ typedef struct ur_dditable_t {
ur_event_dditable_t Event;
ur_program_dditable_t Program;
ur_kernel_dditable_t Kernel;
ur_kernel_exp_dditable_t KernelExp;
ur_sampler_dditable_t Sampler;
ur_mem_dditable_t Mem;
ur_physical_mem_dditable_t PhysicalMem;
ur_global_dditable_t Global;
ur_enqueue_dditable_t Enqueue;
ur_enqueue_exp_dditable_t EnqueueExp;
ur_queue_dditable_t Queue;
ur_bindless_images_exp_dditable_t BindlessImagesExp;
ur_usm_dditable_t USM;
Expand Down
Loading

0 comments on commit 12c8312

Please sign in to comment.