Skip to content

Commit

Permalink
Port the parfor range kernel template to new API.
Browse files Browse the repository at this point in the history
  • Loading branch information
Diptorup Deb authored and ZzEeKkAa committed Apr 1, 2024
1 parent 5ff654f commit 09a1cca
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 14 deletions.
4 changes: 4 additions & 0 deletions numba_dpex/core/descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,10 @@ def typing_context(self):
"""
return self._toplevel_typing_context

@property
def target_name(self):
return self._target_name


class DpexTarget(TargetDescriptor):
"""
Expand Down
33 changes: 21 additions & 12 deletions numba_dpex/core/parfors/kernel_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,12 @@
rename_labels,
replace_var_names,
)
from numba.core.target_extension import target_override
from numba.core.typing import signature
from numba.parfors import parfor

from numba_dpex.core import config
from numba_dpex.core.types.kernel_api.index_space_ids import ItemType
from numba_dpex.kernel_api_impl.spirv import spirv_generator

from ..descriptor import dpex_kernel_target
Expand Down Expand Up @@ -66,18 +68,18 @@ def _print_body(body_dict):
def _compile_kernel_parfor(
sycl_queue, kernel_name, func_ir, argtypes, debug=False
):

cres = compile_numba_ir_with_dpex(
pyfunc=func_ir,
pyfunc_name=kernel_name,
args=argtypes,
return_type=None,
debug=debug,
is_kernel=True,
typing_context=dpex_kernel_target.typing_context,
target_context=dpex_kernel_target.target_context,
extra_compile_flags=None,
)
with target_override(dpex_kernel_target.target_context.target_name):
cres = compile_numba_ir_with_dpex(
pyfunc=func_ir,
pyfunc_name=kernel_name,
args=argtypes,
return_type=None,
debug=debug,
is_kernel=True,
typing_context=dpex_kernel_target.typing_context,
target_context=dpex_kernel_target.target_context,
extra_compile_flags=None,
)
cres.library.inline_threshold = config.INLINE_THRESHOLD
cres.library._optimize_final_module()
func = cres.library.get_function(cres.fndesc.llvm_func_name)
Expand Down Expand Up @@ -420,6 +422,13 @@ def create_kernel_for_parfor(
print("kernel_ir after remove dead")
kernel_ir.dump()

# The first argument to a range kernel is a kernel_api.Item object. The
# ``Item`` object is used by the kernel_api.spirv backend to generate the
# correct SPIR-V indexing instructions. Since, the argument is not something
# available originally in the kernel_param_types, we add it at this point to
# make sure the kernel signature matches the actual generated code.
ty_item = ItemType(parfor_dim)
kernel_param_types = (ty_item, *kernel_param_types)
kernel_sig = signature(types.none, *kernel_param_types)

if config.DEBUG_ARRAY_OPT:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,15 +63,15 @@ def _generate_kernel_stub_as_string(self):

# Create the dpex kernel function.
kernel_txt += "def " + self._kernel_name
kernel_txt += "(" + (", ".join(self._kernel_params)) + "):\n"
kernel_txt += "(item, " + (", ".join(self._kernel_params)) + "):\n"
global_id_dim = 0
for_loop_dim = self._kernel_rank
global_id_dim = self._kernel_rank

for dim in range(global_id_dim):
dimstr = str(dim)
kernel_txt += (
f" {self._ivar_names[dim]} = dpex.get_global_id({dimstr})\n"
f" {self._ivar_names[dim]} = item.get_id({dimstr})\n"
)

for dim in range(global_id_dim, for_loop_dim):
Expand Down

0 comments on commit 09a1cca

Please sign in to comment.