Skip to content

Commit

Permalink
Support local_accessor kernel arguments.
Browse files Browse the repository at this point in the history
    - Adds support in libsyclinterface:: dpctl_sycl_queue_interface for
      sycl::local_accessor as kernel arguments.
    - Refactoring to get rid of compiler warnings.
  • Loading branch information
Diptorup Deb committed Mar 6, 2024
1 parent 56374ee commit e6f4f55
Show file tree
Hide file tree
Showing 5 changed files with 119 additions and 8 deletions.
2 changes: 2 additions & 0 deletions dpctl/enum_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
"device_type",
"backend_type",
"event_status_type",
"kernel_arg_type",
]


Expand Down Expand Up @@ -132,3 +133,4 @@ class kernel_arg_type(Enum):
dpctl_float32 = auto()
dpctl_float64 = auto()
dpctl_void_ptr = auto()
dpctl_local_accessor = auto()
2 changes: 1 addition & 1 deletion libsyclinterface/helper/include/dpctl_error_handlers.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
///
/// \file
/// A functor to use for passing an error handler callback function to sycl
/// context and queue contructors.
/// context and queue constructors.
//===----------------------------------------------------------------------===//

#pragma once
Expand Down
1 change: 1 addition & 0 deletions libsyclinterface/include/dpctl_sycl_enum_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ typedef enum
DPCTL_FLOAT32_T,
DPCTL_FLOAT64_T,
DPCTL_VOID_PTR,
DPCTL_LOCAL_ACCESSOR,
DPCTL_UNSUPPORTED_KERNEL_ARG
} DPCTLKernelArgType;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,7 @@ _GetKernel_ze_impl(const kernel_bundle<bundle_state::executable> &kb,
else {
error_handler("Kernel named " + std::string(kernel_name) +
" could not be found.",
__FILE__, __func__, __LINE__);
__FILE__, __func__, __LINE__, error_level::error);
return nullptr;
}
}
Expand All @@ -541,7 +541,7 @@ bool _HasKernel_ze_impl(const kernel_bundle<bundle_state::executable> &kb,
auto zeKernelCreateFn = get_zeKernelCreate();
if (zeKernelCreateFn == nullptr) {
error_handler("Could not load zeKernelCreate function.", __FILE__,
__func__, __LINE__);
__func__, __LINE__, error_level::error);
return false;
}

Expand All @@ -564,7 +564,7 @@ bool _HasKernel_ze_impl(const kernel_bundle<bundle_state::executable> &kb,
if (ze_status != ZE_RESULT_ERROR_INVALID_KERNEL_NAME) {
error_handler("zeKernelCreate failed: " +
_GetErrorCode_ze_impl(ze_status),
__FILE__, __func__, __LINE__);
__FILE__, __func__, __LINE__, error_level::error);
return false;
}
}
Expand Down
116 changes: 112 additions & 4 deletions libsyclinterface/source/dpctl_sycl_queue_interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,76 @@

using namespace sycl;

#define SET_LOCAL_ACCESSOR_ARG(CGH, NDIM, ARGTY, R, IDX) \
do { \
switch ((ARGTY)) { \
case DPCTL_INT8_T: \
{ \
auto la = local_accessor<int8_t, NDIM>(R, CGH); \
CGH.set_arg(IDX, la); \
return true; \
} \
case DPCTL_UINT8_T: \
{ \
auto la = local_accessor<uint8_t, NDIM>(R, CGH); \
CGH.set_arg(IDX, la); \
return true; \
} \
case DPCTL_INT16_T: \
{ \
auto la = local_accessor<int16_t, NDIM>(R, CGH); \
CGH.set_arg(IDX, la); \
return true; \
} \
case DPCTL_UINT16_T: \
{ \
auto la = local_accessor<uint16_t, NDIM>(R, CGH); \
CGH.set_arg(IDX, la); \
return true; \
} \
case DPCTL_INT32_T: \
{ \
auto la = local_accessor<int32_t, NDIM>(R, CGH); \
CGH.set_arg(IDX, la); \
return true; \
} \
case DPCTL_UINT32_T: \
{ \
auto la = local_accessor<uint32_t, NDIM>(R, CGH); \
CGH.set_arg(IDX, la); \
return true; \
} \
case DPCTL_INT64_T: \
{ \
auto la = local_accessor<int64_t, NDIM>(R, CGH); \
CGH.set_arg(IDX, la); \
return true; \
} \
case DPCTL_UINT64_T: \
{ \
auto la = local_accessor<uint64_t, NDIM>(R, CGH); \
CGH.set_arg(IDX, la); \
return true; \
} \
case DPCTL_FLOAT32_T: \
{ \
auto la = local_accessor<float, NDIM>(R, CGH); \
CGH.set_arg(IDX, la); \
return true; \
} \
case DPCTL_FLOAT64_T: \
{ \
auto la = local_accessor<double, NDIM>(R, CGH); \
CGH.set_arg(IDX, la); \
return true; \
} \
default: \
error_handler("Kernel argument could not be created.", __FILE__, \
__func__, __LINE__, error_level::error); \
return false; \
} \
} while (0);

namespace
{
static_assert(__SYCL_COMPILER_VERSION >= __SYCL_COMPILER_VERSION_REQUIRED,
Expand All @@ -51,6 +121,15 @@ typedef struct complex
uint64_t imag;
} complexNumber;

typedef struct MDLocalAccessorTy
{
size_t ndim;
DPCTLKernelArgType dpctl_type_id;
size_t dim0;
size_t dim1;
size_t dim2;
} MDLocalAccessor;

void set_dependent_events(handler &cgh,
__dpctl_keep const DPCTLSyclEventRef *DepEvents,
size_t NDepEvents)
Expand All @@ -62,11 +141,39 @@ void set_dependent_events(handler &cgh,
}
}

bool set_local_accessor_arg(handler &cgh,
size_t idx,
const MDLocalAccessor *mdstruct)
{
switch (mdstruct->ndim) {
case 1:
{
auto r = range<1>(mdstruct->dim0);
SET_LOCAL_ACCESSOR_ARG(cgh, 1, mdstruct->dpctl_type_id, r, idx);
}
case 2:
{
auto r = range<2>(mdstruct->dim0, mdstruct->dim1);
SET_LOCAL_ACCESSOR_ARG(cgh, 2, mdstruct->dpctl_type_id, r, idx);
}
case 3:
{
auto r = range<3>(mdstruct->dim0, mdstruct->dim1, mdstruct->dim2);
SET_LOCAL_ACCESSOR_ARG(cgh, 3, mdstruct->dpctl_type_id, r, idx);
}
default:
return false;
}
}
/*!
* @brief Set the kernel arg object
*
* @param cgh My Param doc
* @param Arg My Param doc
* @param cgh SYCL command group handler using which a kernel is going to
* be submitted.
* @param idx The position of the argument in the list of arguments passed
* to a kernel.
* @param Arg A void* representing a kernel argument.
* @param Argty A typeid specifying the C++ type of the Arg parameter.
*/
bool set_kernel_arg(handler &cgh,
size_t idx,
Expand Down Expand Up @@ -109,10 +216,11 @@ bool set_kernel_arg(handler &cgh,
case DPCTL_VOID_PTR:
cgh.set_arg(idx, Arg);
break;
case DPCTL_LOCAL_ACCESSOR:
arg_set = set_local_accessor_arg(cgh, idx, (MDLocalAccessor *)Arg);
break;
default:
arg_set = false;
error_handler("Kernel argument could not be created.", __FILE__,
__func__, __LINE__);
break;
}
return arg_set;
Expand Down

0 comments on commit e6f4f55

Please sign in to comment.