Skip to content

Commit

Permalink
Migrate parfor to SPIRVKernelDispatcher
Browse files Browse the repository at this point in the history
  • Loading branch information
ZzEeKkAa committed Apr 11, 2024
1 parent 1336f76 commit dc00ece
Show file tree
Hide file tree
Showing 6 changed files with 54 additions and 21 deletions.
16 changes: 16 additions & 0 deletions numba_dpex/core/parfors/kernel_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,14 @@
from numba.parfors import parfor

from numba_dpex.core import config
from numba_dpex.core.decorators import kernel
from numba_dpex.core.types.kernel_api.index_space_ids import ItemType
from numba_dpex.core.utils.call_kernel_builder import SPIRVKernelModule
from numba_dpex.kernel_api_impl.spirv import spirv_generator
from numba_dpex.kernel_api_impl.spirv.dispatcher import (
SPIRVKernelDispatcher,
_SPIRVKernelCompileResult,
)

from ..descriptor import dpex_kernel_target
from ..types import DpnpNdArray
Expand All @@ -46,6 +52,7 @@ def __init__(
queue: dpctl.SyclQueue,
local_accessors=None,
work_group_size=None,
kernel_module=None,
):
self.name = name
self.kernel = kernel
Expand All @@ -55,6 +62,7 @@ def __init__(
self.queue = queue
self.local_accessors = local_accessors
self.work_group_size = work_group_size
self.kernel_module = kernel_module


def _print_block(block):
Expand Down Expand Up @@ -369,6 +377,8 @@ def create_kernel_for_parfor(
)
kernel_ir = kernel_template.kernel_ir

kernel_dispatcher: SPIRVKernelDispatcher = kernel(kernel_template._py_func)

if config.DEBUG_ARRAY_OPT:
print("kernel_ir dump ", type(kernel_ir))
kernel_ir.dump()
Expand Down Expand Up @@ -469,6 +479,11 @@ def create_kernel_for_parfor(
debug=flags.debuginfo,
)

kcres: _SPIRVKernelCompileResult = kernel_dispatcher.get_compile_result(
types.void(*kernel_param_types) # kernel signature
)
kernel_module: SPIRVKernelModule = kcres.kernel_device_ir_module

flags.noalias = old_alias

if config.DEBUG_ARRAY_OPT:
Expand All @@ -481,6 +496,7 @@ def create_kernel_for_parfor(
kernel_args=parfor_args,
kernel_arg_types=func_arg_types,
queue=exec_queue,
kernel_module=kernel_module,
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __init__(
self._param_dict = param_dict

self._kernel_txt = self._generate_kernel_stub_as_string()
self._kernel_ir = self._generate_kernel_ir()
self._py_func, self._kernel_ir = self._generate_kernel_ir()

def _generate_kernel_stub_as_string(self):
"""Generates a stub dpex kernel for the parfor as a string.
Expand Down Expand Up @@ -111,7 +111,7 @@ def _generate_kernel_ir(self):
exec(self._kernel_txt, globls, locls)
kernel_fn = locls[self._kernel_name]

return compiler.run_frontend(kernel_fn)
return kernel_fn, compiler.run_frontend(kernel_fn)

@property
def kernel_ir(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def __init__(
self._typemap = typemap

self._kernel_txt = self._generate_kernel_stub_as_string()
self._kernel_ir = self._generate_kernel_ir()
self._py_func, self._kernel_ir = self._generate_kernel_ir()

def _generate_kernel_stub_as_string(self):
"""Generate reduction main kernel template"""
Expand Down Expand Up @@ -163,7 +163,7 @@ def _generate_kernel_ir(self):
exec(self._kernel_txt, globls, locls)
kernel_fn = locls[self._kernel_name]

return compiler.run_frontend(kernel_fn)
return kernel_fn, compiler.run_frontend(kernel_fn)

@property
def kernel_ir(self):
Expand Down Expand Up @@ -234,7 +234,7 @@ def __init__(
self._reductionKernelVar = reductionKernelVar

self._kernel_txt = self._generate_kernel_stub_as_string()
self._kernel_ir = self._generate_kernel_ir()
self._py_func, self._kernel_ir = self._generate_kernel_ir()

def _generate_kernel_stub_as_string(self):
"""Generate reduction remainder kernel template"""
Expand Down Expand Up @@ -322,7 +322,7 @@ def _generate_kernel_ir(self):
exec(self._kernel_txt, globls, locls)
kernel_fn = locls[self._kernel_name]

return compiler.run_frontend(kernel_fn)
return kernel_fn, compiler.run_frontend(kernel_fn)

@property
def kernel_ir(self):
Expand Down
23 changes: 9 additions & 14 deletions numba_dpex/core/parfors/parfor_lowerer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,6 @@
create_reduction_remainder_kernel_for_parfor,
)

# A global list of kernels to keep the objects alive indefinitely.
keep_alive_kernels = []


def _getvar(lowerer, x):
"""Returns the LLVM Value corresponding to a Numba IR variable.
Expand Down Expand Up @@ -154,14 +150,12 @@ def _submit_parfor_kernel(
kernel_fn: ParforKernel,
global_range,
local_range,
debug=False,
):
"""
Adds a call to submit a kernel function into the function body of the
current Numba JIT compiled function.
"""
# Ensure that the Python arguments are kept alive for the duration of
# the kernel execution
keep_alive_kernels.append(kernel_fn.kernel)
kl_builder = KernelLaunchIRBuilder(
lowerer.context, lowerer.builder, kernel_dmm
)
Expand All @@ -188,19 +182,17 @@ def _submit_parfor_kernel(
else:
kernel_args.append(_getvar(lowerer, arg))

kernel_ref_addr = kernel_fn.kernel.addressof_ref()
kernel_ref = lowerer.builder.inttoptr(
lowerer.context.get_constant(types.uintp, kernel_ref_addr),
cgutils.voidptr_t,
)

kl_builder.set_kernel(kernel_ref)
kl_builder.set_queue(queue_ref)
kl_builder.set_range(global_range, local_range)
kl_builder.set_arguments(
kernel_fn.kernel_arg_types, kernel_args=kernel_args
)
kl_builder.set_dependent_events([])
kl_builder.set_kernel_from_spirv(
kernel_fn.kernel_module,
debug=debug,
)

event_ref = kl_builder.submit()

sycl.dpctl_event_wait(lowerer.builder, event_ref)
Expand Down Expand Up @@ -278,6 +270,7 @@ def _reduction_codegen(
parfor_kernel,
global_range,
local_range,
debug=flags.debuginfo,
)

parfor_kernel = create_reduction_remainder_kernel_for_parfor(
Expand All @@ -297,6 +290,7 @@ def _reduction_codegen(
parfor_kernel,
global_range,
local_range,
debug=flags.debuginfo,
)

reductionKernelVar.copy_final_sum_to_host(parfor_kernel)
Expand Down Expand Up @@ -418,6 +412,7 @@ def _lower_parfor_as_kernel(self, lowerer, parfor):
parfor_kernel,
global_range,
local_range,
debug=flags.debuginfo,
)

# TODO: free the kernel at this point
Expand Down
22 changes: 22 additions & 0 deletions numba_dpex/core/parfors/reduction_kernel_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,16 @@
)
from numba.core.typing import signature

from numba_dpex.core.decorators import kernel
from numba_dpex.core.parfors.reduction_helper import ReductionKernelVariables
from numba_dpex.core.types import DpctlSyclQueue
from numba_dpex.core.types.kernel_api.index_space_ids import NdItemType
from numba_dpex.core.types.kernel_api.local_accessor import LocalAccessorType
from numba_dpex.core.utils.call_kernel_builder import SPIRVKernelModule
from numba_dpex.kernel_api_impl.spirv.dispatcher import (
SPIRVKernelDispatcher,
_SPIRVKernelCompileResult,
)

from .kernel_builder import _print_body # saved for debug
from .kernel_builder import (
Expand Down Expand Up @@ -113,6 +119,8 @@ def create_reduction_main_kernel_for_parfor(
)
kernel_ir = kernel_template.kernel_ir

kernel_dispatcher: SPIRVKernelDispatcher = kernel(kernel_template._py_func)

for i, name in enumerate(reductionKernelVar.parfor_params):
try:
tmp = reductionKernelVar.parfor_redvars_to_redarrs[name][0]
Expand Down Expand Up @@ -171,6 +179,11 @@ def create_reduction_main_kernel_for_parfor(
].queue
exec_queue = dpctl.get_device_cached_queue(ty_queue.sycl_device)

kcres: _SPIRVKernelCompileResult = kernel_dispatcher.get_compile_result(
types.void(*kernel_param_types) # kernel signature
)
kernel_module: SPIRVKernelModule = kcres.kernel_device_ir_module

sycl_kernel = _compile_kernel_parfor(
exec_queue,
kernel_name,
Expand All @@ -195,6 +208,7 @@ def create_reduction_main_kernel_for_parfor(
queue=exec_queue,
local_accessors=set(local_accessors_dict.values()),
work_group_size=reductionKernelVar.work_group_size,
kernel_module=kernel_module,
)


Expand Down Expand Up @@ -290,6 +304,8 @@ def create_reduction_remainder_kernel_for_parfor(
)
kernel_ir = kernel_template.kernel_ir

kernel_dispatcher: SPIRVKernelDispatcher = kernel(kernel_template._py_func)

var_table = get_name_var_table(kernel_ir.blocks)
new_var_dict = {}
reserved_names = (
Expand Down Expand Up @@ -388,6 +404,11 @@ def create_reduction_remainder_kernel_for_parfor(
debug=flags.debuginfo,
)

kcres: _SPIRVKernelCompileResult = kernel_dispatcher.get_compile_result(
types.void(*kernel_param_types) # kernel signature
)
kernel_module: SPIRVKernelModule = kcres.kernel_device_ir_module

flags.noalias = old_alias

return ParforKernel(
Expand All @@ -397,4 +418,5 @@ def create_reduction_remainder_kernel_for_parfor(
kernel_args=reductionKernelVar.parfor_params,
kernel_arg_types=reductionKernelVar.func_arg_types,
queue=exec_queue,
kernel_module=kernel_module,
)
2 changes: 1 addition & 1 deletion numba_dpex/kernel_api_impl/spirv/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from numba.core.types import void
from numba.core.typing.typeof import Purpose, typeof

from numba_dpex import config
from numba_dpex.core import config
from numba_dpex.core.descriptor import dpex_kernel_target
from numba_dpex.core.exceptions import (
ExecutionQueueInferenceError,
Expand Down

0 comments on commit dc00ece

Please sign in to comment.