diff --git a/numba_dpex/core/types/kernel_api/atomic_ref.py b/numba_dpex/core/types/kernel_api/atomic_ref.py index cdc2eb1890..6524833d7b 100644 --- a/numba_dpex/core/types/kernel_api/atomic_ref.py +++ b/numba_dpex/core/types/kernel_api/atomic_ref.py @@ -80,3 +80,13 @@ def cast_python_value(self, args): AtomicRefType throws a NotImplementedError. """ raise NotImplementedError + + @property + def mangling_args(self): + args = [ + self.dtype, + self.memory_order, + self.memory_scope, + self.address_space, + ] + return self.__class__.__name__, args diff --git a/numba_dpex/core/types/kernel_api/index_space_ids.py b/numba_dpex/core/types/kernel_api/index_space_ids.py index 798aceeb2e..ea57fd05a4 100644 --- a/numba_dpex/core/types/kernel_api/index_space_ids.py +++ b/numba_dpex/core/types/kernel_api/index_space_ids.py @@ -31,6 +31,11 @@ def key(self): def cast_python_value(self, args): raise NotImplementedError + @property + def mangling_args(self): + args = [self.ndim] + return self.__class__.__name__, args + class ItemType(types.Type): """Numba-dpex type corresponding to :class:`numba_dpex.kernel_api.Item`""" @@ -53,6 +58,11 @@ def key(self): """Numba type specific overload""" return self._ndim + @property + def mangling_args(self): + args = [self.ndim] + return self.__class__.__name__, args + def cast_python_value(self, args): raise NotImplementedError @@ -78,5 +88,10 @@ def key(self): """Numba type specific overload""" return self._ndim + @property + def mangling_args(self): + args = [self.ndim] + return self.__class__.__name__, args + def cast_python_value(self, args): raise NotImplementedError diff --git a/numba_dpex/core/types/kernel_api/ranges.py b/numba_dpex/core/types/kernel_api/ranges.py index af2f0125a7..9dc13e48a3 100644 --- a/numba_dpex/core/types/kernel_api/ranges.py +++ b/numba_dpex/core/types/kernel_api/ranges.py @@ -28,6 +28,11 @@ def ndim(self): def key(self): return self._ndim + @property + def mangling_args(self): + args = [self.ndim] + return self.__class__.__name__, args + class NdRangeType(types.Type): """Numba-dpex type corresponding to @@ -49,3 +54,8 @@ def ndim(self): @property def key(self): return self._ndim + + @property + def mangling_args(self): + args = [self.ndim] + return self.__class__.__name__, args diff --git a/numba_dpex/kernel_api_impl/spirv/target.py b/numba_dpex/kernel_api_impl/spirv/target.py index e1e4bcbc65..8de0185605 100644 --- a/numba_dpex/kernel_api_impl/spirv/target.py +++ b/numba_dpex/kernel_api_impl/spirv/target.py @@ -11,7 +11,7 @@ import dpnp from llvmlite import binding as ll from llvmlite import ir as llvmir -from numba.core import cgutils, funcdesc +from numba.core import cgutils from numba.core import types as nb_types from numba.core import typing from numba.core.base import BaseContext @@ -23,6 +23,7 @@ from numba_dpex.core.datamodel.models import _init_kernel_data_model_manager from numba_dpex.core.types import IntEnumLiteral from numba_dpex.core.typing import dpnpdecl +from numba_dpex.core.utils import itanium_mangler from numba_dpex.kernel_api.flag_enum import FlagEnum from numba_dpex.kernel_api.memory_enums import AddressSpace as address_space from numba_dpex.kernel_api_impl.spirv import printimpl @@ -188,7 +189,7 @@ def _generate_spir_kernel_wrapper(self, func, argtypes): wrapper_module = self._internal_codegen.create_empty_spirv_module( "dpex.kernel.wrapper" ) - wrappername = func.name.replace("dpex_fn", "dpex_kernel") + wrappername = func.name + ("dpex_kernel") argtys = list(arginfo.argument_types) fnty = llvmir.FunctionType( llvmir.IntType(32), @@ -319,12 +320,9 @@ def target_data(self): def mangler(self, name, types, *, abi_tags=(), uid=None): """ - Generates a name for a function by appending \"dpex_fn\" to the - name of the function before calling Numba's default function name - mangler.""" - return funcdesc.default_mangler( - name + "dpex_fn", types, abi_tags=abi_tags, uid=uid - ) + Generates a mangled function name using numba_dpex's itanium mangler. + """ + return itanium_mangler.mangle(name, types, abi_tags=abi_tags, uid=uid) def prepare_spir_kernel(self, func, argtypes): """Generates a wrapper function with \"spir_kernel\" calling conv that