Skip to content

Commit

Permalink
Merge pull request #1585 from IntelPython/fix/kernel_arg_type
Browse files Browse the repository at this point in the history
Start kernel_arg_type enums from 0
  • Loading branch information
Diptorup Deb authored Mar 8, 2024
2 parents 545dff2 + 5ef035e commit 13f4443
Show file tree
Hide file tree
Showing 3 changed files with 179 additions and 19 deletions.
149 changes: 149 additions & 0 deletions dpctl/_sycl_queue.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,155 @@ __all__ = [
_logger = logging.getLogger(__name__)


cdef class kernel_arg_type_attribute:
cdef str parent_name
cdef str attr_name
cdef int attr_value

def __cinit__(self, str parent, str name, int value):
self.parent_name = parent
self.attr_name = name
self.attr_value = value

def __repr__(self):
return f"<{self.parent_name}.{self.attr_name}: {self.attr_value}>"

def __str__(self):
return f"<{self.parent_name}.{self.attr_name}: {self.attr_value}>"

@property
def name(self):
return self.attr_name

@property
def value(self):
return self.attr_value


cdef class _kernel_arg_type:
"""
An enumeration of supported kernel argument types in
:func:`dpctl.SyclQueue.submit`
"""
cdef str _name

def __cinit__(self):
self._name = "kernel_arg_type"


@property
def __name__(self):
return self._name

def __repr__(self):
return "<enum 'kernel_arg_type'>"

def __str__(self):
return "<enum 'kernel_arg_type'>"

@property
def dpctl_int8(self):
cdef str p_name = "dpctl_int8"
return kernel_arg_type_attribute(
self._name,
p_name,
_arg_data_type._INT8_T
)

@property
def dpctl_uint8(self):
cdef str p_name = "dpctl_uint8"
return kernel_arg_type_attribute(
self._name,
p_name,
_arg_data_type._UINT8_T
)

@property
def dpctl_int16(self):
cdef str p_name = "dpctl_int16"
return kernel_arg_type_attribute(
self._name,
p_name,
_arg_data_type._INT16_T
)

@property
def dpctl_uint16(self):
cdef str p_name = "dpctl_uint16"
return kernel_arg_type_attribute(
self._name,
p_name,
_arg_data_type._UINT16_T
)

@property
def dpctl_int32(self):
cdef str p_name = "dpctl_int32"
return kernel_arg_type_attribute(
self._name,
p_name,
_arg_data_type._INT32_T
)

@property
def dpctl_uint32(self):
cdef str p_name = "dpctl_uint32"
return kernel_arg_type_attribute(
self._name,
p_name,
_arg_data_type._UINT32_T
)

@property
def dpctl_int64(self):
cdef str p_name = "dpctl_int64"
return kernel_arg_type_attribute(
self._name,
p_name,
_arg_data_type._INT64_T
)

@property
def dpctl_uint64(self):
cdef str p_name = "dpctl_uint64"
return kernel_arg_type_attribute(
self._name,
p_name,
_arg_data_type._UINT64_T
)

@property
def dpctl_float32(self):
cdef str p_name = "dpctl_float32"
return kernel_arg_type_attribute(
self._name,
p_name,
_arg_data_type._FLOAT
)

@property
def dpctl_float64(self):
cdef str p_name = "dpctl_float64"
return kernel_arg_type_attribute(
self._name,
p_name,
_arg_data_type._DOUBLE
)

@property
def dpctl_void_ptr(self):
cdef str p_name = "dpctl_void_ptr"
return kernel_arg_type_attribute(
self._name,
p_name,
_arg_data_type._VOID_PTR
)


kernel_arg_type = _kernel_arg_type()


cdef class SyclKernelSubmitError(Exception):
"""
A SyclKernelSubmitError exception is raised when the provided
Expand Down
19 changes: 0 additions & 19 deletions dpctl/enum_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,22 +113,3 @@ class global_mem_cache_type(Enum):
none = auto()
read_only = auto()
read_write = auto()


class kernel_arg_type(Enum):
"""
An enumeration of supported kernel argument types in
:func:`dpctl.SyclQueue.submit`
"""

dpctl_int8 = auto()
dpctl_uint8 = auto()
dpctl_int16 = auto()
dpctl_uint16 = auto()
dpctl_int32 = auto()
dpctl_uint32 = auto()
dpctl_int64 = auto()
dpctl_uint64 = auto()
dpctl_float32 = auto()
dpctl_float64 = auto()
dpctl_void_ptr = auto()
30 changes: 30 additions & 0 deletions dpctl/tests/test_sycl_kernel_submit.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import dpctl.memory as dpctl_mem
import dpctl.program as dpctl_prog
import dpctl.tensor as dpt
from dpctl._sycl_queue import kernel_arg_type


@pytest.mark.parametrize(
Expand Down Expand Up @@ -244,3 +245,32 @@ def test_submit_async():
Xref[2, i] = min(Xref[0, i], Xref[1, i])

assert np.array_equal(Xnp[:, :n], Xref[:, :n])


def _check_kernel_arg_type_instance(kati):
assert isinstance(kati.name, str)
assert isinstance(kati.value, int)
assert isinstance(repr(kati), str)
assert isinstance(str(kati), str)


def test_kernel_arg_type():
"""
Check that enum values for kernel_arg_type start at 0,
as numba_dpex expects. The next enumerated type must
have next value.
"""
assert isinstance(kernel_arg_type.__name__, str)
assert isinstance(repr(kernel_arg_type), str)
assert isinstance(str(kernel_arg_type), str)
_check_kernel_arg_type_instance(kernel_arg_type.dpctl_int8)
_check_kernel_arg_type_instance(kernel_arg_type.dpctl_uint8)
_check_kernel_arg_type_instance(kernel_arg_type.dpctl_int16)
_check_kernel_arg_type_instance(kernel_arg_type.dpctl_uint16)
_check_kernel_arg_type_instance(kernel_arg_type.dpctl_int32)
_check_kernel_arg_type_instance(kernel_arg_type.dpctl_uint32)
_check_kernel_arg_type_instance(kernel_arg_type.dpctl_int64)
_check_kernel_arg_type_instance(kernel_arg_type.dpctl_uint64)
_check_kernel_arg_type_instance(kernel_arg_type.dpctl_float32)
_check_kernel_arg_type_instance(kernel_arg_type.dpctl_float64)
_check_kernel_arg_type_instance(kernel_arg_type.dpctl_void_ptr)

0 comments on commit 13f4443

Please sign in to comment.