From bb542f3563eba6a4e9bd5056340f668e7f2e546b Mon Sep 17 00:00:00 2001 From: Michael Aziz Date: Tue, 5 Sep 2023 13:37:56 -0700 Subject: [PATCH] Add cooperative kernels experimental feature Signed-off-by: Michael Aziz --- include/ur.py | 59 +++++ include/ur_api.h | 111 +++++++++ include/ur_ddi.h | 75 ++++++ scripts/core/EXP-COOPERATIVE-KERNELS.rst | 68 ++++++ scripts/core/exp-cooperative-kernels.yml | 85 +++++++ scripts/core/registry.yml | 6 + source/adapters/null/ur_nullddi.cpp | 136 +++++++++++ source/common/ur_params.hpp | 93 ++++++++ source/loader/layers/tracing/ur_trcddi.cpp | 173 ++++++++++++++ source/loader/layers/validation/ur_valddi.cpp | 182 +++++++++++++++ source/loader/ur_ldrddi.cpp | 213 ++++++++++++++++++ source/loader/ur_libapi.cpp | 97 ++++++++ source/loader/ur_libddi.cpp | 10 + source/ur_api.cpp | 80 +++++++ 14 files changed, 1388 insertions(+) create mode 100644 scripts/core/EXP-COOPERATIVE-KERNELS.rst create mode 100644 scripts/core/exp-cooperative-kernels.yml diff --git a/include/ur.py b/include/ur.py index c6ce722221..6860304683 100644 --- a/include/ur.py +++ b/include/ur.py @@ -198,6 +198,8 @@ class ur_function_v(IntEnum): COMMAND_BUFFER_APPEND_MEM_BUFFER_WRITE_RECT_EXP = 190 ## Enumerator for ::urCommandBufferAppendMemBufferWriteRectExp COMMAND_BUFFER_APPEND_MEM_BUFFER_READ_RECT_EXP = 191## Enumerator for ::urCommandBufferAppendMemBufferReadRectExp COMMAND_BUFFER_APPEND_MEM_BUFFER_FILL_EXP = 192 ## Enumerator for ::urCommandBufferAppendMemBufferFillExp + ENQUEUE_COOPERATIVE_KERNEL_LAUNCH_EXP = 193 ## Enumerator for ::urEnqueueCooperativeKernelLaunchExp + KERNEL_SUGGEST_MAX_COOPERATIVE_GROUP_COUNT_EXP = 194## Enumerator for ::urKernelSuggestMaxCooperativeGroupCountExp class ur_function_t(c_int): def __str__(self): @@ -2272,6 +2274,11 @@ class ur_exp_command_buffer_sync_point_t(c_ulong): class ur_exp_command_buffer_handle_t(c_void_p): pass +############################################################################### +## @brief The extension string which defines support for cooperative-kernels +## which is returned when querying device extensions. +UR_COOPERATIVE_KERNELS_EXTENSION_STRING_EXP = "ur_exp_cooperative_kernels" + ############################################################################### ## @brief Supported peer info class ur_exp_peer_info_v(IntEnum): @@ -2715,6 +2722,21 @@ class ur_kernel_dditable_t(Structure): ("pfnSetSpecializationConstants", c_void_p) ## _urKernelSetSpecializationConstants_t ] +############################################################################### +## @brief Function-pointer for urKernelSuggestMaxCooperativeGroupCountExp +if __use_win_types: + _urKernelSuggestMaxCooperativeGroupCountExp_t = WINFUNCTYPE( ur_result_t, ur_kernel_handle_t, POINTER(c_ulong) ) +else: + _urKernelSuggestMaxCooperativeGroupCountExp_t = CFUNCTYPE( ur_result_t, ur_kernel_handle_t, POINTER(c_ulong) ) + + +############################################################################### +## @brief Table of KernelExp functions pointers +class ur_kernel_exp_dditable_t(Structure): + _fields_ = [ + ("pfnSuggestMaxCooperativeGroupCountExp", c_void_p) ## _urKernelSuggestMaxCooperativeGroupCountExp_t + ] + ############################################################################### ## @brief Function-pointer for urSamplerCreate if __use_win_types: @@ -3142,6 +3164,21 @@ class ur_enqueue_dditable_t(Structure): ("pfnWriteHostPipe", c_void_p) ## _urEnqueueWriteHostPipe_t ] +############################################################################### +## @brief Function-pointer for urEnqueueCooperativeKernelLaunchExp +if __use_win_types: + _urEnqueueCooperativeKernelLaunchExp_t = WINFUNCTYPE( ur_result_t, ur_queue_handle_t, ur_kernel_handle_t, c_ulong, POINTER(c_size_t), POINTER(c_size_t), POINTER(c_size_t), c_ulong, POINTER(ur_event_handle_t), POINTER(ur_event_handle_t) ) +else: + _urEnqueueCooperativeKernelLaunchExp_t = CFUNCTYPE( ur_result_t, ur_queue_handle_t, ur_kernel_handle_t, c_ulong, POINTER(c_size_t), POINTER(c_size_t), POINTER(c_size_t), c_ulong, POINTER(ur_event_handle_t), POINTER(ur_event_handle_t) ) + + +############################################################################### +## @brief Table of EnqueueExp functions pointers +class ur_enqueue_exp_dditable_t(Structure): + _fields_ = [ + ("pfnCooperativeKernelLaunchExp", c_void_p) ## _urEnqueueCooperativeKernelLaunchExp_t + ] + ############################################################################### ## @brief Function-pointer for urQueueGetInfo if __use_win_types: @@ -3774,11 +3811,13 @@ class ur_dditable_t(Structure): ("Event", ur_event_dditable_t), ("Program", ur_program_dditable_t), ("Kernel", ur_kernel_dditable_t), + ("KernelExp", ur_kernel_exp_dditable_t), ("Sampler", ur_sampler_dditable_t), ("Mem", ur_mem_dditable_t), ("PhysicalMem", ur_physical_mem_dditable_t), ("Global", ur_global_dditable_t), ("Enqueue", ur_enqueue_dditable_t), + ("EnqueueExp", ur_enqueue_exp_dditable_t), ("Queue", ur_queue_dditable_t), ("BindlessImagesExp", ur_bindless_images_exp_dditable_t), ("USM", ur_usm_dditable_t), @@ -3899,6 +3938,16 @@ def __init__(self, version : ur_api_version_t): self.urKernelSetArgMemObj = _urKernelSetArgMemObj_t(self.__dditable.Kernel.pfnSetArgMemObj) self.urKernelSetSpecializationConstants = _urKernelSetSpecializationConstants_t(self.__dditable.Kernel.pfnSetSpecializationConstants) + # call driver to get function pointers + KernelExp = ur_kernel_exp_dditable_t() + r = ur_result_v(self.__dll.urGetKernelExpProcAddrTable(version, byref(KernelExp))) + if r != ur_result_v.SUCCESS: + raise Exception(r) + self.__dditable.KernelExp = KernelExp + + # attach function interface to function address + self.urKernelSuggestMaxCooperativeGroupCountExp = _urKernelSuggestMaxCooperativeGroupCountExp_t(self.__dditable.KernelExp.pfnSuggestMaxCooperativeGroupCountExp) + # call driver to get function pointers Sampler = ur_sampler_dditable_t() r = ur_result_v(self.__dll.urGetSamplerProcAddrTable(version, byref(Sampler))) @@ -3993,6 +4042,16 @@ def __init__(self, version : ur_api_version_t): self.urEnqueueReadHostPipe = _urEnqueueReadHostPipe_t(self.__dditable.Enqueue.pfnReadHostPipe) self.urEnqueueWriteHostPipe = _urEnqueueWriteHostPipe_t(self.__dditable.Enqueue.pfnWriteHostPipe) + # call driver to get function pointers + EnqueueExp = ur_enqueue_exp_dditable_t() + r = ur_result_v(self.__dll.urGetEnqueueExpProcAddrTable(version, byref(EnqueueExp))) + if r != ur_result_v.SUCCESS: + raise Exception(r) + self.__dditable.EnqueueExp = EnqueueExp + + # attach function interface to function address + self.urEnqueueCooperativeKernelLaunchExp = _urEnqueueCooperativeKernelLaunchExp_t(self.__dditable.EnqueueExp.pfnCooperativeKernelLaunchExp) + # call driver to get function pointers Queue = ur_queue_dditable_t() r = ur_result_v(self.__dll.urGetQueueProcAddrTable(version, byref(Queue))) diff --git a/include/ur_api.h b/include/ur_api.h index 6c445fc408..84d26d1c31 100644 --- a/include/ur_api.h +++ b/include/ur_api.h @@ -207,6 +207,8 @@ typedef enum ur_function_t { UR_FUNCTION_COMMAND_BUFFER_APPEND_MEM_BUFFER_WRITE_RECT_EXP = 190, ///< Enumerator for ::urCommandBufferAppendMemBufferWriteRectExp UR_FUNCTION_COMMAND_BUFFER_APPEND_MEM_BUFFER_READ_RECT_EXP = 191, ///< Enumerator for ::urCommandBufferAppendMemBufferReadRectExp UR_FUNCTION_COMMAND_BUFFER_APPEND_MEM_BUFFER_FILL_EXP = 192, ///< Enumerator for ::urCommandBufferAppendMemBufferFillExp + UR_FUNCTION_ENQUEUE_COOPERATIVE_KERNEL_LAUNCH_EXP = 193, ///< Enumerator for ::urEnqueueCooperativeKernelLaunchExp + UR_FUNCTION_KERNEL_SUGGEST_MAX_COOPERATIVE_GROUP_COUNT_EXP = 194, ///< Enumerator for ::urKernelSuggestMaxCooperativeGroupCountExp /// @cond UR_FUNCTION_FORCE_UINT32 = 0x7fffffff /// @endcond @@ -8171,6 +8173,90 @@ urCommandBufferEnqueueExp( ///< command-buffer execution instance. ); +#if !defined(__GNUC__) +#pragma endregion +#endif +// Intel 'oneAPI' Unified Runtime Experimental APIs for Cooperative Kernels +#if !defined(__GNUC__) +#pragma region cooperative kernels(experimental) +#endif +/////////////////////////////////////////////////////////////////////////////// +#ifndef UR_COOPERATIVE_KERNELS_EXTENSION_STRING_EXP +/// @brief The extension string which defines support for cooperative-kernels +/// which is returned when querying device extensions. +#define UR_COOPERATIVE_KERNELS_EXTENSION_STRING_EXP "ur_exp_cooperative_kernels" +#endif // UR_COOPERATIVE_KERNELS_EXTENSION_STRING_EXP + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Enqueue a command to execute a cooperative kernel +/// +/// @returns +/// - ::UR_RESULT_SUCCESS +/// - ::UR_RESULT_ERROR_UNINITIALIZED +/// - ::UR_RESULT_ERROR_DEVICE_LOST +/// - ::UR_RESULT_ERROR_ADAPTER_SPECIFIC +/// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE +/// + `NULL == hQueue` +/// + `NULL == hKernel` +/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER +/// + `NULL == pGlobalWorkOffset` +/// + `NULL == pGlobalWorkSize` +/// - ::UR_RESULT_ERROR_INVALID_QUEUE +/// - ::UR_RESULT_ERROR_INVALID_KERNEL +/// - ::UR_RESULT_ERROR_INVALID_EVENT +/// - ::UR_RESULT_ERROR_INVALID_EVENT_WAIT_LIST +/// + `phEventWaitList == NULL && numEventsInWaitList > 0` +/// + `phEventWaitList != NULL && numEventsInWaitList == 0` +/// + If event objects in phEventWaitList are not valid events. +/// - ::UR_RESULT_ERROR_INVALID_WORK_DIMENSION +/// - ::UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE +/// - ::UR_RESULT_ERROR_INVALID_VALUE +/// - ::UR_RESULT_ERROR_OUT_OF_HOST_MEMORY +/// - ::UR_RESULT_ERROR_OUT_OF_RESOURCES +UR_APIEXPORT ur_result_t UR_APICALL +urEnqueueCooperativeKernelLaunchExp( + ur_queue_handle_t hQueue, ///< [in] handle of the queue object + ur_kernel_handle_t hKernel, ///< [in] handle of the kernel object + uint32_t workDim, ///< [in] number of dimensions, from 1 to 3, to specify the global and + ///< work-group work-items + const size_t *pGlobalWorkOffset, ///< [in] pointer to an array of workDim unsigned values that specify the + ///< offset used to calculate the global ID of a work-item + const size_t *pGlobalWorkSize, ///< [in] pointer to an array of workDim unsigned values that specify the + ///< number of global work-items in workDim that will execute the kernel + ///< function + const size_t *pLocalWorkSize, ///< [in][optional] pointer to an array of workDim unsigned values that + ///< specify the number of local work-items forming a work-group that will + ///< execute the kernel function. + ///< If nullptr, the runtime implementation will choose the work-group + ///< size. + uint32_t numEventsInWaitList, ///< [in] size of the event wait list + const ur_event_handle_t *phEventWaitList, ///< [in][optional][range(0, numEventsInWaitList)] pointer to a list of + ///< events that must be complete before the kernel execution. + ///< If nullptr, the numEventsInWaitList must be 0, indicating that no wait + ///< event. + ur_event_handle_t *phEvent ///< [out][optional] return an event object that identifies this particular + ///< kernel execution instance. +); + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Query the maximum number of work groups for a cooperative kernel +/// +/// @returns +/// - ::UR_RESULT_SUCCESS +/// - ::UR_RESULT_ERROR_UNINITIALIZED +/// - ::UR_RESULT_ERROR_DEVICE_LOST +/// - ::UR_RESULT_ERROR_ADAPTER_SPECIFIC +/// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE +/// + `NULL == hKernel` +/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER +/// + `NULL == pGroupCountRet` +/// - ::UR_RESULT_ERROR_INVALID_KERNEL +UR_APIEXPORT ur_result_t UR_APICALL +urKernelSuggestMaxCooperativeGroupCountExp( + ur_kernel_handle_t hKernel, ///< [in] handle of the kernel object + uint32_t *pGroupCountRet ///< [out] pointer to maximum number of groups +); + #if !defined(__GNUC__) #pragma endregion #endif @@ -8939,6 +9025,15 @@ typedef struct ur_kernel_set_specialization_constants_params_t { const ur_specialization_constant_info_t **ppSpecConstants; } ur_kernel_set_specialization_constants_params_t; +/////////////////////////////////////////////////////////////////////////////// +/// @brief Function parameters for urKernelSuggestMaxCooperativeGroupCountExp +/// @details Each entry is a pointer to the parameter passed to the function; +/// allowing the callback the ability to modify the parameter's value +typedef struct ur_kernel_suggest_max_cooperative_group_count_exp_params_t { + ur_kernel_handle_t *phKernel; + uint32_t **ppGroupCountRet; +} ur_kernel_suggest_max_cooperative_group_count_exp_params_t; + /////////////////////////////////////////////////////////////////////////////// /// @brief Function parameters for urSamplerCreate /// @details Each entry is a pointer to the parameter passed to the function; @@ -9586,6 +9681,22 @@ typedef struct ur_enqueue_write_host_pipe_params_t { ur_event_handle_t **pphEvent; } ur_enqueue_write_host_pipe_params_t; +/////////////////////////////////////////////////////////////////////////////// +/// @brief Function parameters for urEnqueueCooperativeKernelLaunchExp +/// @details Each entry is a pointer to the parameter passed to the function; +/// allowing the callback the ability to modify the parameter's value +typedef struct ur_enqueue_cooperative_kernel_launch_exp_params_t { + ur_queue_handle_t *phQueue; + ur_kernel_handle_t *phKernel; + uint32_t *pworkDim; + const size_t **ppGlobalWorkOffset; + const size_t **ppGlobalWorkSize; + const size_t **ppLocalWorkSize; + uint32_t *pnumEventsInWaitList; + const ur_event_handle_t **pphEventWaitList; + ur_event_handle_t **pphEvent; +} ur_enqueue_cooperative_kernel_launch_exp_params_t; + /////////////////////////////////////////////////////////////////////////////// /// @brief Function parameters for urQueueGetInfo /// @details Each entry is a pointer to the parameter passed to the function; diff --git a/include/ur_ddi.h b/include/ur_ddi.h index ff3bfca155..ae5edf6371 100644 --- a/include/ur_ddi.h +++ b/include/ur_ddi.h @@ -567,6 +567,39 @@ typedef ur_result_t(UR_APICALL *ur_pfnGetKernelProcAddrTable_t)( ur_api_version_t, ur_kernel_dditable_t *); +/////////////////////////////////////////////////////////////////////////////// +/// @brief Function-pointer for urKernelSuggestMaxCooperativeGroupCountExp +typedef ur_result_t(UR_APICALL *ur_pfnKernelSuggestMaxCooperativeGroupCountExp_t)( + ur_kernel_handle_t, + uint32_t *); + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Table of KernelExp functions pointers +typedef struct ur_kernel_exp_dditable_t { + ur_pfnKernelSuggestMaxCooperativeGroupCountExp_t pfnSuggestMaxCooperativeGroupCountExp; +} ur_kernel_exp_dditable_t; + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's KernelExp table +/// with current process' addresses +/// +/// @returns +/// - ::UR_RESULT_SUCCESS +/// - ::UR_RESULT_ERROR_UNINITIALIZED +/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::UR_RESULT_ERROR_UNSUPPORTED_VERSION +UR_DLLEXPORT ur_result_t UR_APICALL +urGetKernelExpProcAddrTable( + ur_api_version_t version, ///< [in] API version requested + ur_kernel_exp_dditable_t *pDdiTable ///< [in,out] pointer to table of DDI function pointers +); + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Function-pointer for urGetKernelExpProcAddrTable +typedef ur_result_t(UR_APICALL *ur_pfnGetKernelExpProcAddrTable_t)( + ur_api_version_t, + ur_kernel_exp_dditable_t *); + /////////////////////////////////////////////////////////////////////////////// /// @brief Function-pointer for urSamplerCreate typedef ur_result_t(UR_APICALL *ur_pfnSamplerCreate_t)( @@ -1246,6 +1279,46 @@ typedef ur_result_t(UR_APICALL *ur_pfnGetEnqueueProcAddrTable_t)( ur_api_version_t, ur_enqueue_dditable_t *); +/////////////////////////////////////////////////////////////////////////////// +/// @brief Function-pointer for urEnqueueCooperativeKernelLaunchExp +typedef ur_result_t(UR_APICALL *ur_pfnEnqueueCooperativeKernelLaunchExp_t)( + ur_queue_handle_t, + ur_kernel_handle_t, + uint32_t, + const size_t *, + const size_t *, + const size_t *, + uint32_t, + const ur_event_handle_t *, + ur_event_handle_t *); + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Table of EnqueueExp functions pointers +typedef struct ur_enqueue_exp_dditable_t { + ur_pfnEnqueueCooperativeKernelLaunchExp_t pfnCooperativeKernelLaunchExp; +} ur_enqueue_exp_dditable_t; + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's EnqueueExp table +/// with current process' addresses +/// +/// @returns +/// - ::UR_RESULT_SUCCESS +/// - ::UR_RESULT_ERROR_UNINITIALIZED +/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::UR_RESULT_ERROR_UNSUPPORTED_VERSION +UR_DLLEXPORT ur_result_t UR_APICALL +urGetEnqueueExpProcAddrTable( + ur_api_version_t version, ///< [in] API version requested + ur_enqueue_exp_dditable_t *pDdiTable ///< [in,out] pointer to table of DDI function pointers +); + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Function-pointer for urGetEnqueueExpProcAddrTable +typedef ur_result_t(UR_APICALL *ur_pfnGetEnqueueExpProcAddrTable_t)( + ur_api_version_t, + ur_enqueue_exp_dditable_t *); + /////////////////////////////////////////////////////////////////////////////// /// @brief Function-pointer for urQueueGetInfo typedef ur_result_t(UR_APICALL *ur_pfnQueueGetInfo_t)( @@ -2154,11 +2227,13 @@ typedef struct ur_dditable_t { ur_event_dditable_t Event; ur_program_dditable_t Program; ur_kernel_dditable_t Kernel; + ur_kernel_exp_dditable_t KernelExp; ur_sampler_dditable_t Sampler; ur_mem_dditable_t Mem; ur_physical_mem_dditable_t PhysicalMem; ur_global_dditable_t Global; ur_enqueue_dditable_t Enqueue; + ur_enqueue_exp_dditable_t EnqueueExp; ur_queue_dditable_t Queue; ur_bindless_images_exp_dditable_t BindlessImagesExp; ur_usm_dditable_t USM; diff --git a/scripts/core/EXP-COOPERATIVE-KERNELS.rst b/scripts/core/EXP-COOPERATIVE-KERNELS.rst new file mode 100644 index 0000000000..c6b64ef669 --- /dev/null +++ b/scripts/core/EXP-COOPERATIVE-KERNELS.rst @@ -0,0 +1,68 @@ +<% + OneApi=tags['$OneApi'] + x=tags['$x'] + X=x.upper() +%> + +.. _experimental-cooperative-kernels: + +================================================================================ +Cooperative Kernels +================================================================================ + +.. warning:: + + Experimental features: + + * May be replaced, updated, or removed at any time. + * Do not require maintaining API/ABI stability of their own additions over + time. + * Do not require conformance testing of their own additions. + + +Motivation +-------------------------------------------------------------------------------- +Cooperative kernels are kernels that use cross-workgroup synchronization +features. All enqueued workgroups must run concurrently for cooperative kernels +to execute without hanging. This experimental feature provides an API for +querying the maximum number of workgroups and launching cooperative kernels. + +Any device can support cooperative kernels by restricting the maximum number of +workgroups to 1. Devices that support cross-workgroup synchronization can +specify a larger maximum for a given cooperative kernel. + +The functions defined here align with those specified in Level Zero. + +API +-------------------------------------------------------------------------------- + +Macros +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +* ${X}_COOPERATIVE_KERNELS_EXTENSION_STRING_EXP + +Functions +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +* ${x}EnqueueCooperativeKernelLaunchExp +* ${x}KernelSuggestMaxCooperativeGroupCountExp + +Changelog +-------------------------------------------------------------------------------- ++-----------+------------------------+ +| Revision | Changes | ++===========+========================+ +| 1.0 | Initial Draft | ++-----------+------------------------+ + +Support +-------------------------------------------------------------------------------- + +Adapters which support this experimental feature *must* return the valid string +defined in ``${X}_COOPERATIVE_KERNELS_EXTENSION_STRING_EXP`` +as one of the options from ${x}DeviceGetInfo when querying for +${X}_DEVICE_INFO_EXTENSIONS. Conversely, before using any of the +functionality defined in this experimental feature the user *must* use the +device query to determine if the adapter supports this feature. + +Contributors +-------------------------------------------------------------------------------- +* Michael Aziz `michael.aziz@intel.com `_ diff --git a/scripts/core/exp-cooperative-kernels.yml b/scripts/core/exp-cooperative-kernels.yml new file mode 100644 index 0000000000..fb2c6b3a4a --- /dev/null +++ b/scripts/core/exp-cooperative-kernels.yml @@ -0,0 +1,85 @@ +# +# Copyright (C) 2023 Intel Corporation +# +# Part of the Unified-Runtime Project, under the Apache License v2.0 with LLVM Exceptions. +# See LICENSE.TXT +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# +# See YaML.md for syntax definition +# +--- #-------------------------------------------------------------------------- +type: header +desc: "Intel $OneApi Unified Runtime Experimental APIs for Cooperative Kernels" +ordinal: "99" +--- #-------------------------------------------------------------------------- +type: macro +desc: | + The extension string which defines support for cooperative-kernels + which is returned when querying device extensions. +name: $X_COOPERATIVE_KERNELS_EXTENSION_STRING_EXP +value: "\"$x_exp_cooperative_kernels\"" +--- #-------------------------------------------------------------------------- +type: function +desc: "Enqueue a command to execute a cooperative kernel" +class: $xEnqueue +name: CooperativeKernelLaunchExp +params: + - type: $x_queue_handle_t + name: hQueue + desc: "[in] handle of the queue object" + - type: $x_kernel_handle_t + name: hKernel + desc: "[in] handle of the kernel object" + - type: uint32_t + name: workDim + desc: "[in] number of dimensions, from 1 to 3, to specify the global and work-group work-items" + - type: "const size_t*" + name: pGlobalWorkOffset + desc: "[in] pointer to an array of workDim unsigned values that specify the offset used to calculate the global ID of a work-item" + - type: "const size_t*" + name: pGlobalWorkSize + desc: "[in] pointer to an array of workDim unsigned values that specify the number of global work-items in workDim that will execute the kernel function" + - type: "const size_t*" + name: pLocalWorkSize + desc: | + [in][optional] pointer to an array of workDim unsigned values that specify the number of local work-items forming a work-group that will execute the kernel function. + If nullptr, the runtime implementation will choose the work-group size. + - type: uint32_t + name: numEventsInWaitList + desc: "[in] size of the event wait list" + - type: "const $x_event_handle_t*" + name: phEventWaitList + desc: | + [in][optional][range(0, numEventsInWaitList)] pointer to a list of events that must be complete before the kernel execution. + If nullptr, the numEventsInWaitList must be 0, indicating that no wait event. + - type: $x_event_handle_t* + name: phEvent + desc: | + [out][optional] return an event object that identifies this particular kernel execution instance. +returns: + - $X_RESULT_ERROR_INVALID_QUEUE + - $X_RESULT_ERROR_INVALID_KERNEL + - $X_RESULT_ERROR_INVALID_EVENT + - $X_RESULT_ERROR_INVALID_EVENT_WAIT_LIST: + - "`phEventWaitList == NULL && numEventsInWaitList > 0`" + - "`phEventWaitList != NULL && numEventsInWaitList == 0`" + - "If event objects in phEventWaitList are not valid events." + - $X_RESULT_ERROR_INVALID_WORK_DIMENSION + - $X_RESULT_ERROR_INVALID_WORK_GROUP_SIZE + - $X_RESULT_ERROR_INVALID_VALUE + - $X_RESULT_ERROR_OUT_OF_HOST_MEMORY + - $X_RESULT_ERROR_OUT_OF_RESOURCES +--- #-------------------------------------------------------------------------- +type: function +desc: "Query the maximum number of work groups for a cooperative kernel" +class: $xKernel +name: SuggestMaxCooperativeGroupCountExp +params: + - type: $x_kernel_handle_t + name: hKernel + desc: "[in] handle of the kernel object" + - type: "uint32_t*" + name: "pGroupCountRet" + desc: "[out] pointer to maximum number of groups" +returns: + - $X_RESULT_ERROR_INVALID_KERNEL diff --git a/scripts/core/registry.yml b/scripts/core/registry.yml index 01b4b50334..4924e3e947 100644 --- a/scripts/core/registry.yml +++ b/scripts/core/registry.yml @@ -535,6 +535,12 @@ etors: - name: COMMAND_BUFFER_APPEND_MEM_BUFFER_FILL_EXP desc: Enumerator for $xCommandBufferAppendMemBufferFillExp value: '192' +- name: ENQUEUE_COOPERATIVE_KERNEL_LAUNCH_EXP + desc: Enumerator for $xEnqueueCooperativeKernelLaunchExp + value: '193' +- name: KERNEL_SUGGEST_MAX_COOPERATIVE_GROUP_COUNT_EXP + desc: Enumerator for $xKernelSuggestMaxCooperativeGroupCountExp + value: '194' --- type: enum desc: Defines structure types diff --git a/source/adapters/null/ur_nullddi.cpp b/source/adapters/null/ur_nullddi.cpp index cb87343945..314cb9138e 100644 --- a/source/adapters/null/ur_nullddi.cpp +++ b/source/adapters/null/ur_nullddi.cpp @@ -4927,6 +4927,80 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferEnqueueExp( return exceptionToResult(std::current_exception()); } +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urEnqueueCooperativeKernelLaunchExp +__urdlllocal ur_result_t UR_APICALL urEnqueueCooperativeKernelLaunchExp( + ur_queue_handle_t hQueue, ///< [in] handle of the queue object + ur_kernel_handle_t hKernel, ///< [in] handle of the kernel object + uint32_t + workDim, ///< [in] number of dimensions, from 1 to 3, to specify the global and + ///< work-group work-items + const size_t * + pGlobalWorkOffset, ///< [in] pointer to an array of workDim unsigned values that specify the + ///< offset used to calculate the global ID of a work-item + const size_t * + pGlobalWorkSize, ///< [in] pointer to an array of workDim unsigned values that specify the + ///< number of global work-items in workDim that will execute the kernel + ///< function + const size_t * + pLocalWorkSize, ///< [in][optional] pointer to an array of workDim unsigned values that + ///< specify the number of local work-items forming a work-group that will + ///< execute the kernel function. + ///< If nullptr, the runtime implementation will choose the work-group + ///< size. + uint32_t numEventsInWaitList, ///< [in] size of the event wait list + const ur_event_handle_t * + phEventWaitList, ///< [in][optional][range(0, numEventsInWaitList)] pointer to a list of + ///< events that must be complete before the kernel execution. + ///< If nullptr, the numEventsInWaitList must be 0, indicating that no wait + ///< event. + ur_event_handle_t * + phEvent ///< [out][optional] return an event object that identifies this particular + ///< kernel execution instance. + ) try { + ur_result_t result = UR_RESULT_SUCCESS; + + // if the driver has created a custom function, then call it instead of using the generic path + auto pfnCooperativeKernelLaunchExp = + d_context.urDdiTable.EnqueueExp.pfnCooperativeKernelLaunchExp; + if (nullptr != pfnCooperativeKernelLaunchExp) { + result = pfnCooperativeKernelLaunchExp( + hQueue, hKernel, workDim, pGlobalWorkOffset, pGlobalWorkSize, + pLocalWorkSize, numEventsInWaitList, phEventWaitList, phEvent); + } else { + // generic implementation + if (nullptr != phEvent) { + *phEvent = reinterpret_cast(d_context.get()); + } + } + + return result; +} catch (...) { + return exceptionToResult(std::current_exception()); +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urKernelSuggestMaxCooperativeGroupCountExp +__urdlllocal ur_result_t UR_APICALL urKernelSuggestMaxCooperativeGroupCountExp( + ur_kernel_handle_t hKernel, ///< [in] handle of the kernel object + uint32_t *pGroupCountRet ///< [out] pointer to maximum number of groups + ) try { + ur_result_t result = UR_RESULT_SUCCESS; + + // if the driver has created a custom function, then call it instead of using the generic path + auto pfnSuggestMaxCooperativeGroupCountExp = + d_context.urDdiTable.KernelExp.pfnSuggestMaxCooperativeGroupCountExp; + if (nullptr != pfnSuggestMaxCooperativeGroupCountExp) { + result = pfnSuggestMaxCooperativeGroupCountExp(hKernel, pGroupCountRet); + } else { + // generic implementation + } + + return result; +} catch (...) { + return exceptionToResult(std::current_exception()); +} + /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urUSMImportExp __urdlllocal ur_result_t UR_APICALL urUSMImportExp( @@ -5357,6 +5431,37 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetEnqueueProcAddrTable( return exceptionToResult(std::current_exception()); } +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's EnqueueExp table +/// with current process' addresses +/// +/// @returns +/// - ::UR_RESULT_SUCCESS +/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::UR_RESULT_ERROR_UNSUPPORTED_VERSION +UR_DLLEXPORT ur_result_t UR_APICALL urGetEnqueueExpProcAddrTable( + ur_api_version_t version, ///< [in] API version requested + ur_enqueue_exp_dditable_t + *pDdiTable ///< [in,out] pointer to table of DDI function pointers + ) try { + if (nullptr == pDdiTable) { + return UR_RESULT_ERROR_INVALID_NULL_POINTER; + } + + if (driver::d_context.version < version) { + return UR_RESULT_ERROR_UNSUPPORTED_VERSION; + } + + ur_result_t result = UR_RESULT_SUCCESS; + + pDdiTable->pfnCooperativeKernelLaunchExp = + driver::urEnqueueCooperativeKernelLaunchExp; + + return result; +} catch (...) { + return exceptionToResult(std::current_exception()); +} + /////////////////////////////////////////////////////////////////////////////// /// @brief Exported function for filling application's Event table /// with current process' addresses @@ -5462,6 +5567,37 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetKernelProcAddrTable( return exceptionToResult(std::current_exception()); } +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's KernelExp table +/// with current process' addresses +/// +/// @returns +/// - ::UR_RESULT_SUCCESS +/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::UR_RESULT_ERROR_UNSUPPORTED_VERSION +UR_DLLEXPORT ur_result_t UR_APICALL urGetKernelExpProcAddrTable( + ur_api_version_t version, ///< [in] API version requested + ur_kernel_exp_dditable_t + *pDdiTable ///< [in,out] pointer to table of DDI function pointers + ) try { + if (nullptr == pDdiTable) { + return UR_RESULT_ERROR_INVALID_NULL_POINTER; + } + + if (driver::d_context.version < version) { + return UR_RESULT_ERROR_UNSUPPORTED_VERSION; + } + + ur_result_t result = UR_RESULT_SUCCESS; + + pDdiTable->pfnSuggestMaxCooperativeGroupCountExp = + driver::urKernelSuggestMaxCooperativeGroupCountExp; + + return result; +} catch (...) { + return exceptionToResult(std::current_exception()); +} + /////////////////////////////////////////////////////////////////////////////// /// @brief Exported function for filling application's Mem table /// with current process' addresses diff --git a/source/common/ur_params.hpp b/source/common/ur_params.hpp index 04a4636444..4296606987 100644 --- a/source/common/ur_params.hpp +++ b/source/common/ur_params.hpp @@ -1152,6 +1152,14 @@ inline std::ostream &operator<<(std::ostream &os, enum ur_function_t value) { case UR_FUNCTION_COMMAND_BUFFER_APPEND_MEM_BUFFER_FILL_EXP: os << "UR_FUNCTION_COMMAND_BUFFER_APPEND_MEM_BUFFER_FILL_EXP"; break; + + case UR_FUNCTION_ENQUEUE_COOPERATIVE_KERNEL_LAUNCH_EXP: + os << "UR_FUNCTION_ENQUEUE_COOPERATIVE_KERNEL_LAUNCH_EXP"; + break; + + case UR_FUNCTION_KERNEL_SUGGEST_MAX_COOPERATIVE_GROUP_COUNT_EXP: + os << "UR_FUNCTION_KERNEL_SUGGEST_MAX_COOPERATIVE_GROUP_COUNT_EXP"; + break; default: os << "unknown enumerator"; break; @@ -12952,6 +12960,65 @@ operator<<(std::ostream &os, return os; } +inline std::ostream &operator<<( + std::ostream &os, + const struct ur_enqueue_cooperative_kernel_launch_exp_params_t *params) { + + os << ".hQueue = "; + + ur_params::serializePtr(os, *(params->phQueue)); + + os << ", "; + os << ".hKernel = "; + + ur_params::serializePtr(os, *(params->phKernel)); + + os << ", "; + os << ".workDim = "; + + os << *(params->pworkDim); + + os << ", "; + os << ".pGlobalWorkOffset = "; + + ur_params::serializePtr(os, *(params->ppGlobalWorkOffset)); + + os << ", "; + os << ".pGlobalWorkSize = "; + + ur_params::serializePtr(os, *(params->ppGlobalWorkSize)); + + os << ", "; + os << ".pLocalWorkSize = "; + + ur_params::serializePtr(os, *(params->ppLocalWorkSize)); + + os << ", "; + os << ".numEventsInWaitList = "; + + os << *(params->pnumEventsInWaitList); + + os << ", "; + os << ".phEventWaitList = {"; + for (size_t i = 0; *(params->pphEventWaitList) != NULL && + i < *params->pnumEventsInWaitList; + ++i) { + if (i != 0) { + os << ", "; + } + + ur_params::serializePtr(os, (*(params->pphEventWaitList))[i]); + } + os << "}"; + + os << ", "; + os << ".phEvent = "; + + ur_params::serializePtr(os, *(params->pphEvent)); + + return os; +} + inline std::ostream & operator<<(std::ostream &os, const struct ur_event_get_info_params_t *params) { @@ -13499,6 +13566,23 @@ inline std::ostream &operator<<( return os; } +inline std::ostream &operator<<( + std::ostream &os, + const struct ur_kernel_suggest_max_cooperative_group_count_exp_params_t + *params) { + + os << ".hKernel = "; + + ur_params::serializePtr(os, *(params->phKernel)); + + os << ", "; + os << ".pGroupCountRet = "; + + ur_params::serializePtr(os, *(params->ppGroupCountRet)); + + return os; +} + inline std::ostream &operator<<(std::ostream &os, const struct ur_loader_init_params_t *params) { @@ -15718,6 +15802,10 @@ inline int serializeFunctionParams(std::ostream &os, uint32_t function, case UR_FUNCTION_ENQUEUE_WRITE_HOST_PIPE: { os << (const struct ur_enqueue_write_host_pipe_params_t *)params; } break; + case UR_FUNCTION_ENQUEUE_COOPERATIVE_KERNEL_LAUNCH_EXP: { + os << (const struct ur_enqueue_cooperative_kernel_launch_exp_params_t *) + params; + } break; case UR_FUNCTION_EVENT_GET_INFO: { os << (const struct ur_event_get_info_params_t *)params; } break; @@ -15790,6 +15878,11 @@ inline int serializeFunctionParams(std::ostream &os, uint32_t function, os << (const struct ur_kernel_set_specialization_constants_params_t *) params; } break; + case UR_FUNCTION_KERNEL_SUGGEST_MAX_COOPERATIVE_GROUP_COUNT_EXP: { + os << (const struct + ur_kernel_suggest_max_cooperative_group_count_exp_params_t *) + params; + } break; case UR_FUNCTION_LOADER_INIT: { os << (const struct ur_loader_init_params_t *)params; } break; diff --git a/source/loader/layers/tracing/ur_trcddi.cpp b/source/loader/layers/tracing/ur_trcddi.cpp index 64702ebed5..6e4ce75bb8 100644 --- a/source/loader/layers/tracing/ur_trcddi.cpp +++ b/source/loader/layers/tracing/ur_trcddi.cpp @@ -5695,6 +5695,99 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferEnqueueExp( return result; } +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urEnqueueCooperativeKernelLaunchExp +__urdlllocal ur_result_t UR_APICALL urEnqueueCooperativeKernelLaunchExp( + ur_queue_handle_t hQueue, ///< [in] handle of the queue object + ur_kernel_handle_t hKernel, ///< [in] handle of the kernel object + uint32_t + workDim, ///< [in] number of dimensions, from 1 to 3, to specify the global and + ///< work-group work-items + const size_t * + pGlobalWorkOffset, ///< [in] pointer to an array of workDim unsigned values that specify the + ///< offset used to calculate the global ID of a work-item + const size_t * + pGlobalWorkSize, ///< [in] pointer to an array of workDim unsigned values that specify the + ///< number of global work-items in workDim that will execute the kernel + ///< function + const size_t * + pLocalWorkSize, ///< [in][optional] pointer to an array of workDim unsigned values that + ///< specify the number of local work-items forming a work-group that will + ///< execute the kernel function. + ///< If nullptr, the runtime implementation will choose the work-group + ///< size. + uint32_t numEventsInWaitList, ///< [in] size of the event wait list + const ur_event_handle_t * + phEventWaitList, ///< [in][optional][range(0, numEventsInWaitList)] pointer to a list of + ///< events that must be complete before the kernel execution. + ///< If nullptr, the numEventsInWaitList must be 0, indicating that no wait + ///< event. + ur_event_handle_t * + phEvent ///< [out][optional] return an event object that identifies this particular + ///< kernel execution instance. +) { + auto pfnCooperativeKernelLaunchExp = + context.urDdiTable.EnqueueExp.pfnCooperativeKernelLaunchExp; + + if (nullptr == pfnCooperativeKernelLaunchExp) { + return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; + } + + ur_enqueue_cooperative_kernel_launch_exp_params_t params = { + &hQueue, + &hKernel, + &workDim, + &pGlobalWorkOffset, + &pGlobalWorkSize, + &pLocalWorkSize, + &numEventsInWaitList, + &phEventWaitList, + &phEvent}; + uint64_t instance = + context.notify_begin(UR_FUNCTION_ENQUEUE_COOPERATIVE_KERNEL_LAUNCH_EXP, + "urEnqueueCooperativeKernelLaunchExp", ¶ms); + + ur_result_t result = pfnCooperativeKernelLaunchExp( + hQueue, hKernel, workDim, pGlobalWorkOffset, pGlobalWorkSize, + pLocalWorkSize, numEventsInWaitList, phEventWaitList, phEvent); + + context.notify_end(UR_FUNCTION_ENQUEUE_COOPERATIVE_KERNEL_LAUNCH_EXP, + "urEnqueueCooperativeKernelLaunchExp", ¶ms, &result, + instance); + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urKernelSuggestMaxCooperativeGroupCountExp +__urdlllocal ur_result_t UR_APICALL urKernelSuggestMaxCooperativeGroupCountExp( + ur_kernel_handle_t hKernel, ///< [in] handle of the kernel object + uint32_t *pGroupCountRet ///< [out] pointer to maximum number of groups +) { + auto pfnSuggestMaxCooperativeGroupCountExp = + context.urDdiTable.KernelExp.pfnSuggestMaxCooperativeGroupCountExp; + + if (nullptr == pfnSuggestMaxCooperativeGroupCountExp) { + return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; + } + + ur_kernel_suggest_max_cooperative_group_count_exp_params_t params = { + &hKernel, &pGroupCountRet}; + uint64_t instance = context.notify_begin( + UR_FUNCTION_KERNEL_SUGGEST_MAX_COOPERATIVE_GROUP_COUNT_EXP, + "urKernelSuggestMaxCooperativeGroupCountExp", ¶ms); + + ur_result_t result = + pfnSuggestMaxCooperativeGroupCountExp(hKernel, pGroupCountRet); + + context.notify_end( + UR_FUNCTION_KERNEL_SUGGEST_MAX_COOPERATIVE_GROUP_COUNT_EXP, + "urKernelSuggestMaxCooperativeGroupCountExp", ¶ms, &result, + instance); + + return result; +} + /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urUSMImportExp __urdlllocal ur_result_t UR_APICALL urUSMImportExp( @@ -6247,6 +6340,41 @@ __urdlllocal ur_result_t UR_APICALL urGetEnqueueProcAddrTable( return result; } /////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's EnqueueExp table +/// with current process' addresses +/// +/// @returns +/// - ::UR_RESULT_SUCCESS +/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::UR_RESULT_ERROR_UNSUPPORTED_VERSION +__urdlllocal ur_result_t UR_APICALL urGetEnqueueExpProcAddrTable( + ur_api_version_t version, ///< [in] API version requested + ur_enqueue_exp_dditable_t + *pDdiTable ///< [in,out] pointer to table of DDI function pointers +) { + auto &dditable = ur_tracing_layer::context.urDdiTable.EnqueueExp; + + if (nullptr == pDdiTable) { + return UR_RESULT_ERROR_INVALID_NULL_POINTER; + } + + if (UR_MAJOR_VERSION(ur_tracing_layer::context.version) != + UR_MAJOR_VERSION(version) || + UR_MINOR_VERSION(ur_tracing_layer::context.version) > + UR_MINOR_VERSION(version)) { + return UR_RESULT_ERROR_UNSUPPORTED_VERSION; + } + + ur_result_t result = UR_RESULT_SUCCESS; + + dditable.pfnCooperativeKernelLaunchExp = + pDdiTable->pfnCooperativeKernelLaunchExp; + pDdiTable->pfnCooperativeKernelLaunchExp = + ur_tracing_layer::urEnqueueCooperativeKernelLaunchExp; + + return result; +} +/////////////////////////////////////////////////////////////////////////////// /// @brief Exported function for filling application's Event table /// with current process' addresses /// @@ -6380,6 +6508,41 @@ __urdlllocal ur_result_t UR_APICALL urGetKernelProcAddrTable( return result; } /////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's KernelExp table +/// with current process' addresses +/// +/// @returns +/// - ::UR_RESULT_SUCCESS +/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::UR_RESULT_ERROR_UNSUPPORTED_VERSION +__urdlllocal ur_result_t UR_APICALL urGetKernelExpProcAddrTable( + ur_api_version_t version, ///< [in] API version requested + ur_kernel_exp_dditable_t + *pDdiTable ///< [in,out] pointer to table of DDI function pointers +) { + auto &dditable = ur_tracing_layer::context.urDdiTable.KernelExp; + + if (nullptr == pDdiTable) { + return UR_RESULT_ERROR_INVALID_NULL_POINTER; + } + + if (UR_MAJOR_VERSION(ur_tracing_layer::context.version) != + UR_MAJOR_VERSION(version) || + UR_MINOR_VERSION(ur_tracing_layer::context.version) > + UR_MINOR_VERSION(version)) { + return UR_RESULT_ERROR_UNSUPPORTED_VERSION; + } + + ur_result_t result = UR_RESULT_SUCCESS; + + dditable.pfnSuggestMaxCooperativeGroupCountExp = + pDdiTable->pfnSuggestMaxCooperativeGroupCountExp; + pDdiTable->pfnSuggestMaxCooperativeGroupCountExp = + ur_tracing_layer::urKernelSuggestMaxCooperativeGroupCountExp; + + return result; +} +/////////////////////////////////////////////////////////////////////////////// /// @brief Exported function for filling application's Mem table /// with current process' addresses /// @@ -6993,6 +7156,11 @@ ur_result_t context_t::init(ur_dditable_t *dditable, UR_API_VERSION_CURRENT, &dditable->Enqueue); } + if (UR_RESULT_SUCCESS == result) { + result = ur_tracing_layer::urGetEnqueueExpProcAddrTable( + UR_API_VERSION_CURRENT, &dditable->EnqueueExp); + } + if (UR_RESULT_SUCCESS == result) { result = ur_tracing_layer::urGetEventProcAddrTable( UR_API_VERSION_CURRENT, &dditable->Event); @@ -7003,6 +7171,11 @@ ur_result_t context_t::init(ur_dditable_t *dditable, UR_API_VERSION_CURRENT, &dditable->Kernel); } + if (UR_RESULT_SUCCESS == result) { + result = ur_tracing_layer::urGetKernelExpProcAddrTable( + UR_API_VERSION_CURRENT, &dditable->KernelExp); + } + if (UR_RESULT_SUCCESS == result) { result = ur_tracing_layer::urGetMemProcAddrTable(UR_API_VERSION_CURRENT, &dditable->Mem); diff --git a/source/loader/layers/validation/ur_valddi.cpp b/source/loader/layers/validation/ur_valddi.cpp index fe689c8537..e170ba04b5 100644 --- a/source/loader/layers/validation/ur_valddi.cpp +++ b/source/loader/layers/validation/ur_valddi.cpp @@ -7142,6 +7142,106 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferEnqueueExp( return result; } +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urEnqueueCooperativeKernelLaunchExp +__urdlllocal ur_result_t UR_APICALL urEnqueueCooperativeKernelLaunchExp( + ur_queue_handle_t hQueue, ///< [in] handle of the queue object + ur_kernel_handle_t hKernel, ///< [in] handle of the kernel object + uint32_t + workDim, ///< [in] number of dimensions, from 1 to 3, to specify the global and + ///< work-group work-items + const size_t * + pGlobalWorkOffset, ///< [in] pointer to an array of workDim unsigned values that specify the + ///< offset used to calculate the global ID of a work-item + const size_t * + pGlobalWorkSize, ///< [in] pointer to an array of workDim unsigned values that specify the + ///< number of global work-items in workDim that will execute the kernel + ///< function + const size_t * + pLocalWorkSize, ///< [in][optional] pointer to an array of workDim unsigned values that + ///< specify the number of local work-items forming a work-group that will + ///< execute the kernel function. + ///< If nullptr, the runtime implementation will choose the work-group + ///< size. + uint32_t numEventsInWaitList, ///< [in] size of the event wait list + const ur_event_handle_t * + phEventWaitList, ///< [in][optional][range(0, numEventsInWaitList)] pointer to a list of + ///< events that must be complete before the kernel execution. + ///< If nullptr, the numEventsInWaitList must be 0, indicating that no wait + ///< event. + ur_event_handle_t * + phEvent ///< [out][optional] return an event object that identifies this particular + ///< kernel execution instance. +) { + auto pfnCooperativeKernelLaunchExp = + context.urDdiTable.EnqueueExp.pfnCooperativeKernelLaunchExp; + + if (nullptr == pfnCooperativeKernelLaunchExp) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + if (context.enableParameterValidation) { + if (NULL == hQueue) { + return UR_RESULT_ERROR_INVALID_NULL_HANDLE; + } + + if (NULL == hKernel) { + return UR_RESULT_ERROR_INVALID_NULL_HANDLE; + } + + if (NULL == pGlobalWorkOffset) { + return UR_RESULT_ERROR_INVALID_NULL_POINTER; + } + + if (NULL == pGlobalWorkSize) { + return UR_RESULT_ERROR_INVALID_NULL_POINTER; + } + + if (phEventWaitList == NULL && numEventsInWaitList > 0) { + return UR_RESULT_ERROR_INVALID_EVENT_WAIT_LIST; + } + + if (phEventWaitList != NULL && numEventsInWaitList == 0) { + return UR_RESULT_ERROR_INVALID_EVENT_WAIT_LIST; + } + } + + ur_result_t result = pfnCooperativeKernelLaunchExp( + hQueue, hKernel, workDim, pGlobalWorkOffset, pGlobalWorkSize, + pLocalWorkSize, numEventsInWaitList, phEventWaitList, phEvent); + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urKernelSuggestMaxCooperativeGroupCountExp +__urdlllocal ur_result_t UR_APICALL urKernelSuggestMaxCooperativeGroupCountExp( + ur_kernel_handle_t hKernel, ///< [in] handle of the kernel object + uint32_t *pGroupCountRet ///< [out] pointer to maximum number of groups +) { + auto pfnSuggestMaxCooperativeGroupCountExp = + context.urDdiTable.KernelExp.pfnSuggestMaxCooperativeGroupCountExp; + + if (nullptr == pfnSuggestMaxCooperativeGroupCountExp) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + if (context.enableParameterValidation) { + if (NULL == hKernel) { + return UR_RESULT_ERROR_INVALID_NULL_HANDLE; + } + + if (NULL == pGroupCountRet) { + return UR_RESULT_ERROR_INVALID_NULL_POINTER; + } + } + + ur_result_t result = + pfnSuggestMaxCooperativeGroupCountExp(hKernel, pGroupCountRet); + + return result; +} + /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urUSMImportExp __urdlllocal ur_result_t UR_APICALL urUSMImportExp( @@ -7727,6 +7827,42 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetEnqueueProcAddrTable( return result; } +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's EnqueueExp table +/// with current process' addresses +/// +/// @returns +/// - ::UR_RESULT_SUCCESS +/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::UR_RESULT_ERROR_UNSUPPORTED_VERSION +UR_DLLEXPORT ur_result_t UR_APICALL urGetEnqueueExpProcAddrTable( + ur_api_version_t version, ///< [in] API version requested + ur_enqueue_exp_dditable_t + *pDdiTable ///< [in,out] pointer to table of DDI function pointers +) { + auto &dditable = ur_validation_layer::context.urDdiTable.EnqueueExp; + + if (nullptr == pDdiTable) { + return UR_RESULT_ERROR_INVALID_NULL_POINTER; + } + + if (UR_MAJOR_VERSION(ur_validation_layer::context.version) != + UR_MAJOR_VERSION(version) || + UR_MINOR_VERSION(ur_validation_layer::context.version) > + UR_MINOR_VERSION(version)) { + return UR_RESULT_ERROR_UNSUPPORTED_VERSION; + } + + ur_result_t result = UR_RESULT_SUCCESS; + + dditable.pfnCooperativeKernelLaunchExp = + pDdiTable->pfnCooperativeKernelLaunchExp; + pDdiTable->pfnCooperativeKernelLaunchExp = + ur_validation_layer::urEnqueueCooperativeKernelLaunchExp; + + return result; +} + /////////////////////////////////////////////////////////////////////////////// /// @brief Exported function for filling application's Event table /// with current process' addresses @@ -7865,6 +8001,42 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetKernelProcAddrTable( return result; } +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's KernelExp table +/// with current process' addresses +/// +/// @returns +/// - ::UR_RESULT_SUCCESS +/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::UR_RESULT_ERROR_UNSUPPORTED_VERSION +UR_DLLEXPORT ur_result_t UR_APICALL urGetKernelExpProcAddrTable( + ur_api_version_t version, ///< [in] API version requested + ur_kernel_exp_dditable_t + *pDdiTable ///< [in,out] pointer to table of DDI function pointers +) { + auto &dditable = ur_validation_layer::context.urDdiTable.KernelExp; + + if (nullptr == pDdiTable) { + return UR_RESULT_ERROR_INVALID_NULL_POINTER; + } + + if (UR_MAJOR_VERSION(ur_validation_layer::context.version) != + UR_MAJOR_VERSION(version) || + UR_MINOR_VERSION(ur_validation_layer::context.version) > + UR_MINOR_VERSION(version)) { + return UR_RESULT_ERROR_UNSUPPORTED_VERSION; + } + + ur_result_t result = UR_RESULT_SUCCESS; + + dditable.pfnSuggestMaxCooperativeGroupCountExp = + pDdiTable->pfnSuggestMaxCooperativeGroupCountExp; + pDdiTable->pfnSuggestMaxCooperativeGroupCountExp = + ur_validation_layer::urKernelSuggestMaxCooperativeGroupCountExp; + + return result; +} + /////////////////////////////////////////////////////////////////////////////// /// @brief Exported function for filling application's Mem table /// with current process' addresses @@ -8505,6 +8677,11 @@ ur_result_t context_t::init(ur_dditable_t *dditable, UR_API_VERSION_CURRENT, &dditable->Enqueue); } + if (UR_RESULT_SUCCESS == result) { + result = ur_validation_layer::urGetEnqueueExpProcAddrTable( + UR_API_VERSION_CURRENT, &dditable->EnqueueExp); + } + if (UR_RESULT_SUCCESS == result) { result = ur_validation_layer::urGetEventProcAddrTable( UR_API_VERSION_CURRENT, &dditable->Event); @@ -8515,6 +8692,11 @@ ur_result_t context_t::init(ur_dditable_t *dditable, UR_API_VERSION_CURRENT, &dditable->Kernel); } + if (UR_RESULT_SUCCESS == result) { + result = ur_validation_layer::urGetKernelExpProcAddrTable( + UR_API_VERSION_CURRENT, &dditable->KernelExp); + } + if (UR_RESULT_SUCCESS == result) { result = ur_validation_layer::urGetMemProcAddrTable( UR_API_VERSION_CURRENT, &dditable->Mem); diff --git a/source/loader/ur_ldrddi.cpp b/source/loader/ur_ldrddi.cpp index 7094cb0304..528b2d1eba 100644 --- a/source/loader/ur_ldrddi.cpp +++ b/source/loader/ur_ldrddi.cpp @@ -6846,6 +6846,109 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferEnqueueExp( return result; } +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urEnqueueCooperativeKernelLaunchExp +__urdlllocal ur_result_t UR_APICALL urEnqueueCooperativeKernelLaunchExp( + ur_queue_handle_t hQueue, ///< [in] handle of the queue object + ur_kernel_handle_t hKernel, ///< [in] handle of the kernel object + uint32_t + workDim, ///< [in] number of dimensions, from 1 to 3, to specify the global and + ///< work-group work-items + const size_t * + pGlobalWorkOffset, ///< [in] pointer to an array of workDim unsigned values that specify the + ///< offset used to calculate the global ID of a work-item + const size_t * + pGlobalWorkSize, ///< [in] pointer to an array of workDim unsigned values that specify the + ///< number of global work-items in workDim that will execute the kernel + ///< function + const size_t * + pLocalWorkSize, ///< [in][optional] pointer to an array of workDim unsigned values that + ///< specify the number of local work-items forming a work-group that will + ///< execute the kernel function. + ///< If nullptr, the runtime implementation will choose the work-group + ///< size. + uint32_t numEventsInWaitList, ///< [in] size of the event wait list + const ur_event_handle_t * + phEventWaitList, ///< [in][optional][range(0, numEventsInWaitList)] pointer to a list of + ///< events that must be complete before the kernel execution. + ///< If nullptr, the numEventsInWaitList must be 0, indicating that no wait + ///< event. + ur_event_handle_t * + phEvent ///< [out][optional] return an event object that identifies this particular + ///< kernel execution instance. +) { + ur_result_t result = UR_RESULT_SUCCESS; + + // extract platform's function pointer table + auto dditable = reinterpret_cast(hQueue)->dditable; + auto pfnCooperativeKernelLaunchExp = + dditable->ur.EnqueueExp.pfnCooperativeKernelLaunchExp; + if (nullptr == pfnCooperativeKernelLaunchExp) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + // convert loader handle to platform handle + hQueue = reinterpret_cast(hQueue)->handle; + + // convert loader handle to platform handle + hKernel = reinterpret_cast(hKernel)->handle; + + // convert loader handles to platform handles + auto phEventWaitListLocal = + std::vector(numEventsInWaitList); + for (size_t i = 0; i < numEventsInWaitList; ++i) { + phEventWaitListLocal[i] = + reinterpret_cast(phEventWaitList[i])->handle; + } + + // forward to device-platform + result = pfnCooperativeKernelLaunchExp( + hQueue, hKernel, workDim, pGlobalWorkOffset, pGlobalWorkSize, + pLocalWorkSize, numEventsInWaitList, phEventWaitListLocal.data(), + phEvent); + + if (UR_RESULT_SUCCESS != result) { + return result; + } + + try { + // convert platform handle to loader handle + if (nullptr != phEvent) { + *phEvent = reinterpret_cast( + ur_event_factory.getInstance(*phEvent, dditable)); + } + } catch (std::bad_alloc &) { + result = UR_RESULT_ERROR_OUT_OF_HOST_MEMORY; + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urKernelSuggestMaxCooperativeGroupCountExp +__urdlllocal ur_result_t UR_APICALL urKernelSuggestMaxCooperativeGroupCountExp( + ur_kernel_handle_t hKernel, ///< [in] handle of the kernel object + uint32_t *pGroupCountRet ///< [out] pointer to maximum number of groups +) { + ur_result_t result = UR_RESULT_SUCCESS; + + // extract platform's function pointer table + auto dditable = reinterpret_cast(hKernel)->dditable; + auto pfnSuggestMaxCooperativeGroupCountExp = + dditable->ur.KernelExp.pfnSuggestMaxCooperativeGroupCountExp; + if (nullptr == pfnSuggestMaxCooperativeGroupCountExp) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + // convert loader handle to platform handle + hKernel = reinterpret_cast(hKernel)->handle; + + // forward to device-platform + result = pfnSuggestMaxCooperativeGroupCountExp(hKernel, pGroupCountRet); + + return result; +} + /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urUSMImportExp __urdlllocal ur_result_t UR_APICALL urUSMImportExp( @@ -7376,6 +7479,61 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetEnqueueProcAddrTable( return result; } +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's EnqueueExp table +/// with current process' addresses +/// +/// @returns +/// - ::UR_RESULT_SUCCESS +/// - ::UR_RESULT_ERROR_UNINITIALIZED +/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::UR_RESULT_ERROR_UNSUPPORTED_VERSION +UR_DLLEXPORT ur_result_t UR_APICALL urGetEnqueueExpProcAddrTable( + ur_api_version_t version, ///< [in] API version requested + ur_enqueue_exp_dditable_t + *pDdiTable ///< [in,out] pointer to table of DDI function pointers +) { + if (nullptr == pDdiTable) { + return UR_RESULT_ERROR_INVALID_NULL_POINTER; + } + + if (ur_loader::context->version < version) { + return UR_RESULT_ERROR_UNSUPPORTED_VERSION; + } + + ur_result_t result = UR_RESULT_SUCCESS; + + // Load the device-platform DDI tables + for (auto &platform : ur_loader::context->platforms) { + if (platform.initStatus != UR_RESULT_SUCCESS) { + continue; + } + auto getTable = reinterpret_cast( + ur_loader::LibLoader::getFunctionPtr( + platform.handle.get(), "urGetEnqueueExpProcAddrTable")); + if (!getTable) { + continue; + } + platform.initStatus = + getTable(version, &platform.dditable.ur.EnqueueExp); + } + + if (UR_RESULT_SUCCESS == result) { + if (ur_loader::context->platforms.size() != 1 || + ur_loader::context->forceIntercept) { + // return pointers to loader's DDIs + pDdiTable->pfnCooperativeKernelLaunchExp = + ur_loader::urEnqueueCooperativeKernelLaunchExp; + } else { + // return pointers directly to platform's DDIs + *pDdiTable = + ur_loader::context->platforms.front().dditable.ur.EnqueueExp; + } + } + + return result; +} + /////////////////////////////////////////////////////////////////////////////// /// @brief Exported function for filling application's Event table /// with current process' addresses @@ -7506,6 +7664,61 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetKernelProcAddrTable( return result; } +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's KernelExp table +/// with current process' addresses +/// +/// @returns +/// - ::UR_RESULT_SUCCESS +/// - ::UR_RESULT_ERROR_UNINITIALIZED +/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::UR_RESULT_ERROR_UNSUPPORTED_VERSION +UR_DLLEXPORT ur_result_t UR_APICALL urGetKernelExpProcAddrTable( + ur_api_version_t version, ///< [in] API version requested + ur_kernel_exp_dditable_t + *pDdiTable ///< [in,out] pointer to table of DDI function pointers +) { + if (nullptr == pDdiTable) { + return UR_RESULT_ERROR_INVALID_NULL_POINTER; + } + + if (ur_loader::context->version < version) { + return UR_RESULT_ERROR_UNSUPPORTED_VERSION; + } + + ur_result_t result = UR_RESULT_SUCCESS; + + // Load the device-platform DDI tables + for (auto &platform : ur_loader::context->platforms) { + if (platform.initStatus != UR_RESULT_SUCCESS) { + continue; + } + auto getTable = reinterpret_cast( + ur_loader::LibLoader::getFunctionPtr( + platform.handle.get(), "urGetKernelExpProcAddrTable")); + if (!getTable) { + continue; + } + platform.initStatus = + getTable(version, &platform.dditable.ur.KernelExp); + } + + if (UR_RESULT_SUCCESS == result) { + if (ur_loader::context->platforms.size() != 1 || + ur_loader::context->forceIntercept) { + // return pointers to loader's DDIs + pDdiTable->pfnSuggestMaxCooperativeGroupCountExp = + ur_loader::urKernelSuggestMaxCooperativeGroupCountExp; + } else { + // return pointers directly to platform's DDIs + *pDdiTable = + ur_loader::context->platforms.front().dditable.ur.KernelExp; + } + } + + return result; +} + /////////////////////////////////////////////////////////////////////////////// /// @brief Exported function for filling application's Mem table /// with current process' addresses diff --git a/source/loader/ur_libapi.cpp b/source/loader/ur_libapi.cpp index dced46e279..c834293bb5 100644 --- a/source/loader/ur_libapi.cpp +++ b/source/loader/ur_libapi.cpp @@ -7654,6 +7654,103 @@ ur_result_t UR_APICALL urCommandBufferEnqueueExp( return exceptionToResult(std::current_exception()); } +/////////////////////////////////////////////////////////////////////////////// +/// @brief Enqueue a command to execute a cooperative kernel +/// +/// @returns +/// - ::UR_RESULT_SUCCESS +/// - ::UR_RESULT_ERROR_UNINITIALIZED +/// - ::UR_RESULT_ERROR_DEVICE_LOST +/// - ::UR_RESULT_ERROR_ADAPTER_SPECIFIC +/// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE +/// + `NULL == hQueue` +/// + `NULL == hKernel` +/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER +/// + `NULL == pGlobalWorkOffset` +/// + `NULL == pGlobalWorkSize` +/// - ::UR_RESULT_ERROR_INVALID_QUEUE +/// - ::UR_RESULT_ERROR_INVALID_KERNEL +/// - ::UR_RESULT_ERROR_INVALID_EVENT +/// - ::UR_RESULT_ERROR_INVALID_EVENT_WAIT_LIST +/// + `phEventWaitList == NULL && numEventsInWaitList > 0` +/// + `phEventWaitList != NULL && numEventsInWaitList == 0` +/// + If event objects in phEventWaitList are not valid events. +/// - ::UR_RESULT_ERROR_INVALID_WORK_DIMENSION +/// - ::UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE +/// - ::UR_RESULT_ERROR_INVALID_VALUE +/// - ::UR_RESULT_ERROR_OUT_OF_HOST_MEMORY +/// - ::UR_RESULT_ERROR_OUT_OF_RESOURCES +ur_result_t UR_APICALL urEnqueueCooperativeKernelLaunchExp( + ur_queue_handle_t hQueue, ///< [in] handle of the queue object + ur_kernel_handle_t hKernel, ///< [in] handle of the kernel object + uint32_t + workDim, ///< [in] number of dimensions, from 1 to 3, to specify the global and + ///< work-group work-items + const size_t * + pGlobalWorkOffset, ///< [in] pointer to an array of workDim unsigned values that specify the + ///< offset used to calculate the global ID of a work-item + const size_t * + pGlobalWorkSize, ///< [in] pointer to an array of workDim unsigned values that specify the + ///< number of global work-items in workDim that will execute the kernel + ///< function + const size_t * + pLocalWorkSize, ///< [in][optional] pointer to an array of workDim unsigned values that + ///< specify the number of local work-items forming a work-group that will + ///< execute the kernel function. + ///< If nullptr, the runtime implementation will choose the work-group + ///< size. + uint32_t numEventsInWaitList, ///< [in] size of the event wait list + const ur_event_handle_t * + phEventWaitList, ///< [in][optional][range(0, numEventsInWaitList)] pointer to a list of + ///< events that must be complete before the kernel execution. + ///< If nullptr, the numEventsInWaitList must be 0, indicating that no wait + ///< event. + ur_event_handle_t * + phEvent ///< [out][optional] return an event object that identifies this particular + ///< kernel execution instance. + ) try { + auto pfnCooperativeKernelLaunchExp = + ur_lib::context->urDdiTable.EnqueueExp.pfnCooperativeKernelLaunchExp; + if (nullptr == pfnCooperativeKernelLaunchExp) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + return pfnCooperativeKernelLaunchExp( + hQueue, hKernel, workDim, pGlobalWorkOffset, pGlobalWorkSize, + pLocalWorkSize, numEventsInWaitList, phEventWaitList, phEvent); +} catch (...) { + return exceptionToResult(std::current_exception()); +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Query the maximum number of work groups for a cooperative kernel +/// +/// @returns +/// - ::UR_RESULT_SUCCESS +/// - ::UR_RESULT_ERROR_UNINITIALIZED +/// - ::UR_RESULT_ERROR_DEVICE_LOST +/// - ::UR_RESULT_ERROR_ADAPTER_SPECIFIC +/// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE +/// + `NULL == hKernel` +/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER +/// + `NULL == pGroupCountRet` +/// - ::UR_RESULT_ERROR_INVALID_KERNEL +ur_result_t UR_APICALL urKernelSuggestMaxCooperativeGroupCountExp( + ur_kernel_handle_t hKernel, ///< [in] handle of the kernel object + uint32_t *pGroupCountRet ///< [out] pointer to maximum number of groups + ) try { + auto pfnSuggestMaxCooperativeGroupCountExp = + ur_lib::context->urDdiTable.KernelExp + .pfnSuggestMaxCooperativeGroupCountExp; + if (nullptr == pfnSuggestMaxCooperativeGroupCountExp) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + return pfnSuggestMaxCooperativeGroupCountExp(hKernel, pGroupCountRet); +} catch (...) { + return exceptionToResult(std::current_exception()); +} + /////////////////////////////////////////////////////////////////////////////// /// @brief Import memory into USM /// diff --git a/source/loader/ur_libddi.cpp b/source/loader/ur_libddi.cpp index 912449b54e..9d0a2566f4 100644 --- a/source/loader/ur_libddi.cpp +++ b/source/loader/ur_libddi.cpp @@ -45,6 +45,11 @@ __urdlllocal ur_result_t context_t::urLoaderInit() { &urDdiTable.Enqueue); } + if (UR_RESULT_SUCCESS == result) { + result = urGetEnqueueExpProcAddrTable(UR_API_VERSION_CURRENT, + &urDdiTable.EnqueueExp); + } + if (UR_RESULT_SUCCESS == result) { result = urGetEventProcAddrTable(UR_API_VERSION_CURRENT, &urDdiTable.Event); @@ -55,6 +60,11 @@ __urdlllocal ur_result_t context_t::urLoaderInit() { &urDdiTable.Kernel); } + if (UR_RESULT_SUCCESS == result) { + result = urGetKernelExpProcAddrTable(UR_API_VERSION_CURRENT, + &urDdiTable.KernelExp); + } + if (UR_RESULT_SUCCESS == result) { result = urGetMemProcAddrTable(UR_API_VERSION_CURRENT, &urDdiTable.Mem); } diff --git a/source/ur_api.cpp b/source/ur_api.cpp index 4bc1cc2889..9a66ab92e3 100644 --- a/source/ur_api.cpp +++ b/source/ur_api.cpp @@ -6459,6 +6459,86 @@ ur_result_t UR_APICALL urCommandBufferEnqueueExp( return result; } +/////////////////////////////////////////////////////////////////////////////// +/// @brief Enqueue a command to execute a cooperative kernel +/// +/// @returns +/// - ::UR_RESULT_SUCCESS +/// - ::UR_RESULT_ERROR_UNINITIALIZED +/// - ::UR_RESULT_ERROR_DEVICE_LOST +/// - ::UR_RESULT_ERROR_ADAPTER_SPECIFIC +/// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE +/// + `NULL == hQueue` +/// + `NULL == hKernel` +/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER +/// + `NULL == pGlobalWorkOffset` +/// + `NULL == pGlobalWorkSize` +/// - ::UR_RESULT_ERROR_INVALID_QUEUE +/// - ::UR_RESULT_ERROR_INVALID_KERNEL +/// - ::UR_RESULT_ERROR_INVALID_EVENT +/// - ::UR_RESULT_ERROR_INVALID_EVENT_WAIT_LIST +/// + `phEventWaitList == NULL && numEventsInWaitList > 0` +/// + `phEventWaitList != NULL && numEventsInWaitList == 0` +/// + If event objects in phEventWaitList are not valid events. +/// - ::UR_RESULT_ERROR_INVALID_WORK_DIMENSION +/// - ::UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE +/// - ::UR_RESULT_ERROR_INVALID_VALUE +/// - ::UR_RESULT_ERROR_OUT_OF_HOST_MEMORY +/// - ::UR_RESULT_ERROR_OUT_OF_RESOURCES +ur_result_t UR_APICALL urEnqueueCooperativeKernelLaunchExp( + ur_queue_handle_t hQueue, ///< [in] handle of the queue object + ur_kernel_handle_t hKernel, ///< [in] handle of the kernel object + uint32_t + workDim, ///< [in] number of dimensions, from 1 to 3, to specify the global and + ///< work-group work-items + const size_t * + pGlobalWorkOffset, ///< [in] pointer to an array of workDim unsigned values that specify the + ///< offset used to calculate the global ID of a work-item + const size_t * + pGlobalWorkSize, ///< [in] pointer to an array of workDim unsigned values that specify the + ///< number of global work-items in workDim that will execute the kernel + ///< function + const size_t * + pLocalWorkSize, ///< [in][optional] pointer to an array of workDim unsigned values that + ///< specify the number of local work-items forming a work-group that will + ///< execute the kernel function. + ///< If nullptr, the runtime implementation will choose the work-group + ///< size. + uint32_t numEventsInWaitList, ///< [in] size of the event wait list + const ur_event_handle_t * + phEventWaitList, ///< [in][optional][range(0, numEventsInWaitList)] pointer to a list of + ///< events that must be complete before the kernel execution. + ///< If nullptr, the numEventsInWaitList must be 0, indicating that no wait + ///< event. + ur_event_handle_t * + phEvent ///< [out][optional] return an event object that identifies this particular + ///< kernel execution instance. +) { + ur_result_t result = UR_RESULT_SUCCESS; + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Query the maximum number of work groups for a cooperative kernel +/// +/// @returns +/// - ::UR_RESULT_SUCCESS +/// - ::UR_RESULT_ERROR_UNINITIALIZED +/// - ::UR_RESULT_ERROR_DEVICE_LOST +/// - ::UR_RESULT_ERROR_ADAPTER_SPECIFIC +/// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE +/// + `NULL == hKernel` +/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER +/// + `NULL == pGroupCountRet` +/// - ::UR_RESULT_ERROR_INVALID_KERNEL +ur_result_t UR_APICALL urKernelSuggestMaxCooperativeGroupCountExp( + ur_kernel_handle_t hKernel, ///< [in] handle of the kernel object + uint32_t *pGroupCountRet ///< [out] pointer to maximum number of groups +) { + ur_result_t result = UR_RESULT_SUCCESS; + return result; +} + /////////////////////////////////////////////////////////////////////////////// /// @brief Import memory into USM ///