diff --git a/numba_dpex/core/descriptor.py b/numba_dpex/core/descriptor.py index 369838ed25..f8637103c1 100644 --- a/numba_dpex/core/descriptor.py +++ b/numba_dpex/core/descriptor.py @@ -48,6 +48,8 @@ class DpexTargetOptions(CPUTargetOptions): no_compile = _option_mapping("no_compile") inline_threshold = _option_mapping("inline_threshold") _compilation_mode = _option_mapping("_compilation_mode") + # TODO: create separate parfor kernel target + _parfor_body_args = _option_mapping("_parfor_body_args") def finalize(self, flags, options): super().finalize(flags, options) @@ -63,6 +65,7 @@ def finalize(self, flags, options): _inherit_if_not_set( flags, options, "_compilation_mode", CompilationMode.KERNEL ) + _inherit_if_not_set(flags, options, "_parfor_body_args", None) class DpexKernelTarget(TargetDescriptor): diff --git a/numba_dpex/core/parfors/kernel_builder.py b/numba_dpex/core/parfors/kernel_builder.py index 1eda1664ae..6f059c8452 100644 --- a/numba_dpex/core/parfors/kernel_builder.py +++ b/numba_dpex/core/parfors/kernel_builder.py @@ -26,8 +26,17 @@ from numba.parfors import parfor from numba_dpex.core import config +from numba_dpex.core.decorators import kernel +from numba_dpex.core.parfors.parfor_sentinel_replace_pass import ( + ParforBodyArguments, +) from numba_dpex.core.types.kernel_api.index_space_ids import ItemType +from numba_dpex.core.utils.call_kernel_builder import SPIRVKernelModule from numba_dpex.kernel_api_impl.spirv import spirv_generator +from numba_dpex.kernel_api_impl.spirv.dispatcher import ( + SPIRVKernelDispatcher, + _SPIRVKernelCompileResult, +) from ..descriptor import dpex_kernel_target from ..types import DpnpNdArray @@ -38,79 +47,19 @@ class ParforKernel: def __init__( self, - name, - kernel, signature, kernel_args, kernel_arg_types, - queue: dpctl.SyclQueue, local_accessors=None, work_group_size=None, + kernel_module=None, ): - self.name = name - self.kernel = kernel self.signature = signature self.kernel_args = kernel_args self.kernel_arg_types = kernel_arg_types - self.queue = queue self.local_accessors = local_accessors self.work_group_size = work_group_size - - -def _print_block(block): - for i, inst in enumerate(block.body): - print(" ", i, inst) - - -def _print_body(body_dict): - """Pretty-print a set of IR blocks.""" - for label, block in body_dict.items(): - print("label: ", label) - _print_block(block) - - -def _compile_kernel_parfor( - sycl_queue, kernel_name, func_ir, argtypes, debug=False -): - with target_override(dpex_kernel_target.target_context.target_name): - cres = compile_numba_ir_with_dpex( - pyfunc=func_ir, - pyfunc_name=kernel_name, - args=argtypes, - return_type=None, - debug=debug, - is_kernel=True, - typing_context=dpex_kernel_target.typing_context, - target_context=dpex_kernel_target.target_context, - extra_compile_flags=None, - ) - cres.library.inline_threshold = config.INLINE_THRESHOLD - cres.library._optimize_final_module() - func = cres.library.get_function(cres.fndesc.llvm_func_name) - kernel = dpex_kernel_target.target_context.prepare_spir_kernel( - func, cres.signature.args - ) - spirv_module = spirv_generator.llvm_to_spirv( - dpex_kernel_target.target_context, - kernel.module.__str__(), - kernel.module.as_bitcode(), - ) - - dpctl_create_program_from_spirv_flags = [] - if debug or config.DPEX_OPT == 0: - # if debug is ON we need to pass additional flags to igc. - dpctl_create_program_from_spirv_flags = ["-g", "-cl-opt-disable"] - - # create a sycl::kernel_bundle - kernel_bundle = dpctl_prog.create_program_from_spirv( - sycl_queue, - spirv_module, - " ".join(dpctl_create_program_from_spirv_flags), - ) - # create a sycl::kernel - sycl_kernel = kernel_bundle.get_sycl_kernel(kernel.name) - - return sycl_kernel + self.kernel_module = kernel_module def _legalize_names_with_typemap(names, typemap): @@ -189,76 +138,11 @@ def _replace_var_with_array(vars, loop_body, typemap, calltypes): typemap[v] = types.npytypes.Array(el_typ, 1, "C") -def _find_setitems_block(setitems, block, typemap): - for inst in block.body: - if isinstance(inst, ir.StaticSetItem) or isinstance(inst, ir.SetItem): - setitems.add(inst.target.name) - elif isinstance(inst, parfor.Parfor): - _find_setitems_block(setitems, inst.init_block, typemap) - _find_setitems_body(setitems, inst.loop_body, typemap) - - -def _find_setitems_body(setitems, loop_body, typemap): - """ - Find the arrays that are written into (goes into setitems) - """ - for label, block in loop_body.items(): - _find_setitems_block(setitems, block, typemap) - - -def _replace_sentinel_with_parfor_body(kernel_ir, sentinel_name, loop_body): - # new label for splitting sentinel block - new_label = max(loop_body.keys()) + 1 - - # Search all the block in the kernel function for the sentinel assignment. - for label, block in kernel_ir.blocks.items(): - for i, inst in enumerate(block.body): - if ( - isinstance(inst, ir.Assign) - and inst.target.name == sentinel_name - ): - # We found the sentinel assignment. - loc = inst.loc - scope = block.scope - # split block across __sentinel__ - # A new block is allocated for the statements prior to the - # sentinel but the new block maintains the current block label. - prev_block = ir.Block(scope, loc) - prev_block.body = block.body[:i] - - # The current block is used for statements after the sentinel. - block.body = block.body[i + 1 :] # noqa: E203 - # But the current block gets a new label. - body_first_label = min(loop_body.keys()) - - # The previous block jumps to the minimum labelled block of the - # parfor body. - prev_block.append(ir.Jump(body_first_label, loc)) - # Add all the parfor loop body blocks to the kernel function's - # IR. - for loop, b in loop_body.items(): - kernel_ir.blocks[loop] = b - body_last_label = max(loop_body.keys()) - kernel_ir.blocks[new_label] = block - kernel_ir.blocks[label] = prev_block - # Add a jump from the last parfor body block to the block - # containing statements after the sentinel. - kernel_ir.blocks[body_last_label].append( - ir.Jump(new_label, loc) - ) - break - else: - continue - break - - def create_kernel_for_parfor( lowerer, parfor_node, typemap, - flags, loop_ranges, - has_aliases, races, parfor_outputs, ) -> ParforKernel: @@ -367,120 +251,38 @@ def create_kernel_for_parfor( loop_ranges=loop_ranges, param_dict=param_dict, ) - kernel_ir = kernel_template.kernel_ir - if config.DEBUG_ARRAY_OPT: - print("kernel_ir dump ", type(kernel_ir)) - kernel_ir.dump() - print("loop_body dump ", type(loop_body)) - _print_body(loop_body) - - # rename all variables in kernel_ir afresh - var_table = get_name_var_table(kernel_ir.blocks) - new_var_dict = {} - reserved_names = ( - [sentinel_name] + list(param_dict.values()) + legal_loop_indices + kernel_dispatcher: SPIRVKernelDispatcher = kernel( + kernel_template.py_func, + _parfor_body_args=ParforBodyArguments( + loop_body=loop_body, + param_dict=param_dict, + legal_loop_indices=legal_loop_indices, + ), ) - for name, var in var_table.items(): - if not (name in reserved_names): - new_var_dict[name] = mk_unique_var(name) - replace_var_names(kernel_ir.blocks, new_var_dict) - if config.DEBUG_ARRAY_OPT: - print("kernel_ir dump after renaming ") - kernel_ir.dump() - - kernel_param_types = param_types - if config.DEBUG_ARRAY_OPT: - print( - "kernel_param_types = ", - type(kernel_param_types), - "\n", - kernel_param_types, - ) - - kernel_stub_last_label = max(kernel_ir.blocks.keys()) + 1 - - # Add kernel stub last label to each parfor.loop_body label to prevent - # label conflicts. - loop_body = add_offset_to_labels(loop_body, kernel_stub_last_label) - - _replace_sentinel_with_parfor_body(kernel_ir, sentinel_name, loop_body) - - if config.DEBUG_ARRAY_OPT: - print("kernel_ir last dump before renaming") - kernel_ir.dump() - - kernel_ir.blocks = rename_labels(kernel_ir.blocks) - remove_dels(kernel_ir.blocks) - - old_alias = flags.noalias - if not has_aliases: - if config.DEBUG_ARRAY_OPT: - print("No aliases found so adding noalias flag.") - flags.noalias = True - - remove_dead(kernel_ir.blocks, kernel_ir.arg_names, kernel_ir, typemap) - - if config.DEBUG_ARRAY_OPT: - print("kernel_ir after remove dead") - kernel_ir.dump() - - # The first argument to a range kernel is a kernel_api.Item object. The - # ``Item`` object is used by the kernel_api.spirv backend to generate the + # The first argument to a range kernel is a kernel_api.NdItem object. The + # ``NdItem`` object is used by the kernel_api.spirv backend to generate the # correct SPIR-V indexing instructions. Since, the argument is not something # available originally in the kernel_param_types, we add it at this point to # make sure the kernel signature matches the actual generated code. ty_item = ItemType(parfor_dim) - kernel_param_types = (ty_item, *kernel_param_types) + kernel_param_types = (ty_item, *param_types) kernel_sig = signature(types.none, *kernel_param_types) - if config.DEBUG_ARRAY_OPT: - sys.stdout.flush() - - if config.DEBUG_ARRAY_OPT: - print("after DUFunc inline".center(80, "-")) - kernel_ir.dump() - - # The ParforLegalizeCFD pass has already ensured that the LHS and RHS - # arrays are on same device. We can take the queue from the first input - # array and use that to compile the kernel. - - exec_queue: dpctl.SyclQueue = None - - for arg in parfor_args: - obj = typemap[arg] - if isinstance(obj, DpnpNdArray): - filter_string = obj.queue.sycl_device - # FIXME: A better design is required so that we do not have to - # create a queue every time. - exec_queue = dpctl.get_device_cached_queue(filter_string) - - if not exec_queue: - raise AssertionError( - "No execution found for parfor. No way to compile the kernel!" - ) - - sycl_kernel = _compile_kernel_parfor( - exec_queue, - kernel_name, - kernel_ir, - kernel_param_types, - debug=flags.debuginfo, + kcres: _SPIRVKernelCompileResult = kernel_dispatcher.get_compile_result( + types.void(*kernel_param_types) # kernel signature ) - - flags.noalias = old_alias + kernel_module: SPIRVKernelModule = kcres.kernel_device_ir_module if config.DEBUG_ARRAY_OPT: print("kernel_sig = ", kernel_sig) return ParforKernel( - name=kernel_name, - kernel=sycl_kernel, signature=kernel_sig, kernel_args=parfor_args, kernel_arg_types=func_arg_types, - queue=exec_queue, + kernel_module=kernel_module, ) diff --git a/numba_dpex/core/parfors/kernel_templates/kernel_template_iface.py b/numba_dpex/core/parfors/kernel_templates/kernel_template_iface.py index 453017dbbb..9bb7457218 100644 --- a/numba_dpex/core/parfors/kernel_templates/kernel_template_iface.py +++ b/numba_dpex/core/parfors/kernel_templates/kernel_template_iface.py @@ -30,13 +30,9 @@ def _generate_kernel_ir(self): def dump_kernel_string(self): raise NotImplementedError - @abc.abstractmethod - def dump_kernel_ir(self): - raise NotImplementedError - @property @abc.abstractmethod - def kernel_ir(self): + def py_func(self): raise NotImplementedError @property diff --git a/numba_dpex/core/parfors/kernel_templates/range_kernel_template.py b/numba_dpex/core/parfors/kernel_templates/range_kernel_template.py index ecbfa5df4c..f5dd4304e1 100644 --- a/numba_dpex/core/parfors/kernel_templates/range_kernel_template.py +++ b/numba_dpex/core/parfors/kernel_templates/range_kernel_template.py @@ -5,7 +5,6 @@ import sys import dpnp -from numba.core import compiler import numba_dpex as dpex @@ -51,7 +50,7 @@ def __init__( self._param_dict = param_dict self._kernel_txt = self._generate_kernel_stub_as_string() - self._kernel_ir = self._generate_kernel_ir() + self._py_func = self._generate_kernel_ir() def _generate_kernel_stub_as_string(self): """Generates a stub dpex kernel for the parfor as a string. @@ -109,17 +108,15 @@ def _generate_kernel_ir(self): globls = {"dpnp": dpnp, "dpex": dpex} locls = {} exec(self._kernel_txt, globls, locls) - kernel_fn = locls[self._kernel_name] - - return compiler.run_frontend(kernel_fn) + return locls[self._kernel_name] @property - def kernel_ir(self): - """Returns the Numba IR generated for a RangeKernelTemplate. - - Returns: The Numba functionIR object for the compiled kernel_txt string. + def py_func(self): + """Returns the python function generated for a + TreeReduceIntermediateKernelTemplate. + Returns: The python function object for the compiled kernel_txt string. """ - return self._kernel_ir + return self._py_func @property def kernel_string(self): @@ -134,7 +131,3 @@ def dump_kernel_string(self): """Helper to print the kernel function string.""" print(self._kernel_txt) sys.stdout.flush() - - def dump_kernel_ir(self): - """Helper to dump the Numba IR for the RangeKernelTemplate.""" - self._kernel_ir.dump() diff --git a/numba_dpex/core/parfors/kernel_templates/reduction_template.py b/numba_dpex/core/parfors/kernel_templates/reduction_template.py index d78f8cd449..199a3f2b41 100644 --- a/numba_dpex/core/parfors/kernel_templates/reduction_template.py +++ b/numba_dpex/core/parfors/kernel_templates/reduction_template.py @@ -6,7 +6,6 @@ import sys import dpnp -from numba.core import compiler import numba_dpex.kernel_api as kapi @@ -48,7 +47,7 @@ def __init__( self._typemap = typemap self._kernel_txt = self._generate_kernel_stub_as_string() - self._kernel_ir = self._generate_kernel_ir() + self._py_func = self._generate_kernel_ir() def _generate_kernel_stub_as_string(self): """Generate reduction main kernel template""" @@ -161,18 +160,15 @@ def _generate_kernel_ir(self): globls = {"dpnp": dpnp, "kapi": kapi} locls = {} exec(self._kernel_txt, globls, locls) - kernel_fn = locls[self._kernel_name] - - return compiler.run_frontend(kernel_fn) + return locls[self._kernel_name] @property - def kernel_ir(self): - """Returns the Numba IR generated for a + def py_func(self): + """Returns the python function generated for a TreeReduceIntermediateKernelTemplate. - - Returns: The Numba functionIR object for the compiled kernel_txt string. + Returns: The python function object for the compiled kernel_txt string. """ - return self._kernel_ir + return self._py_func @property def kernel_string(self): @@ -190,11 +186,6 @@ def dump_kernel_string(self): print(self._kernel_txt) sys.stdout.flush() - def dump_kernel_ir(self): - """Helper to dump the Numba IR for a - TreeReduceIntermediateKernelTemplate.""" - self._kernel_ir.dump() - class RemainderReduceIntermediateKernelTemplate(KernelTemplateInterface): """The class to build reduction remainder kernel_txt template and @@ -234,7 +225,7 @@ def __init__( self._reductionKernelVar = reductionKernelVar self._kernel_txt = self._generate_kernel_stub_as_string() - self._kernel_ir = self._generate_kernel_ir() + self._py_func = self._generate_kernel_ir() def _generate_kernel_stub_as_string(self): """Generate reduction remainder kernel template""" @@ -320,18 +311,15 @@ def _generate_kernel_ir(self): globls = {"dpnp": dpnp, "kapi": kapi} locls = {} exec(self._kernel_txt, globls, locls) - kernel_fn = locls[self._kernel_name] - - return compiler.run_frontend(kernel_fn) + return locls[self._kernel_name] @property - def kernel_ir(self): - """Returns the Numba IR generated for a - RemainderReduceIntermediateKernelTemplate. - - Returns: The Numba functionIR object for the compiled kernel_txt string. + def py_func(self): + """Returns the python function generated for a + TreeReduceIntermediateKernelTemplate. + Returns: The python function object for the compiled kernel_txt string. """ - return self._kernel_ir + return self._py_func @property def kernel_string(self): @@ -349,9 +337,3 @@ def dump_kernel_string(self): print(self._kernel_txt) sys.stdout.flush() - - def dump_kernel_ir(self): - """Helper to dump the Numba IR for the - RemainderReduceIntermediateKernelTemplate.""" - - self._kernel_ir.dump() diff --git a/numba_dpex/core/parfors/parfor_lowerer.py b/numba_dpex/core/parfors/parfor_lowerer.py index 41bf86868f..3eb39b2cf5 100644 --- a/numba_dpex/core/parfors/parfor_lowerer.py +++ b/numba_dpex/core/parfors/parfor_lowerer.py @@ -31,9 +31,6 @@ create_reduction_remainder_kernel_for_parfor, ) -# A global list of kernels to keep the objects alive indefinitely. -keep_alive_kernels = [] - def _getvar(lowerer, x): """Returns the LLVM Value corresponding to a Numba IR variable. @@ -154,20 +151,16 @@ def _submit_parfor_kernel( kernel_fn: ParforKernel, global_range, local_range, + debug=False, ): """ Adds a call to submit a kernel function into the function body of the current Numba JIT compiled function. """ - # Ensure that the Python arguments are kept alive for the duration of - # the kernel execution - keep_alive_kernels.append(kernel_fn.kernel) kl_builder = KernelLaunchIRBuilder( lowerer.context, lowerer.builder, kernel_dmm ) - queue_ref = kl_builder.get_queue(exec_queue=kernel_fn.queue) - kernel_args = [] for i, arg in enumerate(kernel_fn.kernel_args): if ( @@ -188,24 +181,23 @@ def _submit_parfor_kernel( else: kernel_args.append(_getvar(lowerer, arg)) - kernel_ref_addr = kernel_fn.kernel.addressof_ref() - kernel_ref = lowerer.builder.inttoptr( - lowerer.context.get_constant(types.uintp, kernel_ref_addr), - cgutils.voidptr_t, - ) - - kl_builder.set_kernel(kernel_ref) - kl_builder.set_queue(queue_ref) kl_builder.set_range(global_range, local_range) kl_builder.set_arguments( kernel_fn.kernel_arg_types, kernel_args=kernel_args ) + kl_builder.set_queue_from_arguments() kl_builder.set_dependent_events([]) + kl_builder.set_kernel_from_spirv( + kernel_fn.kernel_module, + debug=debug, + ) + event_ref = kl_builder.submit() sycl.dpctl_event_wait(lowerer.builder, event_ref) sycl.dpctl_event_delete(lowerer.builder, event_ref) - sycl.dpctl_queue_delete(lowerer.builder, queue_ref) + + return kl_builder.arguments.sycl_queue_ref def _reduction_codegen( self, @@ -263,8 +255,6 @@ def _reduction_codegen( loop_ranges, parfor, typemap, - flags, - bool(alias_map), reductionKernelVar, parfor_reddict, ) @@ -278,13 +268,12 @@ def _reduction_codegen( parfor_kernel, global_range, local_range, + debug=flags.debuginfo, ) parfor_kernel = create_reduction_remainder_kernel_for_parfor( parfor, typemap, - flags, - bool(alias_map), reductionKernelVar, parfor_reddict, reductionHelperList, @@ -292,14 +281,16 @@ def _reduction_codegen( global_range, local_range = self._remainder_ranges(lowerer) - self._submit_parfor_kernel( + # TODO: find better way to pass queue + queue_ref = self._submit_parfor_kernel( lowerer, parfor_kernel, global_range, local_range, + debug=flags.debuginfo, ) - reductionKernelVar.copy_final_sum_to_host(parfor_kernel) + reductionKernelVar.copy_final_sum_to_host(queue_ref) def _lower_parfor_as_kernel(self, lowerer, parfor): """Lowers a parfor node created by the dpjit compiler to a @@ -400,9 +391,7 @@ def _lower_parfor_as_kernel(self, lowerer, parfor): lowerer, parfor, typemap, - flags, loop_ranges, - bool(alias_map), parfor.races, parfor_output_arrays, ) @@ -418,10 +407,9 @@ def _lower_parfor_as_kernel(self, lowerer, parfor): parfor_kernel, global_range, local_range, + debug=flags.debuginfo, ) - # TODO: free the kernel at this point - # Restore the original typemap of the function that was replaced # temporarily at the beginning of this function. lowerer.fndesc.typemap = orig_typemap diff --git a/numba_dpex/core/parfors/parfor_pass.py b/numba_dpex/core/parfors/parfor_pass.py index c2149bf0cf..3901af8e7c 100644 --- a/numba_dpex/core/parfors/parfor_pass.py +++ b/numba_dpex/core/parfors/parfor_pass.py @@ -17,8 +17,8 @@ from numba.core import config, errors, ir, types, typing from numba.core.compiler_machinery import register_pass from numba.core.ir_utils import ( + convert_size_to_var, dprint_func_ir, - mk_alloc, mk_unique_var, next_label, ) @@ -43,6 +43,7 @@ ) from numba.stencils.stencilparfor import StencilPass +from numba_dpex.core.types.dpnp_ndarray_type import DpnpNdArray from numba_dpex.core.typing import dpnpdecl @@ -58,6 +59,37 @@ class ConvertDPNPPass(ConvertNumpyPass): def __init__(self, pass_states): super().__init__(pass_states) + def _get_queue(self, queue_type, expr: tuple): + """ + Extracts queue from the input arguments of the array operation. + """ + pass_states = self.pass_states + typemap: map[str, any] = pass_states.typemap + + var_with_queue = None + + for var in expr[1]: + if isinstance(var, tuple): + res = self._get_queue(queue_type, var) + if res is not None: + return res + + continue + + if not isinstance(var, ir.Var): + continue + + _type = typemap[var.name] + if not isinstance(_type, DpnpNdArray): + continue + if queue_type != _type.queue: + continue + + var_with_queue = var + break + + return ir.Expr.getattr(var_with_queue, "sycl_queue", var_with_queue.loc) + def _arrayexpr_to_parfor(self, equiv_set, lhs, arrayexpr, avail_vars): """generate parfor from arrayexpr node, which is essentially a map with recursive tree. @@ -77,6 +109,10 @@ def _arrayexpr_to_parfor(self, equiv_set, lhs, arrayexpr, avail_vars): pass_states.typemap, size_vars, scope, loc ) + # Expr is a tuple + ir_queue = self._get_queue(arr_typ.queue, expr) + assert ir_queue is not None + # generate init block and body init_block = ir.Block(scope, loc) init_block.body = mk_alloc( @@ -89,6 +125,7 @@ def _arrayexpr_to_parfor(self, equiv_set, lhs, arrayexpr, avail_vars): scope, loc, pass_states.typemap[lhs.name], + queue_ir_val=ir_queue, ) body_label = next_label() body_block = ir.Block(scope, loc) @@ -469,3 +506,66 @@ def _arrayexpr_tree_to_ir( typemap.pop(expr_out_var.name, None) typemap[expr_out_var.name] = el_typ return out_ir + + +def mk_alloc( + typingctx, + typemap, + calltypes, + lhs, + size_var, + dtype, + scope, + loc, + lhs_typ, + **kws, +): + """generate an array allocation with np.empty() and return list of nodes. + size_var can be an int variable or tuple of int variables. + lhs_typ is the type of the array being allocated. + + Taken from numba, added kws argument to pass it to __allocate__ + """ + out = [] + ndims = 1 + size_typ = types.intp + if isinstance(size_var, tuple): + if len(size_var) == 1: + size_var = size_var[0] + size_var = convert_size_to_var(size_var, typemap, scope, loc, out) + else: + # tuple_var = build_tuple([size_var...]) + ndims = len(size_var) + tuple_var = ir.Var(scope, mk_unique_var("$tuple_var"), loc) + if typemap: + typemap[tuple_var.name] = types.containers.UniTuple( + types.intp, ndims + ) + # constant sizes need to be assigned to vars + new_sizes = [ + convert_size_to_var(s, typemap, scope, loc, out) + for s in size_var + ] + tuple_call = ir.Expr.build_tuple(new_sizes, loc) + tuple_assign = ir.Assign(tuple_call, tuple_var, loc) + out.append(tuple_assign) + size_var = tuple_var + size_typ = types.containers.UniTuple(types.intp, ndims) + if hasattr(lhs_typ, "__allocate__"): + return lhs_typ.__allocate__( + typingctx, + typemap, + calltypes, + lhs, + size_var, + dtype, + scope, + loc, + lhs_typ, + size_typ, + out, + **kws, + ) + + # Unused numba's code.. + assert False diff --git a/numba_dpex/core/parfors/parfor_sentinel_replace_pass.py b/numba_dpex/core/parfors/parfor_sentinel_replace_pass.py new file mode 100644 index 0000000000..b9c5310b01 --- /dev/null +++ b/numba_dpex/core/parfors/parfor_sentinel_replace_pass.py @@ -0,0 +1,110 @@ +# SPDX-FileCopyrightText: 2024 Intel Corporation +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import NamedTuple, Union + +from numba.core import config +from numba.core.compiler_machinery import FunctionPass, register_pass +from numba.core.ir_utils import ( + add_offset_to_labels, + get_name_var_table, + get_unused_var_name, + mk_unique_var, + remove_dels, + rename_labels, + replace_var_names, +) + + +class ParforBodyArguments(NamedTuple): + """ + Arguments containing information to inject parfor code inside kernel. + """ + + loop_body: any + param_dict: any + legal_loop_indices: any + + +def _print_block(block): + for i, inst in enumerate(block.body): + print(" ", i, inst) + + +def _print_body(body_dict): + """Pretty-print a set of IR blocks.""" + for label, block in body_dict.items(): + print("label: ", label) + _print_block(block) + + +@register_pass(mutates_CFG=True, analysis_only=False) +class ParforSentinelReplacePass(FunctionPass): + _name = "sentinel_inject" + + def __init__(self): + FunctionPass.__init__(self) + + def _get_parfor_body_args(self, flags) -> Union[ParforBodyArguments, None]: + if not hasattr(flags, "_parfor_body_args"): + return None + + return flags._parfor_body_args + + def run_pass(self, state): + flags = state["flags"] + + args = self._get_parfor_body_args(flags) + + if args is None: + return True + + # beginning + kernel_ir = state["func_ir"] + loop_body = args.loop_body + + if config.DEBUG_ARRAY_OPT: + print("kernel_ir dump ", type(kernel_ir)) + kernel_ir.dump() + print("loop_body dump ", type(loop_body)) + _print_body(loop_body) + + # Determine the unique names of the scheduling and kernel functions. + loop_body_var_table = get_name_var_table(loop_body) + sentinel_name = get_unused_var_name("__sentinel__", loop_body_var_table) + + # rename all variables in kernel_ir afresh + var_table = get_name_var_table(kernel_ir.blocks) + new_var_dict = {} + reserved_names = ( + [sentinel_name] + + list(args.param_dict.values()) + + args.legal_loop_indices + ) + for name, _ in var_table.items(): + if not (name in reserved_names): + new_var_dict[name] = mk_unique_var(name) + + replace_var_names(kernel_ir.blocks, new_var_dict) + + kernel_stub_last_label = max(kernel_ir.blocks.keys()) + 1 + loop_body = add_offset_to_labels(loop_body, kernel_stub_last_label) + + # new label for splitting sentinel block + new_label = max(loop_body.keys()) + 1 + + from .kernel_builder import update_sentinel # circular + + update_sentinel(kernel_ir, sentinel_name, loop_body, new_label) + + # FIXME: Why rename and remove dels causes the partial_sum array update + # instructions to be removed. + kernel_ir.blocks = rename_labels(kernel_ir.blocks) + remove_dels(kernel_ir.blocks) + + if config.DEBUG_ARRAY_OPT: + print("kernel_ir after remove dead") + kernel_ir.dump() + + return True diff --git a/numba_dpex/core/parfors/reduction_helper.py b/numba_dpex/core/parfors/reduction_helper.py index 552575fbfe..cb9e189977 100644 --- a/numba_dpex/core/parfors/reduction_helper.py +++ b/numba_dpex/core/parfors/reduction_helper.py @@ -275,12 +275,14 @@ def __init__( param_dict = _legalize_names_with_typemap(parfor_params, typemap) ind_dict = _legalize_names_with_typemap(loop_indices, typemap) self._parfor_reddict = parfor_reddict + # output of reduction self._parfor_redvars = parfor_redvars self._redvars_legal_dict = legalize_names(parfor_redvars) # Compute a new list of legal loop index names. legal_loop_indices = [ind_dict[v] for v in loop_indices] tmp1 = [] + # output of reduction computed on device self._final_sum_names = [] self._parfor_redvars_to_redarrs = {} for ele1 in reductionHelperList: @@ -397,16 +399,8 @@ def lowerer(self): def work_group_size(self): return self._work_group_size - def copy_final_sum_to_host(self, parfor_kernel): + def copy_final_sum_to_host(self, queue_ref): lowerer = self.lowerer - kl_builder = KernelLaunchIRBuilder( - lowerer.context, lowerer.builder, kernel_dmm - ) - - # Create a local variable storing a pointer to a DPCTLSyclQueueRef - # pointer. - queue_ref = kl_builder.get_queue(exec_queue=parfor_kernel.queue) - builder = lowerer.builder context = lowerer.context @@ -448,5 +442,3 @@ def copy_final_sum_to_host(self, parfor_kernel): event_ref = sycl.dpctl_queue_memcpy(builder, *args) sycl.dpctl_event_wait(builder, event_ref) sycl.dpctl_event_delete(builder, event_ref) - - sycl.dpctl_queue_delete(builder, queue_ref) diff --git a/numba_dpex/core/parfors/reduction_kernel_builder.py b/numba_dpex/core/parfors/reduction_kernel_builder.py index 6180ff5ef1..90002f6321 100644 --- a/numba_dpex/core/parfors/reduction_kernel_builder.py +++ b/numba_dpex/core/parfors/reduction_kernel_builder.py @@ -4,33 +4,29 @@ import warnings -import dpctl from numba.core import types from numba.core.errors import NumbaParallelSafetyWarning from numba.core.ir_utils import ( - add_offset_to_labels, get_name_var_table, get_unused_var_name, legalize_names, - mk_unique_var, - remove_dels, - rename_labels, - replace_var_names, ) from numba.core.typing import signature +from numba_dpex.core.decorators import kernel +from numba_dpex.core.parfors.parfor_sentinel_replace_pass import ( + ParforBodyArguments, +) from numba_dpex.core.parfors.reduction_helper import ReductionKernelVariables -from numba_dpex.core.types import DpctlSyclQueue from numba_dpex.core.types.kernel_api.index_space_ids import NdItemType from numba_dpex.core.types.kernel_api.local_accessor import LocalAccessorType - -from .kernel_builder import _print_body # saved for debug -from .kernel_builder import ( - ParforKernel, - _compile_kernel_parfor, - _to_scalar_from_0d, - update_sentinel, +from numba_dpex.core.utils.call_kernel_builder import SPIRVKernelModule +from numba_dpex.kernel_api_impl.spirv.dispatcher import ( + SPIRVKernelDispatcher, + _SPIRVKernelCompileResult, ) + +from .kernel_builder import ParforKernel, _to_scalar_from_0d from .kernel_templates.reduction_template import ( RemainderReduceIntermediateKernelTemplate, TreeReduceIntermediateKernelTemplate, @@ -41,8 +37,6 @@ def create_reduction_main_kernel_for_parfor( loop_ranges, parfor_node, typemap, - flags, - has_aliases, reductionKernelVar: ReductionKernelVariables, parfor_reddict=None, ): @@ -111,7 +105,6 @@ def create_reduction_main_kernel_for_parfor( local_accessors_dict=local_accessors_dict, typemap=typemap, ) - kernel_ir = kernel_template.kernel_ir for i, name in enumerate(reductionKernelVar.parfor_params): try: @@ -120,40 +113,14 @@ def create_reduction_main_kernel_for_parfor( except KeyError: pass - # rename all variables in kernel_ir afresh - var_table = get_name_var_table(kernel_ir.blocks) - new_var_dict = {} - reserved_names = ( - [sentinel_name] - + list(reductionKernelVar.param_dict.values()) - + reductionKernelVar.legal_loop_indices - ) - for name, _ in var_table.items(): - if not (name in reserved_names): - new_var_dict[name] = mk_unique_var(name) - - replace_var_names(kernel_ir.blocks, new_var_dict) - kernel_param_types = parfor_param_types - kernel_stub_last_label = max(kernel_ir.blocks.keys()) + 1 - # Add kernel stub last label to each parfor.loop_body label to prevent - # label conflicts. - loop_body = add_offset_to_labels( - reductionKernelVar.loop_body, kernel_stub_last_label + kernel_dispatcher: SPIRVKernelDispatcher = kernel( + kernel_template.py_func, + _parfor_body_args=ParforBodyArguments( + loop_body=reductionKernelVar.loop_body, + param_dict=reductionKernelVar.param_dict, + legal_loop_indices=reductionKernelVar.legal_loop_indices, + ), ) - # new label for splitting sentinel block - new_label = max(loop_body.keys()) + 1 - - update_sentinel(kernel_ir, sentinel_name, loop_body, new_label) - - # FIXME: Why rename and remove dels causes the partial_sum array update - # instructions to be removed. - kernel_ir.blocks = rename_labels(kernel_ir.blocks) - remove_dels(kernel_ir.blocks) - - old_alias = flags.noalias - - if not has_aliases: - flags.noalias = True # The first argument to a range kernel is a kernel_api.NdItem object. The # ``NdItem`` object is used by the kernel_api.spirv backend to generate the @@ -161,25 +128,13 @@ def create_reduction_main_kernel_for_parfor( # available originally in the kernel_param_types, we add it at this point to # make sure the kernel signature matches the actual generated code. ty_item = NdItemType(parfor_dim) - kernel_param_types = (ty_item, *kernel_param_types) + kernel_param_types = (ty_item, *parfor_param_types) kernel_sig = signature(types.none, *kernel_param_types) - # FIXME: A better design is required so that we do not have to create a - # queue every time. - ty_queue: DpctlSyclQueue = typemap[ - reductionKernelVar.parfor_params[0] - ].queue - exec_queue = dpctl.get_device_cached_queue(ty_queue.sycl_device) - - sycl_kernel = _compile_kernel_parfor( - exec_queue, - kernel_name, - kernel_ir, - kernel_param_types, - debug=flags.debuginfo, + kcres: _SPIRVKernelCompileResult = kernel_dispatcher.get_compile_result( + types.void(*kernel_param_types) # kernel signature ) - - flags.noalias = old_alias + kernel_module: SPIRVKernelModule = kcres.kernel_device_ir_module parfor_params = ( reductionKernelVar.parfor_params.copy() @@ -187,22 +142,18 @@ def create_reduction_main_kernel_for_parfor( ) return ParforKernel( - name=kernel_name, - kernel=sycl_kernel, signature=kernel_sig, kernel_args=parfor_params, kernel_arg_types=parfor_param_types, - queue=exec_queue, local_accessors=set(local_accessors_dict.values()), work_group_size=reductionKernelVar.work_group_size, + kernel_module=kernel_module, ) def create_reduction_remainder_kernel_for_parfor( parfor_node, typemap, - flags, - has_aliases, reductionKernelVar, parfor_reddict, reductionHelperList, @@ -288,19 +239,6 @@ def create_reduction_remainder_kernel_for_parfor( final_sum_var_name=final_sum_var_legal_name, reductionKernelVar=reductionKernelVar, ) - kernel_ir = kernel_template.kernel_ir - - var_table = get_name_var_table(kernel_ir.blocks) - new_var_dict = {} - reserved_names = ( - [sentinel_name] - + list(reductionKernelVar.param_dict.values()) - + reductionKernelVar.legal_loop_indices - ) - for name, _ in var_table.items(): - if not (name in reserved_names): - new_var_dict[name] = mk_unique_var(name) - replace_var_names(kernel_ir.blocks, new_var_dict) for i, _ in enumerate(reductionKernelVar.parfor_redvars): if reductionHelperList[i].global_size_var is not None: @@ -353,48 +291,27 @@ def create_reduction_remainder_kernel_for_parfor( _to_scalar_from_0d(typemap[final_sum_var_name[i]]) ) - kernel_param_types = reductionKernelVar.param_types - - kernel_stub_last_label = max(kernel_ir.blocks.keys()) + 1 - - # Add kernel stub last label to each parfor.loop_body label to prevent - # label conflicts. - loop_body = add_offset_to_labels( - reductionKernelVar.loop_body, kernel_stub_last_label + kernel_dispatcher: SPIRVKernelDispatcher = kernel( + kernel_template.py_func, + _parfor_body_args=ParforBodyArguments( + loop_body=reductionKernelVar.loop_body, + param_dict=reductionKernelVar.param_dict, + legal_loop_indices=reductionKernelVar.legal_loop_indices, + ), ) - # new label for splitting sentinel block - new_label = max(loop_body.keys()) + 1 - - update_sentinel(kernel_ir, sentinel_name, loop_body, new_label) - old_alias = flags.noalias - if not has_aliases: - flags.noalias = True + kernel_param_types = reductionKernelVar.param_types kernel_sig = signature(types.none, *kernel_param_types) - # FIXME: A better design is required so that we do not have to create a - # queue every time. - ty_queue: DpctlSyclQueue = typemap[ - reductionKernelVar.parfor_params[0] - ].queue - exec_queue = dpctl.get_device_cached_queue(ty_queue.sycl_device) - - sycl_kernel = _compile_kernel_parfor( - exec_queue, - kernel_name, - kernel_ir, - kernel_param_types, - debug=flags.debuginfo, + kcres: _SPIRVKernelCompileResult = kernel_dispatcher.get_compile_result( + types.void(*kernel_param_types) # kernel signature ) - - flags.noalias = old_alias + kernel_module: SPIRVKernelModule = kcres.kernel_device_ir_module return ParforKernel( - name=kernel_name, - kernel=sycl_kernel, signature=kernel_sig, kernel_args=reductionKernelVar.parfor_params, kernel_arg_types=reductionKernelVar.func_arg_types, - queue=exec_queue, + kernel_module=kernel_module, ) diff --git a/numba_dpex/core/pipelines/kernel_compiler.py b/numba_dpex/core/pipelines/kernel_compiler.py index 7cb21d3817..c34a85b875 100644 --- a/numba_dpex/core/pipelines/kernel_compiler.py +++ b/numba_dpex/core/pipelines/kernel_compiler.py @@ -8,8 +8,12 @@ IRLegalization, NoPythonSupportedFeatureValidation, ) +from numba.core.untyped_passes import InlineClosureLikes from numba_dpex.core.exceptions import UnsupportedCompilationModeError +from numba_dpex.core.parfors.parfor_sentinel_replace_pass import ( + ParforSentinelReplacePass, +) from numba_dpex.core.passes.passes import ( NoPythonBackend, QualNameDisambiguationLowering, @@ -57,6 +61,8 @@ def define_nopython_pipeline(state, name="dpex_kernel_nopython"): pm = PassManager(name) untyped_passes = ndpb.define_untyped_pipeline(state) pm.passes.extend(untyped_passes.passes) + # TODO: create separate parfor kernel pass + pm.add_pass_after(ParforSentinelReplacePass, InlineClosureLikes) typed_passes = ndpb.define_typed_pipeline(state) pm.passes.extend(typed_passes.passes) diff --git a/numba_dpex/core/types/dpnp_ndarray_type.py b/numba_dpex/core/types/dpnp_ndarray_type.py index fb09598243..6d6622202c 100644 --- a/numba_dpex/core/types/dpnp_ndarray_type.py +++ b/numba_dpex/core/types/dpnp_ndarray_type.py @@ -2,14 +2,28 @@ # # SPDX-License-Identifier: Apache-2.0 +from functools import partial + import dpnp from numba.core import ir, types from numba.core.ir_utils import get_np_ufunc_typ, mk_unique_var -from numba.core.pythonapi import NativeValue, PythonAPI, box, unbox from .usm_ndarray_type import USMNdArray +def partialclass(cls, *args, **kwds): + """Creates fabric class of the original class with preset initialization + arguments.""" + cls0 = partial(cls, *args, **kwds) + new_cls = type( + cls.__name__ + "Partial", + (cls,), + {"__new__": lambda cls, *args, **kwds: cls0(*args, **kwds)}, + ) + + return new_cls + + class DpnpNdArray(USMNdArray): """ The Numba type to represent an dpnp.ndarray. The type has the same @@ -40,15 +54,22 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): Returns: The DpnpNdArray class. """ if method == "__call__": - if not all( - ( - isinstance(inp, DpnpNdArray) - or isinstance(inp, types.abstract.Number) - ) - for inp in inputs - ): + dpnp_type = None + + for inp in inputs: + if isinstance(inp, DpnpNdArray): + dpnp_type = inp + continue + if isinstance(inp, types.abstract.Number): + continue + return NotImplemented - return DpnpNdArray + + assert dpnp_type is not None + + return partialclass( + DpnpNdArray, queue=dpnp_type.queue, usm_type=dpnp_type.usm_type + ) else: return @@ -71,6 +92,8 @@ def __allocate__( lhs_typ, size_typ, out, + # dpex specific argument: + queue_ir_val=None, ): """Generates the Numba typed IR representing the allocation of a new DpnpNdArray using the dpnp.ndarray overload. @@ -94,6 +117,10 @@ def __allocate__( Returns: The IR Value for the allocated array """ + # TODO: it looks like it is being called only for parfor allocations, + # so can we rely on it? We can grab information from input arguments + # from rhs, but doc does not set any restriction on parfor use only. + assert queue_ir_val is not None g_np_var = ir.Var(scope, mk_unique_var("$np_g_var"), loc) if typemap: typemap[g_np_var.name] = types.misc.Module(dpnp) @@ -132,11 +159,13 @@ def __allocate__( usm_typ_var = ir.Var(scope, mk_unique_var("$np_usm_type_var"), loc) # A default device string arg added as a placeholder device_typ_var = ir.Var(scope, mk_unique_var("$np_device_var"), loc) + queue_typ_var = ir.Var(scope, mk_unique_var("$np_queue_var"), loc) if typemap: typemap[layout_var.name] = types.literal(lhs_typ.layout) typemap[usm_typ_var.name] = types.literal(lhs_typ.usm_type) - typemap[device_typ_var.name] = types.literal(lhs_typ.device) + typemap[device_typ_var.name] = types.none + typemap[queue_typ_var.name] = lhs_typ.queue layout_var_assign = ir.Assign( ir.Const(lhs_typ.layout, loc), layout_var, loc @@ -145,16 +174,29 @@ def __allocate__( ir.Const(lhs_typ.usm_type, loc), usm_typ_var, loc ) device_typ_var_assign = ir.Assign( - ir.Const(lhs_typ.device, loc), device_typ_var, loc + ir.Const(None, loc), device_typ_var, loc ) + queue_typ_var_assign = ir.Assign(queue_ir_val, queue_typ_var, loc) out.extend( - [layout_var_assign, usm_typ_var_assign, device_typ_var_assign] + [ + layout_var_assign, + usm_typ_var_assign, + device_typ_var_assign, + queue_typ_var_assign, + ] ) alloc_call = ir.Expr.call( attr_var, - [size_var, typ_var, layout_var, device_typ_var, usm_typ_var], + [ + size_var, + typ_var, + layout_var, + device_typ_var, + usm_typ_var, + queue_typ_var, + ], (), loc, ) @@ -170,6 +212,7 @@ def __allocate__( layout_var, device_typ_var, usm_typ_var, + queue_typ_var, ] ], {}, diff --git a/numba_dpex/kernel_api_impl/spirv/dispatcher.py b/numba_dpex/kernel_api_impl/spirv/dispatcher.py index 99e6de8b6c..92224d9c28 100644 --- a/numba_dpex/kernel_api_impl/spirv/dispatcher.py +++ b/numba_dpex/kernel_api_impl/spirv/dispatcher.py @@ -26,7 +26,7 @@ from numba.core.types import void from numba.core.typing.typeof import Purpose, typeof -from numba_dpex import config +from numba_dpex.core import config from numba_dpex.core.descriptor import dpex_kernel_target from numba_dpex.core.exceptions import ( ExecutionQueueInferenceError,