From e9240bd8b93f21589eef012f3370ddec17152a7e Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Wed, 6 Mar 2024 21:51:46 -0600 Subject: [PATCH] Use enum values for kernel arg types if dpctl >= 0.17 --- numba_dpex/dpctl_iface/_helpers.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/numba_dpex/dpctl_iface/_helpers.py b/numba_dpex/dpctl_iface/_helpers.py index be8354b459..f46915eaf0 100644 --- a/numba_dpex/dpctl_iface/_helpers.py +++ b/numba_dpex/dpctl_iface/_helpers.py @@ -14,25 +14,26 @@ def numba_type_to_dpctl_typenum(context, ty): """ if dpctl_sem_version >= (0, 17, 0): - # FIXME change to imports from a dpctl enum/class rather than - # hard coding these numbers. + from dpctl._sycl_queue import kernel_arg_type as kargty if ty == types.boolean: - return context.get_constant(types.int32, 1) + return context.get_constant(types.int32, kargty.dpctl_uint8.value) elif ty == types.int32 or isinstance(ty, types.scalars.IntegerLiteral): - return context.get_constant(types.int32, 4) + return context.get_constant(types.int32, kargty.dpctl_int32.value) elif ty == types.uint32: - return context.get_constant(types.int32, 5) + return context.get_constant(types.int32, kargty.dpctl_uint32.value) elif ty == types.int64: - return context.get_constant(types.int32, 6) + return context.get_constant(types.int32, kargty.dpctl_int64.value) elif ty == types.uint64: - return context.get_constant(types.int32, 7) + return context.get_constant(types.int32, kargty.dpctl_uint64.value) elif ty == types.float32: - return context.get_constant(types.int32, 8) + return context.get_constant(types.int32, kargty.dpctl_float32.value) elif ty == types.float64: - return context.get_constant(types.int32, 9) + return context.get_constant(types.int32, kargty.dpctl_float64.value) elif ty == types.voidptr or isinstance(ty, types.CPointer): - return context.get_constant(types.int32, 10) + return context.get_constant( + types.int32, kargty.dpctl_void_ptr.value + ) else: raise NotImplementedError else: