Skip to content

Commit

Permalink
Use enum values for kernel arg types if dpctl >= 0.17
Browse files Browse the repository at this point in the history
  • Loading branch information
Diptorup Deb committed Mar 11, 2024
1 parent e3854ff commit e9240bd
Showing 1 changed file with 11 additions and 10 deletions.
21 changes: 11 additions & 10 deletions numba_dpex/dpctl_iface/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,25 +14,26 @@ def numba_type_to_dpctl_typenum(context, ty):
"""

if dpctl_sem_version >= (0, 17, 0):
# FIXME change to imports from a dpctl enum/class rather than
# hard coding these numbers.
from dpctl._sycl_queue import kernel_arg_type as kargty

if ty == types.boolean:
return context.get_constant(types.int32, 1)
return context.get_constant(types.int32, kargty.dpctl_uint8.value)
elif ty == types.int32 or isinstance(ty, types.scalars.IntegerLiteral):
return context.get_constant(types.int32, 4)
return context.get_constant(types.int32, kargty.dpctl_int32.value)
elif ty == types.uint32:
return context.get_constant(types.int32, 5)
return context.get_constant(types.int32, kargty.dpctl_uint32.value)
elif ty == types.int64:
return context.get_constant(types.int32, 6)
return context.get_constant(types.int32, kargty.dpctl_int64.value)
elif ty == types.uint64:
return context.get_constant(types.int32, 7)
return context.get_constant(types.int32, kargty.dpctl_uint64.value)
elif ty == types.float32:
return context.get_constant(types.int32, 8)
return context.get_constant(types.int32, kargty.dpctl_float32.value)
elif ty == types.float64:
return context.get_constant(types.int32, 9)
return context.get_constant(types.int32, kargty.dpctl_float64.value)
elif ty == types.voidptr or isinstance(ty, types.CPointer):
return context.get_constant(types.int32, 10)
return context.get_constant(
types.int32, kargty.dpctl_void_ptr.value
)
else:
raise NotImplementedError
else:
Expand Down

0 comments on commit e9240bd

Please sign in to comment.