diff --git a/numba_dpex/core/caching.py b/numba_dpex/core/caching.py index 1c46485c6c..807703188e 100644 --- a/numba_dpex/core/caching.py +++ b/numba_dpex/core/caching.py @@ -224,15 +224,18 @@ class LRUCache(AbstractCache): with a dictionary as a lookup table. """ - def __init__(self, capacity=10, pyfunc=None): + def __init__(self, name="cache", capacity=10, pyfunc=None): """Constructor for LRUCache. Args: + name (str, optional): The name of the cache, useful for + debugging. capacity (int, optional): The max capacity of the cache. Defaults to 10. pyfunc (NoneType, optional): A python function to be cached. Defaults to None. """ + self._name = name self._capacity = capacity self._lookup = {} self._evicted = {} @@ -432,8 +435,8 @@ def get(self, key): value = self._cache_file.load(key) if config.DEBUG_CACHE: print( - "[cache]: unpickled an evicted artifact, " - "key: {0:s}.".format(str(key)) + "[{0:s}]: unpickled an evicted artifact, " + "key: {1:s}.".format(self._name, str(key)) ) else: value = self._evicted[key] @@ -442,8 +445,8 @@ def get(self, key): else: if config.DEBUG_CACHE: print( - "[cache] size: {0:d}, loading artifact, key: {1:s}".format( - len(self._lookup), str(key) + "[{0:s}] size: {1:d}, loading artifact, key: {2:s}".format( + self._name, len(self._lookup), str(key) ) ) node = self._lookup[key] @@ -464,8 +467,8 @@ def put(self, key, value): if key in self._lookup: if config.DEBUG_CACHE: print( - "[cache] size: {0:d}, storing artifact, key: {1:s}".format( - len(self._lookup), str(key) + "[{0:s}] size: {1:d}, storing artifact, key: {2:s}".format( + self._name, len(self._lookup), str(key) ) ) self._lookup[key].value = value @@ -480,8 +483,9 @@ def put(self, key, value): if self._cache_file: if config.DEBUG_CACHE: print( - "[cache] size: {0:d}, pickling the LRU item, " - "key: {1:s}, indexed at {2:s}.".format( + "[{0:s}] size: {1:d}, pickling the LRU item, " + "key: {2:s}, indexed at {3:s}.".format( + self._name, len(self._lookup), str(self._head.key), self._cache_file._index_path, @@ -496,8 +500,8 @@ def put(self, key, value): self._lookup.pop(self._head.key) if config.DEBUG_CACHE: print( - "[cache] size: {0:d}, capacity exceeded, evicted".format( - len(self._lookup) + "[{0:s}] size: {1:d}, capacity exceeded, evicted".format( + self._name, len(self._lookup) ), self._head.key, ) @@ -509,7 +513,7 @@ def put(self, key, value): self._append_tail(new_node) if config.DEBUG_CACHE: print( - "[cache] size: {0:d}, saved artifact, key: {1:s}".format( - len(self._lookup), str(key) + "[{0:s}] size: {1:d}, saved artifact, key: {2:s}".format( + self._name, len(self._lookup), str(key) ) ) diff --git a/numba_dpex/core/compiler.py b/numba_dpex/core/compiler.py index 5dd0445ee2..0529382fc1 100644 --- a/numba_dpex/core/compiler.py +++ b/numba_dpex/core/compiler.py @@ -225,7 +225,7 @@ def compile_with_dpex( return_type, target_context, typing_context, - debug=None, + debug=False, is_kernel=True, extra_compile_flags=None, ): @@ -256,7 +256,7 @@ def compile_with_dpex( flags.no_cpython_wrapper = True flags.nrt = False - if debug is not None: + if debug: flags.debuginfo = debug # Run compilation pipeline diff --git a/numba_dpex/core/kernel_interface/dispatcher.py b/numba_dpex/core/kernel_interface/dispatcher.py index 37bc1e8b26..7f8b193b3c 100644 --- a/numba_dpex/core/kernel_interface/dispatcher.py +++ b/numba_dpex/core/kernel_interface/dispatcher.py @@ -89,7 +89,9 @@ def __init__( self._cache = NullCache() elif enable_cache: self._cache = LRUCache( - capacity=config.CACHE_SIZE, pyfunc=self.pyfunc + name="SPIRVKernelCache", + capacity=config.CACHE_SIZE, + pyfunc=self.pyfunc, ) else: self._cache = NullCache() @@ -118,7 +120,9 @@ def __init__( if specialization_sigs: self._has_specializations = True self._specialization_cache = LRUCache( - capacity=config.CACHE_SIZE, pyfunc=self.pyfunc + name="SPIRVKernelSpecializationCache", + capacity=config.CACHE_SIZE, + pyfunc=self.pyfunc, ) for sig in specialization_sigs: self._specialize(sig) diff --git a/numba_dpex/core/kernel_interface/func.py b/numba_dpex/core/kernel_interface/func.py index 42fe2e19c1..36b1393ab4 100644 --- a/numba_dpex/core/kernel_interface/func.py +++ b/numba_dpex/core/kernel_interface/func.py @@ -6,100 +6,242 @@ """ +from numba.core import sigutils, types from numba.core.typing.templates import AbstractTemplate, ConcreteTemplate +from numba_dpex import config +from numba_dpex.core.caching import LRUCache, NullCache, build_key from numba_dpex.core.compiler import compile_with_dpex from numba_dpex.core.descriptor import dpex_target +from numba_dpex.utils import npytypes_array_to_dpex_array -def compile_func(pyfunc, return_type, args, debug=None): - cres = compile_with_dpex( - pyfunc=pyfunc, - pyfunc_name=pyfunc.__name__, - return_type=return_type, - target_context=dpex_target.target_context, - typing_context=dpex_target.typing_context, - args=args, - is_kernel=False, - debug=debug, - ) - func = cres.library.get_function(cres.fndesc.llvm_func_name) - cres.target_context.mark_ocl_device(func) - devfn = DpexFunction(cres) +class DpexFunction(object): + """Class to materialize dpex function + + Helper class to eager compile a specialized `numba_dpex.func` + decorated Python function into a LLVM function with `spir_func` + calling convention. + + A specialized `numba_dpex.func` decorated Python function is one + where the user has specified a signature or a list of signatures + for the function. The function gets compiled as soon as the Python + program is loaded, i.e., eagerly, instead of JIT compilation once + the function is invoked. + """ + + def __init__(self, pyfunc, debug=False): + """Constructor for `DpexFunction` + + Args: + pyfunc (`function`): A python function to be compiled. + debug (`bool`, optional): Debug option for compilation. + Defaults to `False`. + """ + self._pyfunc = pyfunc + self._debug = debug - class _function_template(ConcreteTemplate): - key = devfn - cases = [cres.signature] + def compile(self, arg_types, return_types): + """The actual compilation function. - cres.typing_context.insert_user_function(devfn, _function_template) - libs = [cres.library] - cres.target_context.insert_user_function(devfn, cres.fndesc, libs) - return devfn + Args: + arg_types (`tuple`): Function argument types in a tuple. + return_types (`numba.core.types.scalars.Integer`): + An integer value to specify the return type. + Returns: + `numba.core.compiler.CompileResult`: The compiled result + """ -def compile_func_template(pyfunc, debug=None): - """Compile a DpexFunctionTemplate""" + cres = compile_with_dpex( + pyfunc=self._pyfunc, + pyfunc_name=self._pyfunc.__name__, + return_type=return_types, + target_context=dpex_target.target_context, + typing_context=dpex_target.typing_context, + args=arg_types, + is_kernel=False, + debug=self._debug, + ) + func = cres.library.get_function(cres.fndesc.llvm_func_name) + cres.target_context.mark_ocl_device(func) - dft = DpexFunctionTemplate(pyfunc, debug=debug) + return cres - class _function_template(AbstractTemplate): - key = dft - def generic(self, args, kws): - if kws: - raise AssertionError("No keyword arguments allowed.") - return dft.compile(args) +class DpexFunctionTemplate(object): + """Helper class to compile an unspecialized `numba_dpex.func` + + A helper class to JIT compile an unspecialized `numba_dpex.func` + decorated Python function into an LLVM function with `spir_func` + calling convention. + """ + + def __init__(self, pyfunc, debug=False, enable_cache=True): + """Constructor for `DpexFunctionTemplate` + + Args: + pyfunc (function): A python function to be compiled. + debug (bool, optional): Debug option for compilation. + Defaults to `False`. + enable_cache (bool, optional): Flag to turn on/off caching. + Defaults to `True`. + """ + self._pyfunc = pyfunc + self._debug = debug + self._enable_cache = enable_cache + + if not config.ENABLE_CACHE: + self._cache = NullCache() + elif self._enable_cache: + self._cache = LRUCache( + name="DpexFunctionTemplateCache", + capacity=config.CACHE_SIZE, + pyfunc=self._pyfunc, + ) + else: + self._cache = NullCache() + self._cache_hits = 0 - dpex_target.typing_context.insert_user_function(dft, _function_template) - return dft + @property + def cache(self): + """Cache accessor""" + return self._cache + @property + def cache_hits(self): + """Cache hit count accessor""" + return self._cache_hits -class DpexFunctionTemplate(object): - """Unmaterialized dpex function""" + def compile(self, args): + """Compile a `numba_dpex.func` decorated function - def __init__(self, pyfunc, debug=None): - self.py_func = pyfunc - self.debug = debug - self._compileinfos = {} + Compile a `numba_dpex.func` decorated Python function with the + given argument types. Each signature is compiled once by caching + the compiled function inside this object. - def compile(self, args): - """Compile a dpex.func decorated Python function with the given - argument types. + Args: + args (`tuple`): Function argument types in a tuple. - Each signature is compiled once by caching the compiled function inside - this object. + Returns: + `numba.core.typing.templates.Signature`: Signature of the + compiled result. """ - if args not in self._compileinfos: + + argtypes = [ + dpex_target.typing_context.resolve_argument_type(arg) + for arg in args + ] + key = build_key( + tuple(argtypes), + self._pyfunc, + dpex_target.target_context.codegen(), + ) + cres = self._cache.get(key) + if cres is None: + self._cache_hits += 1 cres = compile_with_dpex( - pyfunc=self.py_func, - pyfunc_name=self.py_func.__name__, + pyfunc=self._pyfunc, + pyfunc_name=self._pyfunc.__name__, return_type=None, target_context=dpex_target.target_context, typing_context=dpex_target.typing_context, args=args, is_kernel=False, - debug=self.debug, + debug=self._debug, ) func = cres.library.get_function(cres.fndesc.llvm_func_name) cres.target_context.mark_ocl_device(func) - first_definition = not self._compileinfos - self._compileinfos[args] = cres libs = [cres.library] - if first_definition: - # First definition - cres.target_context.insert_user_function( - self, cres.fndesc, libs - ) - else: - cres.target_context.add_user_function(self, cres.fndesc, libs) + cres.target_context.insert_user_function(self, cres.fndesc, libs) + self._cache.put(key, cres) + return cres.signature - else: - cres = self._compileinfos[args] - return cres.signature +def compile_func(pyfunc, signature, debug=False): + """Compiles a specialized `numba_dpex.func` + Compiles a specialized `numba_dpex.func` decorated function to native + binary library function and returns the library wrapped inside a + `numba_dpex.core.kernel_interface.func.DpexFunction` object. -class DpexFunction(object): - def __init__(self, cres): - self.cres = cres + Args: + pyfunc (`function`): A python function to be compiled. + signature (`list`): A list of `numba.core.typing.templates.Signature`'s + debug (`bool`, optional): Debug options. Defaults to `False`. + + Returns: + `numba_dpex.core.kernel_interface.func.DpexFunction`: A `DpexFunction` object + """ + + devfn = DpexFunction(pyfunc, debug=debug) + + cres = [] + for sig in signature: + arg_types, return_types = sigutils.normalize_signature(sig) + arg_types = tuple( + [ + npytypes_array_to_dpex_array(ty) + if isinstance(ty, types.npytypes.Array) + else ty + for ty in arg_types + ] + ) + c = devfn.compile(arg_types, return_types) + cres.append(c) + + class _function_template(ConcreteTemplate): + unsafe_casting = False + exact_match_required = True + key = devfn + cases = [c.signature for c in cres] + + cres[0].typing_context.insert_user_function(devfn, _function_template) + + for c in cres: + c.target_context.insert_user_function(devfn, c.fndesc, [c.library]) + + return devfn + + +def compile_func_template(pyfunc, debug=False, enable_cache=True): + """Converts a `numba_dpex.func` function to an `AbstractTemplate` + + Converts a `numba_dpex.func` decorated function to a Numba + `AbstractTemplate` and returns the object wrapped inside a + `numba_dpex.core.kernel_interface.func.DpexFunctionTemplate` + object. + + A `DpexFunctionTemplate` object is an abstract representation for + a native function with `spir_func` calling convention that is to be + JIT compiled once the argument types are resolved. + + Args: + pyfunc (`function`): A python function to be compiled. + debug (`bool`, optional): Debug options. Defaults to `False`. + + Raises: + `AssertionError`: Raised if keyword arguments are supplied in + the inner generic function. + + Returns: + `numba_dpex.core.kernel_interface.func.DpexFunctionTemplate`: + A `DpexFunctionTemplate` object. + """ + + dft = DpexFunctionTemplate(pyfunc, debug=debug, enable_cache=enable_cache) + + class _function_template(AbstractTemplate): + unsafe_casting = False + exact_match_required = True + key = dft + + def generic(self, args, kws): + if kws: + raise AssertionError("No keyword arguments allowed.") + return dft.compile(args) + + dpex_target.typing_context.insert_user_function(dft, _function_template) + + return dft diff --git a/numba_dpex/core/passes/lowerer.py b/numba_dpex/core/passes/lowerer.py index b12cf31f5b..cd98d97128 100644 --- a/numba_dpex/core/passes/lowerer.py +++ b/numba_dpex/core/passes/lowerer.py @@ -55,7 +55,7 @@ def _compile_kernel_parfor( - sycl_queue, kernel_name, func_ir, args, args_with_addrspaces, debug=None + sycl_queue, kernel_name, func_ir, args, args_with_addrspaces, debug=False ): # We only accept numba_dpex.core.types.Array type for arg in args_with_addrspaces: diff --git a/numba_dpex/decorators.py b/numba_dpex/decorators.py index ace5354e97..3c7bd16f20 100644 --- a/numba_dpex/decorators.py +++ b/numba_dpex/decorators.py @@ -4,7 +4,7 @@ import inspect -from numba.core import sigutils, types +from numba.core import sigutils from numba_dpex.core.kernel_interface.dispatcher import ( JitKernel, @@ -14,13 +14,12 @@ compile_func, compile_func_template, ) -from numba_dpex.utils import npytypes_array_to_dpex_array def kernel( func_or_sig=None, access_types=None, - debug=None, + debug=False, enable_cache=True, ): """A decorator to define a kernel function. @@ -55,7 +54,7 @@ def _kernel_dispatcher(pyfunc, sigs=None): elif isinstance(func_or_sig, list) or sigutils.is_signature(func_or_sig): # String signatures are not supported as passing usm_ndarray type as # a string is not possible. Numba's sigutils relies on the type being - # available in Numba's types.__dpct__ and dpex types are not registered + # available in Numba's `types.__dict__` and dpex types are not registered # there yet. if isinstance(func_or_sig, list): for sig in func_or_sig: @@ -94,39 +93,51 @@ def _specialized_kernel_dispatcher(pyfunc): return _kernel_dispatcher(func) -def func(signature=None, debug=None): - if signature is None: - return _func_autojit_wrapper(debug=debug) - elif not sigutils.is_signature(signature): - func = signature - return _func_autojit(func, debug=debug) - else: - return _func_jit(signature, debug=debug) - - -def _func_jit(signature, debug=None): - argtypes, restype = sigutils.normalize_signature(signature) - argtypes = tuple( - [ - npytypes_array_to_dpex_array(ty) - if isinstance(ty, types.npytypes.Array) - else ty - for ty in argtypes - ] - ) - - def _wrapped(pyfunc): - return compile_func(pyfunc, restype, argtypes, debug=debug) +def func(func_or_sig=None, debug=False, enable_cache=True): + """A decorator to define a kernel device function. - return _wrapped + Device functions are functions that can be only invoked from a kernel + and not from a host function. This provides a special decorator + `numba_dpex.func` specifically to implement a device function. + A device function can be invoked from another device function and + unlike a kernel function, a device function can return a value like + normal functions. + """ -def _func_autojit_wrapper(debug=None): - def _func_autojit(pyfunc, debug=debug): - return compile_func_template(pyfunc, debug=debug) + def _func_autojit(pyfunc): + return compile_func_template( + pyfunc, debug=debug, enable_cache=enable_cache + ) - return _func_autojit + if func_or_sig is None: + return _func_autojit + elif isinstance(func_or_sig, str): + raise NotImplementedError( + "Specifying signatures as string is not yet supported by numba-dpex" + ) + elif isinstance(func_or_sig, list) or sigutils.is_signature(func_or_sig): + # String signatures are not supported as passing usm_ndarray type as + # a string is not possible. Numba's sigutils relies on the type being + # available in Numba's types.__dict__ and dpex types are not registered + # there yet. + if isinstance(func_or_sig, list): + for sig in func_or_sig: + if isinstance(sig, str): + raise NotImplementedError( + "Specifying signatures as string is not yet supported " + "by numba-dpex" + ) + # Specialized signatures can either be a single signature or a list. + # In case only one signature is provided convert it to a list + if not isinstance(func_or_sig, list): + func_or_sig = [func_or_sig] + def _wrapped(pyfunc): + return compile_func(pyfunc, func_or_sig, debug=debug) -def _func_autojit(pyfunc, debug=None): - return compile_func_template(pyfunc, debug=debug) + return _wrapped + else: + # no signature + func = func_or_sig + return _func_autojit(func) diff --git a/numba_dpex/examples/kernel/device_func.py b/numba_dpex/examples/kernel/device_func.py index 939a79336b..1c6fe52d39 100644 --- a/numba_dpex/examples/kernel/device_func.py +++ b/numba_dpex/examples/kernel/device_func.py @@ -5,18 +5,33 @@ import dpnp as np import numba_dpex as ndpex +from numba_dpex import float32, int32, int64 # Array size N = 10 -# A device callable function that can be invoked from ``kernel`` and other -# device functions +# A device callable function that can be invoked from +# ``kernel`` and other device functions @ndpex.func def a_device_function(a): return a + 1 +# A device callable function with signature that can be invoked +# from ``kernel`` and other device functions +@ndpex.func(int32(int32)) +def a_device_function_int32(a): + return a + 1 + + +# A device callable function with list signature that can be invoked +# from ``kernel`` and other device functions +@ndpex.func([int32(int32), float32(float32)]) +def a_device_function_int32_float32(a): + return a + 1 + + # A device callable function can call another device function @ndpex.func def another_device_function(a): @@ -30,24 +45,105 @@ def a_kernel_function(a, b): b[i] = another_device_function(a[i]) -# Utility function for printing -def driver(a, b, N): - print("A=", a) - a_kernel_function[N](a, b) - print("B=", b) +# A kernel function that calls the device function +@ndpex.kernel +def a_kernel_function_int32(a, b): + i = ndpex.get_global_id(0) + b[i] = a_device_function_int32(a[i]) -# Main function -def main(): +# A kernel function that calls the device function +@ndpex.kernel +def a_kernel_function_int32_float32(a, b): + i = ndpex.get_global_id(0) + b[i] = a_device_function_int32_float32(a[i]) + + +# test function 1: tests basic +def test1(): a = np.ones(N) b = np.ones(N) print("Using device ...") print(a.device) - driver(a, b, N) + + print("A=", a) + try: + a_kernel_function[N](a, b) + except Exception as err: + print(err) + print("B=", b) + + print("Done...") + + +# test function 2: test device func with signature +def test2(): + a = np.ones(N, dtype=np.int32) + b = np.ones(N, dtype=np.int32) + + print("Using device ...") + print(a.device) + + print("A=", a) + try: + a_kernel_function_int32[N](a, b) + except Exception as err: + print(err) + print("B=", b) + + print("Done...") + + +# test function 3: test device function with list signature +def test3(): + a = np.ones(N, dtype=np.int32) + b = np.ones(N, dtype=np.int32) + + print("Using device ...") + print(a.device) + + print("A=", a) + try: + a_kernel_function_int32_float32[N](a, b) + except Exception as err: + print(err) + print("B=", b) + + # with a different dtype + a = np.ones(N, dtype=np.float32) + b = np.ones(N, dtype=np.float32) + + print("Using device ...") + print(a.device) + + print("A=", a) + try: + a_kernel_function_int32_float32[N](a, b) + except Exception as err: + print(err) + print("B=", b) + + # this will fail, since int64 is not in + # the signature list: [int32(int32), float32(float32)] + a = np.ones(N, dtype=np.int64) + b = np.ones(N, dtype=np.int64) + + print("Using device ...") + print(a.device) + + print("A=", a) + try: + a_kernel_function_int32_float32[N](a, b) + except Exception as err: + print(err) + print("B=", b) print("Done...") +# main function if __name__ == "__main__": - main() + test1() + test2() + test3() diff --git a/numba_dpex/tests/kernel_tests/test_atomic_op.py b/numba_dpex/tests/kernel_tests/test_atomic_op.py index 404365cc71..dcdce68dd8 100644 --- a/numba_dpex/tests/kernel_tests/test_atomic_op.py +++ b/numba_dpex/tests/kernel_tests/test_atomic_op.py @@ -219,7 +219,7 @@ def test_atomic_fp_native( with override_config("NATIVE_FP_ATOMICS", NATIVE_FP_ATOMICS): kernel.compile( args=argtypes, - debug=None, + debug=False, compile_flags=None, target_ctx=dpex_target.target_context, typing_ctx=dpex_target.typing_context, diff --git a/numba_dpex/tests/kernel_tests/test_func_specialization.py b/numba_dpex/tests/kernel_tests/test_func_specialization.py new file mode 100644 index 0000000000..ba0f50df84 --- /dev/null +++ b/numba_dpex/tests/kernel_tests/test_func_specialization.py @@ -0,0 +1,104 @@ +import dpctl.tensor as dpt +import numpy as np +import pytest + +import numba_dpex as dpex +from numba_dpex import float32, int32 + +single_signature = dpex.func(int32(int32)) +list_signature = dpex.func([int32(int32), float32(float32)]) + +# Array size +N = 10 + + +def increment(a): + return a + 1 + + +def test_basic(): + """Basic test with device func""" + + f = dpex.func(increment) + + def kernel_function(a, b): + """Kernel function that applies f() in parallel""" + i = dpex.get_global_id(0) + b[i] = f(a[i]) + + k = dpex.kernel(kernel_function) + + a = dpt.ones(N) + b = dpt.ones(N) + + k[N](a, b) + + assert np.array_equal(dpt.asnumpy(b), dpt.asnumpy(a) + 1) + + +def test_single_signature(): + """Basic test with single signature""" + + fi32 = single_signature(increment) + + def kernel_function(a, b): + """Kernel function that applies fi32() in parallel""" + i = dpex.get_global_id(0) + b[i] = fi32(a[i]) + + k = dpex.kernel(kernel_function) + + # Test with int32, should work + a = dpt.ones(N, dtype=dpt.int32) + b = dpt.ones(N, dtype=dpt.int32) + + k[N](a, b) + + assert np.array_equal(dpt.asnumpy(b), dpt.asnumpy(a) + 1) + + # Test with int64, should fail + a = dpt.ones(N, dtype=dpt.int64) + b = dpt.ones(N, dtype=dpt.int64) + + with pytest.raises(Exception) as e: + k[N](a, b) + + assert " >>> (int64)" in e.value.args[0] + + +def test_list_signature(): + """Basic test with list signature""" + + fi32f32 = list_signature(increment) + + def kernel_function(a, b): + """Kernel function that applies fi32f32() in parallel""" + i = dpex.get_global_id(0) + b[i] = fi32f32(a[i]) + + k = dpex.kernel(kernel_function) + + # Test with int32, should work + a = dpt.ones(N, dtype=dpt.int32) + b = dpt.ones(N, dtype=dpt.int32) + + k[N](a, b) + + assert np.array_equal(dpt.asnumpy(b), dpt.asnumpy(a) + 1) + + # Test with float32, should work + a = dpt.ones(N, dtype=dpt.float32) + b = dpt.ones(N, dtype=dpt.float32) + + k[N](a, b) + + assert np.array_equal(dpt.asnumpy(b), dpt.asnumpy(a) + 1) + + # Test with int64, should fail + a = dpt.ones(N, dtype=dpt.int64) + b = dpt.ones(N, dtype=dpt.int64) + + with pytest.raises(Exception) as e: + k[N](a, b) + + assert " >>> (int64)" in e.value.args[0] diff --git a/numba_dpex/tests/test_debuginfo.py b/numba_dpex/tests/test_debuginfo.py index 4b2eadf15e..2dd8e4abc9 100644 --- a/numba_dpex/tests/test_debuginfo.py +++ b/numba_dpex/tests/test_debuginfo.py @@ -22,7 +22,7 @@ def debug_option(request): return request.param -def get_kernel_ir(fn, sig, debug=None): +def get_kernel_ir(fn, sig, debug=False): kernel = dpex.core.kernel_interface.spirv_kernel.SpirvKernel( fn, fn.__name__ )