Skip to content

Commit

Permalink
Add the mangled_Args property to kernel_api types.
Browse files Browse the repository at this point in the history
  • Loading branch information
Diptorup Deb committed Apr 24, 2024
1 parent 9ef3a7a commit 1379e63
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 8 deletions.
10 changes: 10 additions & 0 deletions numba_dpex/core/types/kernel_api/atomic_ref.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,13 @@ def cast_python_value(self, args):
AtomicRefType throws a NotImplementedError.
"""
raise NotImplementedError

@property
def mangling_args(self):
args = [
self.dtype,
self.memory_order,
self.memory_scope,
self.address_space,
]
return self.__class__.__name__, args
15 changes: 15 additions & 0 deletions numba_dpex/core/types/kernel_api/index_space_ids.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ def key(self):
def cast_python_value(self, args):
raise NotImplementedError

@property
def mangling_args(self):
args = [self.ndim]
return self.__class__.__name__, args


class ItemType(types.Type):
"""Numba-dpex type corresponding to :class:`numba_dpex.kernel_api.Item`"""
Expand All @@ -53,6 +58,11 @@ def key(self):
"""Numba type specific overload"""
return self._ndim

@property
def mangling_args(self):
args = [self.ndim]
return self.__class__.__name__, args

def cast_python_value(self, args):
raise NotImplementedError

Expand All @@ -78,5 +88,10 @@ def key(self):
"""Numba type specific overload"""
return self._ndim

@property
def mangling_args(self):
args = [self.ndim]
return self.__class__.__name__, args

def cast_python_value(self, args):
raise NotImplementedError
10 changes: 10 additions & 0 deletions numba_dpex/core/types/kernel_api/ranges.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ def ndim(self):
def key(self):
return self._ndim

@property
def mangling_args(self):
args = [self.ndim]
return self.__class__.__name__, args


class NdRangeType(types.Type):
"""Numba-dpex type corresponding to
Expand All @@ -49,3 +54,8 @@ def ndim(self):
@property
def key(self):
return self._ndim

@property
def mangling_args(self):
args = [self.ndim]
return self.__class__.__name__, args
14 changes: 6 additions & 8 deletions numba_dpex/kernel_api_impl/spirv/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import dpnp
from llvmlite import binding as ll
from llvmlite import ir as llvmir
from numba.core import cgutils, funcdesc
from numba.core import cgutils
from numba.core import types as nb_types
from numba.core import typing
from numba.core.base import BaseContext
Expand All @@ -23,6 +23,7 @@
from numba_dpex.core.datamodel.models import _init_kernel_data_model_manager
from numba_dpex.core.types import IntEnumLiteral
from numba_dpex.core.typing import dpnpdecl
from numba_dpex.core.utils import itanium_mangler
from numba_dpex.kernel_api.flag_enum import FlagEnum
from numba_dpex.kernel_api.memory_enums import AddressSpace as address_space
from numba_dpex.kernel_api_impl.spirv import printimpl
Expand Down Expand Up @@ -188,7 +189,7 @@ def _generate_spir_kernel_wrapper(self, func, argtypes):
wrapper_module = self._internal_codegen.create_empty_spirv_module(
"dpex.kernel.wrapper"
)
wrappername = func.name.replace("dpex_fn", "dpex_kernel")
wrappername = func.name + ("dpex_kernel")
argtys = list(arginfo.argument_types)
fnty = llvmir.FunctionType(
llvmir.IntType(32),
Expand Down Expand Up @@ -319,12 +320,9 @@ def target_data(self):

def mangler(self, name, types, *, abi_tags=(), uid=None):
"""
Generates a name for a function by appending \"dpex_fn\" to the
name of the function before calling Numba's default function name
mangler."""
return funcdesc.default_mangler(
name + "dpex_fn", types, abi_tags=abi_tags, uid=uid
)
Generates a mangled function name using numba_dpex's itanium mangler.
"""
return itanium_mangler.mangle(name, types, abi_tags=abi_tags, uid=uid)

def prepare_spir_kernel(self, func, argtypes):
"""Generates a wrapper function with \"spir_kernel\" calling conv that
Expand Down

0 comments on commit 1379e63

Please sign in to comment.