diff --git a/numba_dpex/core/descriptor.py b/numba_dpex/core/descriptor.py index dc99bb2d7b..50527f8733 100644 --- a/numba_dpex/core/descriptor.py +++ b/numba_dpex/core/descriptor.py @@ -48,7 +48,7 @@ class DpexTargetOptions(CPUTargetOptions): no_compile = _option_mapping("no_compile") inline_threshold = _option_mapping("inline_threshold") _compilation_mode = _option_mapping("_compilation_mode") - _reduction_kernel_variables = _option_mapping("_reduction_kernel_variables") + _parfor_args = _option_mapping("_parfor_args") def finalize(self, flags, options): super().finalize(flags, options) @@ -64,7 +64,7 @@ def finalize(self, flags, options): _inherit_if_not_set( flags, options, "_compilation_mode", CompilationMode.KERNEL ) - _inherit_if_not_set(flags, options, "_reduction_kernel_variables", None) + _inherit_if_not_set(flags, options, "_parfor_args", None) class DpexKernelTarget(TargetDescriptor): diff --git a/numba_dpex/core/parfors/kernel_builder.py b/numba_dpex/core/parfors/kernel_builder.py index 3d00bcc154..c0766752a6 100644 --- a/numba_dpex/core/parfors/kernel_builder.py +++ b/numba_dpex/core/parfors/kernel_builder.py @@ -27,6 +27,7 @@ from numba_dpex.core import config from numba_dpex.core.decorators import kernel +from numba_dpex.core.parfors.kernel_parfor_pass import ParforArguments from numba_dpex.core.types.kernel_api.index_space_ids import ItemType from numba_dpex.core.utils.call_kernel_builder import SPIRVKernelModule from numba_dpex.kernel_api_impl.spirv import spirv_generator @@ -44,83 +45,21 @@ class ParforKernel: def __init__( self, - name, - kernel, signature, kernel_args, kernel_arg_types, - queue: dpctl.SyclQueue, local_accessors=None, work_group_size=None, kernel_module=None, ): - self.name = name - self.kernel = kernel self.signature = signature self.kernel_args = kernel_args self.kernel_arg_types = kernel_arg_types - self.queue = queue self.local_accessors = local_accessors self.work_group_size = work_group_size self.kernel_module = kernel_module -def _print_block(block): - for i, inst in enumerate(block.body): - print(" ", i, inst) - - -def _print_body(body_dict): - """Pretty-print a set of IR blocks.""" - for label, block in body_dict.items(): - print("label: ", label) - _print_block(block) - - -def _compile_kernel_parfor( - sycl_queue, kernel_name, func_ir, argtypes, debug=False -): - with target_override(dpex_kernel_target.target_context.target_name): - cres = compile_numba_ir_with_dpex( - pyfunc=func_ir, - pyfunc_name=kernel_name, - args=argtypes, - return_type=None, - debug=debug, - is_kernel=True, - typing_context=dpex_kernel_target.typing_context, - target_context=dpex_kernel_target.target_context, - extra_compile_flags=None, - ) - cres.library.inline_threshold = config.INLINE_THRESHOLD - cres.library._optimize_final_module() - func = cres.library.get_function(cres.fndesc.llvm_func_name) - kernel = dpex_kernel_target.target_context.prepare_spir_kernel( - func, cres.signature.args - ) - spirv_module = spirv_generator.llvm_to_spirv( - dpex_kernel_target.target_context, - kernel.module.__str__(), - kernel.module.as_bitcode(), - ) - - dpctl_create_program_from_spirv_flags = [] - if debug or config.DPEX_OPT == 0: - # if debug is ON we need to pass additional flags to igc. - dpctl_create_program_from_spirv_flags = ["-g", "-cl-opt-disable"] - - # create a sycl::kernel_bundle - kernel_bundle = dpctl_prog.create_program_from_spirv( - sycl_queue, - spirv_module, - " ".join(dpctl_create_program_from_spirv_flags), - ) - # create a sycl::kernel - sycl_kernel = kernel_bundle.get_sycl_kernel(kernel.name) - - return sycl_kernel - - def _legalize_names_with_typemap(names, typemap): """Replace illegal characters in Numba IR var names. @@ -197,69 +136,6 @@ def _replace_var_with_array(vars, loop_body, typemap, calltypes): typemap[v] = types.npytypes.Array(el_typ, 1, "C") -def _find_setitems_block(setitems, block, typemap): - for inst in block.body: - if isinstance(inst, ir.StaticSetItem) or isinstance(inst, ir.SetItem): - setitems.add(inst.target.name) - elif isinstance(inst, parfor.Parfor): - _find_setitems_block(setitems, inst.init_block, typemap) - _find_setitems_body(setitems, inst.loop_body, typemap) - - -def _find_setitems_body(setitems, loop_body, typemap): - """ - Find the arrays that are written into (goes into setitems) - """ - for label, block in loop_body.items(): - _find_setitems_block(setitems, block, typemap) - - -def _replace_sentinel_with_parfor_body(kernel_ir, sentinel_name, loop_body): - # new label for splitting sentinel block - new_label = max(loop_body.keys()) + 1 - - # Search all the block in the kernel function for the sentinel assignment. - for label, block in kernel_ir.blocks.items(): - for i, inst in enumerate(block.body): - if ( - isinstance(inst, ir.Assign) - and inst.target.name == sentinel_name - ): - # We found the sentinel assignment. - loc = inst.loc - scope = block.scope - # split block across __sentinel__ - # A new block is allocated for the statements prior to the - # sentinel but the new block maintains the current block label. - prev_block = ir.Block(scope, loc) - prev_block.body = block.body[:i] - - # The current block is used for statements after the sentinel. - block.body = block.body[i + 1 :] # noqa: E203 - # But the current block gets a new label. - body_first_label = min(loop_body.keys()) - - # The previous block jumps to the minimum labelled block of the - # parfor body. - prev_block.append(ir.Jump(body_first_label, loc)) - # Add all the parfor loop body blocks to the kernel function's - # IR. - for loop, b in loop_body.items(): - kernel_ir.blocks[loop] = b - body_last_label = max(loop_body.keys()) - kernel_ir.blocks[new_label] = block - kernel_ir.blocks[label] = prev_block - # Add a jump from the last parfor body block to the block - # containing statements after the sentinel. - kernel_ir.blocks[body_last_label].append( - ir.Jump(new_label, loc) - ) - break - else: - continue - break - - def create_kernel_for_parfor( lowerer, parfor_node, @@ -375,127 +251,28 @@ def create_kernel_for_parfor( loop_ranges=loop_ranges, param_dict=param_dict, ) - kernel_ir = kernel_template.kernel_ir - kernel_dispatcher: SPIRVKernelDispatcher = kernel(kernel_template._py_func) - - if config.DEBUG_ARRAY_OPT: - print("kernel_ir dump ", type(kernel_ir)) - kernel_ir.dump() - print("loop_body dump ", type(loop_body)) - _print_body(loop_body) - - # rename all variables in kernel_ir afresh - var_table = get_name_var_table(kernel_ir.blocks) - new_var_dict = {} - reserved_names = ( - [sentinel_name] + list(param_dict.values()) + legal_loop_indices + kernel_dispatcher: SPIRVKernelDispatcher = kernel( + kernel_template._py_func, + _parfor_args=ParforArguments(loop_body=loop_body), ) - for name, var in var_table.items(): - if not (name in reserved_names): - new_var_dict[name] = mk_unique_var(name) - replace_var_names(kernel_ir.blocks, new_var_dict) - if config.DEBUG_ARRAY_OPT: - print("kernel_ir dump after renaming ") - kernel_ir.dump() - kernel_param_types = param_types - - if config.DEBUG_ARRAY_OPT: - print( - "kernel_param_types = ", - type(kernel_param_types), - "\n", - kernel_param_types, - ) - - kernel_stub_last_label = max(kernel_ir.blocks.keys()) + 1 - - # Add kernel stub last label to each parfor.loop_body label to prevent - # label conflicts. - loop_body = add_offset_to_labels(loop_body, kernel_stub_last_label) - - _replace_sentinel_with_parfor_body(kernel_ir, sentinel_name, loop_body) - - if config.DEBUG_ARRAY_OPT: - print("kernel_ir last dump before renaming") - kernel_ir.dump() - - kernel_ir.blocks = rename_labels(kernel_ir.blocks) - remove_dels(kernel_ir.blocks) - - old_alias = flags.noalias - if not has_aliases: - if config.DEBUG_ARRAY_OPT: - print("No aliases found so adding noalias flag.") - flags.noalias = True - - remove_dead(kernel_ir.blocks, kernel_ir.arg_names, kernel_ir, typemap) - - if config.DEBUG_ARRAY_OPT: - print("kernel_ir after remove dead") - kernel_ir.dump() - - # The first argument to a range kernel is a kernel_api.Item object. The - # ``Item`` object is used by the kernel_api.spirv backend to generate the - # correct SPIR-V indexing instructions. Since, the argument is not something - # available originally in the kernel_param_types, we add it at this point to - # make sure the kernel signature matches the actual generated code. ty_item = ItemType(parfor_dim) - kernel_param_types = (ty_item, *kernel_param_types) + kernel_param_types = (ty_item, *param_types) kernel_sig = signature(types.none, *kernel_param_types) - if config.DEBUG_ARRAY_OPT: - sys.stdout.flush() - - if config.DEBUG_ARRAY_OPT: - print("after DUFunc inline".center(80, "-")) - kernel_ir.dump() - - # The ParforLegalizeCFD pass has already ensured that the LHS and RHS - # arrays are on same device. We can take the queue from the first input - # array and use that to compile the kernel. - - exec_queue: dpctl.SyclQueue = None - - for arg in parfor_args: - obj = typemap[arg] - if isinstance(obj, DpnpNdArray): - filter_string = obj.queue.sycl_device - # FIXME: A better design is required so that we do not have to - # create a queue every time. - exec_queue = dpctl.get_device_cached_queue(filter_string) - - if not exec_queue: - raise AssertionError( - "No execution found for parfor. No way to compile the kernel!" - ) - - sycl_kernel = _compile_kernel_parfor( - exec_queue, - kernel_name, - kernel_ir, - kernel_param_types, - debug=flags.debuginfo, - ) - kcres: _SPIRVKernelCompileResult = kernel_dispatcher.get_compile_result( types.void(*kernel_param_types) # kernel signature ) kernel_module: SPIRVKernelModule = kcres.kernel_device_ir_module - flags.noalias = old_alias - if config.DEBUG_ARRAY_OPT: print("kernel_sig = ", kernel_sig) return ParforKernel( - name=kernel_name, - kernel=sycl_kernel, signature=kernel_sig, kernel_args=parfor_args, kernel_arg_types=func_arg_types, - queue=exec_queue, kernel_module=kernel_module, ) diff --git a/numba_dpex/core/parfors/kernel_parfor_pass.py b/numba_dpex/core/parfors/kernel_parfor_pass.py index 7f1eb358fc..6bc8ec1f09 100644 --- a/numba_dpex/core/parfors/kernel_parfor_pass.py +++ b/numba_dpex/core/parfors/kernel_parfor_pass.py @@ -1,4 +1,6 @@ -from numba.core import config, postproc +from typing import NamedTuple, Union + +from numba.core import config from numba.core.compiler_machinery import FunctionPass, register_pass from numba.core.ir_utils import ( add_offset_to_labels, @@ -11,6 +13,23 @@ from numba_dpex.core.parfors.reduction_helper import ReductionKernelVariables +class ParforArguments(NamedTuple): + loop_body: any = None + reduction_vars: any = None + + +def _print_block(block): + for i, inst in enumerate(block.body): + print(" ", i, inst) + + +def _print_body(body_dict): + """Pretty-print a set of IR blocks.""" + for label, block in body_dict.items(): + print("label: ", label) + _print_block(block) + + @register_pass(mutates_CFG=True, analysis_only=False) class SentinelInjectPass(FunctionPass): _name = "sentinel_inject" @@ -18,26 +37,15 @@ class SentinelInjectPass(FunctionPass): def __init__(self): FunctionPass.__init__(self) - def run_pass(self, state): - flags = state["flags"] + def _get_parfor_args(self, flags) -> Union[ParforArguments, None]: + if not hasattr(flags, "_parfor_args"): + return None - if not hasattr(flags, "_reduction_kernel_variables"): - return True - - reductionKernelVar: ReductionKernelVariables = ( - flags._reduction_kernel_variables - ) - - if reductionKernelVar is None: - return True - - # Determine the unique names of the scheduling and kernel functions. - loop_body_var_table = get_name_var_table(reductionKernelVar.loop_body) - sentinel_name = get_unused_var_name("__sentinel__", loop_body_var_table) - - # beginning - kernel_ir = state["func_ir"] + return flags._parfor_args + def _apply_reduction_vars( + self, kernel_ir, reductionKernelVar, sentinel_name + ): for i, name in enumerate(reductionKernelVar.parfor_params): try: tmp = reductionKernelVar.parfor_redvars_to_redarrs[name][0] @@ -58,12 +66,39 @@ def run_pass(self, state): new_var_dict[name] = mk_unique_var(name) replace_var_names(kernel_ir.blocks, new_var_dict) + + def run_pass(self, state): + flags = state["flags"] + + args = self._get_parfor_args(flags) + + if args is None: + return True + + # beginning + kernel_ir = state["func_ir"] + loop_body = args.loop_body + if loop_body is None: + loop_body = args.reduction_vars.loop_body + + if config.DEBUG_ARRAY_OPT: + print("kernel_ir dump ", type(kernel_ir)) + kernel_ir.dump() + print("loop_body dump ", type(loop_body)) + _print_body(loop_body) + + # Determine the unique names of the scheduling and kernel functions. + loop_body_var_table = get_name_var_table(loop_body) + sentinel_name = get_unused_var_name("__sentinel__", loop_body_var_table) + + if args.reduction_vars is not None: + self._apply_reduction_vars( + kernel_ir, args.reduction_vars, sentinel_name + ) + kernel_stub_last_label = max(kernel_ir.blocks.keys()) + 1 - # Add kernel stub last label to each parfor.loop_body label to prevent - # label conflicts. - loop_body = add_offset_to_labels( - reductionKernelVar.loop_body, kernel_stub_last_label - ) + loop_body = add_offset_to_labels(loop_body, kernel_stub_last_label) + # new label for splitting sentinel block new_label = max(loop_body.keys()) + 1 @@ -75,7 +110,11 @@ def run_pass(self, state): # FIXME: Why rename and remove dels causes the partial_sum array update # instructions to be removed. # TODO: do we need it now? - # kernel_ir.blocks = rename_labels(kernel_ir.blocks) - # remove_dels(kernel_ir.blocks) + # kernel_ir.blocks = rename_labels(kernel_ir.blocks) #noqa: E800 + # remove_dels(kernel_ir.blocks) #noqa: E800 + + if config.DEBUG_ARRAY_OPT: + print("kernel_ir after remove dead") + kernel_ir.dump() return True diff --git a/numba_dpex/core/parfors/kernel_templates/range_kernel_template.py b/numba_dpex/core/parfors/kernel_templates/range_kernel_template.py index 86bce5dc14..0277f0801e 100644 --- a/numba_dpex/core/parfors/kernel_templates/range_kernel_template.py +++ b/numba_dpex/core/parfors/kernel_templates/range_kernel_template.py @@ -5,7 +5,6 @@ import sys import dpnp -from numba.core import compiler import numba_dpex as dpex @@ -111,7 +110,7 @@ def _generate_kernel_ir(self): exec(self._kernel_txt, globls, locls) kernel_fn = locls[self._kernel_name] - return kernel_fn, compiler.run_frontend(kernel_fn) + return kernel_fn, None @property def kernel_ir(self): diff --git a/numba_dpex/core/parfors/kernel_templates/reduction_template.py b/numba_dpex/core/parfors/kernel_templates/reduction_template.py index da8cf73775..52b75c8603 100644 --- a/numba_dpex/core/parfors/kernel_templates/reduction_template.py +++ b/numba_dpex/core/parfors/kernel_templates/reduction_template.py @@ -6,7 +6,6 @@ import sys import dpnp -from numba.core import compiler import numba_dpex.kernel_api as kapi @@ -163,7 +162,7 @@ def _generate_kernel_ir(self): exec(self._kernel_txt, globls, locls) kernel_fn = locls[self._kernel_name] - return kernel_fn, compiler.run_frontend(kernel_fn) + return kernel_fn, None @property def kernel_ir(self): @@ -322,7 +321,7 @@ def _generate_kernel_ir(self): exec(self._kernel_txt, globls, locls) kernel_fn = locls[self._kernel_name] - return kernel_fn, compiler.run_frontend(kernel_fn) + return kernel_fn, None @property def kernel_ir(self): diff --git a/numba_dpex/core/parfors/parfor_lowerer.py b/numba_dpex/core/parfors/parfor_lowerer.py index 9ea3e572d6..ddd073d08c 100644 --- a/numba_dpex/core/parfors/parfor_lowerer.py +++ b/numba_dpex/core/parfors/parfor_lowerer.py @@ -197,6 +197,8 @@ def _submit_parfor_kernel( sycl.dpctl_event_wait(lowerer.builder, event_ref) sycl.dpctl_event_delete(lowerer.builder, event_ref) + return kl_builder.arguments.sycl_queue_ref + def _reduction_codegen( self, parfor, @@ -283,7 +285,8 @@ def _reduction_codegen( global_range, local_range = self._remainder_ranges(lowerer) - self._submit_parfor_kernel( + # TODO: find better way to pass queue + queue_ref = self._submit_parfor_kernel( lowerer, parfor_kernel, global_range, @@ -291,7 +294,7 @@ def _reduction_codegen( debug=flags.debuginfo, ) - reductionKernelVar.copy_final_sum_to_host(parfor_kernel) + reductionKernelVar.copy_final_sum_to_host(queue_ref) def _lower_parfor_as_kernel(self, lowerer, parfor): """Lowers a parfor node created by the dpjit compiler to a diff --git a/numba_dpex/core/parfors/parfor_pass.py b/numba_dpex/core/parfors/parfor_pass.py index c2149bf0cf..f30f1a5b14 100644 --- a/numba_dpex/core/parfors/parfor_pass.py +++ b/numba_dpex/core/parfors/parfor_pass.py @@ -17,8 +17,8 @@ from numba.core import config, errors, ir, types, typing from numba.core.compiler_machinery import register_pass from numba.core.ir_utils import ( + convert_size_to_var, dprint_func_ir, - mk_alloc, mk_unique_var, next_label, ) @@ -43,6 +43,7 @@ ) from numba.stencils.stencilparfor import StencilPass +from numba_dpex.core.types.dpnp_ndarray_type import DpnpNdArray from numba_dpex.core.typing import dpnpdecl @@ -58,6 +59,30 @@ class ConvertDPNPPass(ConvertNumpyPass): def __init__(self, pass_states): super().__init__(pass_states) + def _get_queue(self, avail_ir_vars: list[ir.Var]): + """ + Extracts queue from the input arguments of the array operation. + """ + pass_states = self.pass_states + typemap: map[str, any] = pass_states.typemap + return_type: DpnpNdArray = pass_states.return_type + + var_with_queue = None + + for var in avail_ir_vars: + _type = typemap[var.name] + if not isinstance(_type, DpnpNdArray): + continue + if return_type.queue != _type.queue: + continue + + var_with_queue = var + break + + assert var_with_queue is not None + + return ir.Expr.getattr(var_with_queue, "sycl_queue", var_with_queue.loc) + def _arrayexpr_to_parfor(self, equiv_set, lhs, arrayexpr, avail_vars): """generate parfor from arrayexpr node, which is essentially a map with recursive tree. @@ -77,6 +102,9 @@ def _arrayexpr_to_parfor(self, equiv_set, lhs, arrayexpr, avail_vars): pass_states.typemap, size_vars, scope, loc ) + # Expr is a tuple + ir_queue = self._get_queue(expr[1]) + # generate init block and body init_block = ir.Block(scope, loc) init_block.body = mk_alloc( @@ -89,6 +117,7 @@ def _arrayexpr_to_parfor(self, equiv_set, lhs, arrayexpr, avail_vars): scope, loc, pass_states.typemap[lhs.name], + queue_ir_val=ir_queue, ) body_label = next_label() body_block = ir.Block(scope, loc) @@ -469,3 +498,66 @@ def _arrayexpr_tree_to_ir( typemap.pop(expr_out_var.name, None) typemap[expr_out_var.name] = el_typ return out_ir + + +def mk_alloc( + typingctx, + typemap, + calltypes, + lhs, + size_var, + dtype, + scope, + loc, + lhs_typ, + **kws, +): + """generate an array allocation with np.empty() and return list of nodes. + size_var can be an int variable or tuple of int variables. + lhs_typ is the type of the array being allocated. + + Taken from numba, added kws argument to pass it to __allocate__ + """ + out = [] + ndims = 1 + size_typ = types.intp + if isinstance(size_var, tuple): + if len(size_var) == 1: + size_var = size_var[0] + size_var = convert_size_to_var(size_var, typemap, scope, loc, out) + else: + # tuple_var = build_tuple([size_var...]) + ndims = len(size_var) + tuple_var = ir.Var(scope, mk_unique_var("$tuple_var"), loc) + if typemap: + typemap[tuple_var.name] = types.containers.UniTuple( + types.intp, ndims + ) + # constant sizes need to be assigned to vars + new_sizes = [ + convert_size_to_var(s, typemap, scope, loc, out) + for s in size_var + ] + tuple_call = ir.Expr.build_tuple(new_sizes, loc) + tuple_assign = ir.Assign(tuple_call, tuple_var, loc) + out.append(tuple_assign) + size_var = tuple_var + size_typ = types.containers.UniTuple(types.intp, ndims) + if hasattr(lhs_typ, "__allocate__"): + return lhs_typ.__allocate__( + typingctx, + typemap, + calltypes, + lhs, + size_var, + dtype, + scope, + loc, + lhs_typ, + size_typ, + out, + **kws, + ) + + # Unused numba's code.. + assert False diff --git a/numba_dpex/core/parfors/reduction_helper.py b/numba_dpex/core/parfors/reduction_helper.py index 552575fbfe..cb9e189977 100644 --- a/numba_dpex/core/parfors/reduction_helper.py +++ b/numba_dpex/core/parfors/reduction_helper.py @@ -275,12 +275,14 @@ def __init__( param_dict = _legalize_names_with_typemap(parfor_params, typemap) ind_dict = _legalize_names_with_typemap(loop_indices, typemap) self._parfor_reddict = parfor_reddict + # output of reduction self._parfor_redvars = parfor_redvars self._redvars_legal_dict = legalize_names(parfor_redvars) # Compute a new list of legal loop index names. legal_loop_indices = [ind_dict[v] for v in loop_indices] tmp1 = [] + # output of reduction computed on device self._final_sum_names = [] self._parfor_redvars_to_redarrs = {} for ele1 in reductionHelperList: @@ -397,16 +399,8 @@ def lowerer(self): def work_group_size(self): return self._work_group_size - def copy_final_sum_to_host(self, parfor_kernel): + def copy_final_sum_to_host(self, queue_ref): lowerer = self.lowerer - kl_builder = KernelLaunchIRBuilder( - lowerer.context, lowerer.builder, kernel_dmm - ) - - # Create a local variable storing a pointer to a DPCTLSyclQueueRef - # pointer. - queue_ref = kl_builder.get_queue(exec_queue=parfor_kernel.queue) - builder = lowerer.builder context = lowerer.context @@ -448,5 +442,3 @@ def copy_final_sum_to_host(self, parfor_kernel): event_ref = sycl.dpctl_queue_memcpy(builder, *args) sycl.dpctl_event_wait(builder, event_ref) sycl.dpctl_event_delete(builder, event_ref) - - sycl.dpctl_queue_delete(builder, queue_ref) diff --git a/numba_dpex/core/parfors/reduction_kernel_builder.py b/numba_dpex/core/parfors/reduction_kernel_builder.py index 8127ab8ee4..f0ed70ce42 100644 --- a/numba_dpex/core/parfors/reduction_kernel_builder.py +++ b/numba_dpex/core/parfors/reduction_kernel_builder.py @@ -4,24 +4,18 @@ import warnings -import dpctl from numba.core import types from numba.core.errors import NumbaParallelSafetyWarning from numba.core.ir_utils import ( - add_offset_to_labels, get_name_var_table, get_unused_var_name, legalize_names, - mk_unique_var, - remove_dels, - rename_labels, - replace_var_names, ) from numba.core.typing import signature from numba_dpex.core.decorators import kernel +from numba_dpex.core.parfors.kernel_parfor_pass import ParforArguments from numba_dpex.core.parfors.reduction_helper import ReductionKernelVariables -from numba_dpex.core.types import DpctlSyclQueue from numba_dpex.core.types.kernel_api.index_space_ids import NdItemType from numba_dpex.core.types.kernel_api.local_accessor import LocalAccessorType from numba_dpex.core.utils.call_kernel_builder import SPIRVKernelModule @@ -30,13 +24,7 @@ _SPIRVKernelCompileResult, ) -from .kernel_builder import _print_body # saved for debug -from .kernel_builder import ( - ParforKernel, - _compile_kernel_parfor, - _to_scalar_from_0d, - update_sentinel, -) +from .kernel_builder import ParforKernel, _to_scalar_from_0d from .kernel_templates.reduction_template import ( RemainderReduceIntermediateKernelTemplate, TreeReduceIntermediateKernelTemplate, @@ -117,98 +105,35 @@ def create_reduction_main_kernel_for_parfor( local_accessors_dict=local_accessors_dict, typemap=typemap, ) - kernel_ir = kernel_template.kernel_ir kernel_dispatcher: SPIRVKernelDispatcher = kernel( kernel_template._py_func, - _reduction_kernel_variables=reductionKernelVar, + _parfor_args=ParforArguments(reduction_vars=reductionKernelVar), ) - for i, name in enumerate(reductionKernelVar.parfor_params): - try: - tmp = reductionKernelVar.parfor_redvars_to_redarrs[name][0] - reductionKernelVar.parfor_params[i] = tmp - except KeyError: - pass - - # rename all variables in kernel_ir afresh - var_table = get_name_var_table(kernel_ir.blocks) - new_var_dict = {} - reserved_names = ( - [sentinel_name] - + list(reductionKernelVar.param_dict.values()) - + reductionKernelVar.legal_loop_indices - ) - for name, _ in var_table.items(): - if not (name in reserved_names): - new_var_dict[name] = mk_unique_var(name) - - replace_var_names(kernel_ir.blocks, new_var_dict) - kernel_param_types = parfor_param_types - kernel_stub_last_label = max(kernel_ir.blocks.keys()) + 1 - # Add kernel stub last label to each parfor.loop_body label to prevent - # label conflicts. - loop_body = add_offset_to_labels( - reductionKernelVar.loop_body, kernel_stub_last_label - ) - # new label for splitting sentinel block - new_label = max(loop_body.keys()) + 1 - - update_sentinel(kernel_ir, sentinel_name, loop_body, new_label) - - # FIXME: Why rename and remove dels causes the partial_sum array update - # instructions to be removed. - kernel_ir.blocks = rename_labels(kernel_ir.blocks) - remove_dels(kernel_ir.blocks) - - old_alias = flags.noalias - - if not has_aliases: - flags.noalias = True - # The first argument to a range kernel is a kernel_api.NdItem object. The # ``NdItem`` object is used by the kernel_api.spirv backend to generate the # correct SPIR-V indexing instructions. Since, the argument is not something # available originally in the kernel_param_types, we add it at this point to # make sure the kernel signature matches the actual generated code. ty_item = NdItemType(parfor_dim) - kernel_param_types = (ty_item, *kernel_param_types) + kernel_param_types = (ty_item, *parfor_param_types) kernel_sig = signature(types.none, *kernel_param_types) - # FIXME: A better design is required so that we do not have to create a - # queue every time. - ty_queue: DpctlSyclQueue = typemap[ - reductionKernelVar.parfor_params[0] - ].queue - exec_queue = dpctl.get_device_cached_queue(ty_queue.sycl_device) - kcres: _SPIRVKernelCompileResult = kernel_dispatcher.get_compile_result( types.void(*kernel_param_types) # kernel signature ) kernel_module: SPIRVKernelModule = kcres.kernel_device_ir_module - sycl_kernel = _compile_kernel_parfor( - exec_queue, - kernel_name, - kernel_ir, - kernel_param_types, - debug=flags.debuginfo, - ) - - flags.noalias = old_alias - parfor_params = ( reductionKernelVar.parfor_params.copy() + parfor_params[len(reductionKernelVar.parfor_params) :] # noqa: $203 ) return ParforKernel( - name=kernel_name, - kernel=sycl_kernel, signature=kernel_sig, kernel_args=parfor_params, kernel_arg_types=parfor_param_types, - queue=exec_queue, local_accessors=set(local_accessors_dict.values()), work_group_size=reductionKernelVar.work_group_size, kernel_module=kernel_module, @@ -305,24 +230,6 @@ def create_reduction_remainder_kernel_for_parfor( final_sum_var_name=final_sum_var_legal_name, reductionKernelVar=reductionKernelVar, ) - kernel_ir = kernel_template.kernel_ir - - kernel_dispatcher: SPIRVKernelDispatcher = kernel( - kernel_template._py_func, - _reduction_kernel_variables=reductionKernelVar, - ) - - var_table = get_name_var_table(kernel_ir.blocks) - new_var_dict = {} - reserved_names = ( - [sentinel_name] - + list(reductionKernelVar.param_dict.values()) - + reductionKernelVar.legal_loop_indices - ) - for name, _ in var_table.items(): - if not (name in reserved_names): - new_var_dict[name] = mk_unique_var(name) - replace_var_names(kernel_ir.blocks, new_var_dict) for i, _ in enumerate(reductionKernelVar.parfor_redvars): if reductionHelperList[i].global_size_var is not None: @@ -375,54 +282,23 @@ def create_reduction_remainder_kernel_for_parfor( _to_scalar_from_0d(typemap[final_sum_var_name[i]]) ) - kernel_param_types = reductionKernelVar.param_types - - kernel_stub_last_label = max(kernel_ir.blocks.keys()) + 1 - - # Add kernel stub last label to each parfor.loop_body label to prevent - # label conflicts. - loop_body = add_offset_to_labels( - reductionKernelVar.loop_body, kernel_stub_last_label + kernel_dispatcher: SPIRVKernelDispatcher = kernel( + kernel_template._py_func, + _parfor_args=ParforArguments(reduction_vars=reductionKernelVar), ) - # new label for splitting sentinel block - new_label = max(loop_body.keys()) + 1 - - update_sentinel(kernel_ir, sentinel_name, loop_body, new_label) - old_alias = flags.noalias - if not has_aliases: - flags.noalias = True + kernel_param_types = reductionKernelVar.param_types kernel_sig = signature(types.none, *kernel_param_types) - # FIXME: A better design is required so that we do not have to create a - # queue every time. - ty_queue: DpctlSyclQueue = typemap[ - reductionKernelVar.parfor_params[0] - ].queue - exec_queue = dpctl.get_device_cached_queue(ty_queue.sycl_device) - - sycl_kernel = _compile_kernel_parfor( - exec_queue, - kernel_name, - kernel_ir, - kernel_param_types, - debug=flags.debuginfo, - ) - kcres: _SPIRVKernelCompileResult = kernel_dispatcher.get_compile_result( types.void(*kernel_param_types) # kernel signature ) kernel_module: SPIRVKernelModule = kcres.kernel_device_ir_module - flags.noalias = old_alias - return ParforKernel( - name=kernel_name, - kernel=sycl_kernel, signature=kernel_sig, kernel_args=reductionKernelVar.parfor_params, kernel_arg_types=reductionKernelVar.func_arg_types, - queue=exec_queue, kernel_module=kernel_module, ) diff --git a/numba_dpex/core/types/dpnp_ndarray_type.py b/numba_dpex/core/types/dpnp_ndarray_type.py index fb09598243..93d39576bd 100644 --- a/numba_dpex/core/types/dpnp_ndarray_type.py +++ b/numba_dpex/core/types/dpnp_ndarray_type.py @@ -2,14 +2,28 @@ # # SPDX-License-Identifier: Apache-2.0 +from functools import partial + import dpnp from numba.core import ir, types from numba.core.ir_utils import get_np_ufunc_typ, mk_unique_var -from numba.core.pythonapi import NativeValue, PythonAPI, box, unbox from .usm_ndarray_type import USMNdArray +def partialclass(cls, *args, **kwds): + """Creates fabric class of the original class with preset initialization + arguments.""" + cls0 = partial(cls, *args, **kwds) + new_cls = type( + cls.__name__ + "Partial", + (cls,), + {"__new__": lambda cls, *args, **kwds: cls0(*args, **kwds)}, + ) + + return new_cls + + class DpnpNdArray(USMNdArray): """ The Numba type to represent an dpnp.ndarray. The type has the same @@ -48,7 +62,10 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): for inp in inputs ): return NotImplemented - return DpnpNdArray + + return partialclass( + DpnpNdArray, queue=inputs[0].queue, usm_type=inputs[0].usm_type + ) else: return @@ -71,6 +88,8 @@ def __allocate__( lhs_typ, size_typ, out, + # dpex specific argument: + queue_ir_val=None, ): """Generates the Numba typed IR representing the allocation of a new DpnpNdArray using the dpnp.ndarray overload. @@ -94,6 +113,10 @@ def __allocate__( Returns: The IR Value for the allocated array """ + # TODO: it looks like it is being called only for parfor allocations, + # so can we rely on it? We can grab information from input arguments + # from rhs, but doc does not set any restriction on parfor use only. + assert queue_ir_val is not None g_np_var = ir.Var(scope, mk_unique_var("$np_g_var"), loc) if typemap: typemap[g_np_var.name] = types.misc.Module(dpnp) @@ -132,11 +155,13 @@ def __allocate__( usm_typ_var = ir.Var(scope, mk_unique_var("$np_usm_type_var"), loc) # A default device string arg added as a placeholder device_typ_var = ir.Var(scope, mk_unique_var("$np_device_var"), loc) + queue_typ_var = ir.Var(scope, mk_unique_var("$np_queue_var"), loc) if typemap: typemap[layout_var.name] = types.literal(lhs_typ.layout) typemap[usm_typ_var.name] = types.literal(lhs_typ.usm_type) - typemap[device_typ_var.name] = types.literal(lhs_typ.device) + typemap[device_typ_var.name] = types.none + typemap[queue_typ_var.name] = lhs_typ.queue layout_var_assign = ir.Assign( ir.Const(lhs_typ.layout, loc), layout_var, loc @@ -145,16 +170,29 @@ def __allocate__( ir.Const(lhs_typ.usm_type, loc), usm_typ_var, loc ) device_typ_var_assign = ir.Assign( - ir.Const(lhs_typ.device, loc), device_typ_var, loc + ir.Const(None, loc), device_typ_var, loc ) + queue_typ_var_assign = ir.Assign(queue_ir_val, queue_typ_var, loc) out.extend( - [layout_var_assign, usm_typ_var_assign, device_typ_var_assign] + [ + layout_var_assign, + usm_typ_var_assign, + device_typ_var_assign, + queue_typ_var_assign, + ] ) alloc_call = ir.Expr.call( attr_var, - [size_var, typ_var, layout_var, device_typ_var, usm_typ_var], + [ + size_var, + typ_var, + layout_var, + device_typ_var, + usm_typ_var, + queue_typ_var, + ], (), loc, ) @@ -170,6 +208,7 @@ def __allocate__( layout_var, device_typ_var, usm_typ_var, + queue_typ_var, ] ], {},