diff --git a/numba_dpex/dpctl_iface/_helpers.py b/numba_dpex/dpctl_iface/_helpers.py index be8354b459..f46915eaf0 100644 --- a/numba_dpex/dpctl_iface/_helpers.py +++ b/numba_dpex/dpctl_iface/_helpers.py @@ -14,25 +14,26 @@ def numba_type_to_dpctl_typenum(context, ty): """ if dpctl_sem_version >= (0, 17, 0): - # FIXME change to imports from a dpctl enum/class rather than - # hard coding these numbers. + from dpctl._sycl_queue import kernel_arg_type as kargty if ty == types.boolean: - return context.get_constant(types.int32, 1) + return context.get_constant(types.int32, kargty.dpctl_uint8.value) elif ty == types.int32 or isinstance(ty, types.scalars.IntegerLiteral): - return context.get_constant(types.int32, 4) + return context.get_constant(types.int32, kargty.dpctl_int32.value) elif ty == types.uint32: - return context.get_constant(types.int32, 5) + return context.get_constant(types.int32, kargty.dpctl_uint32.value) elif ty == types.int64: - return context.get_constant(types.int32, 6) + return context.get_constant(types.int32, kargty.dpctl_int64.value) elif ty == types.uint64: - return context.get_constant(types.int32, 7) + return context.get_constant(types.int32, kargty.dpctl_uint64.value) elif ty == types.float32: - return context.get_constant(types.int32, 8) + return context.get_constant(types.int32, kargty.dpctl_float32.value) elif ty == types.float64: - return context.get_constant(types.int32, 9) + return context.get_constant(types.int32, kargty.dpctl_float64.value) elif ty == types.voidptr or isinstance(ty, types.CPointer): - return context.get_constant(types.int32, 10) + return context.get_constant( + types.int32, kargty.dpctl_void_ptr.value + ) else: raise NotImplementedError else: