diff --git a/numba_dpex/kernel_api_impl/spirv/dispatcher.py b/numba_dpex/kernel_api_impl/spirv/dispatcher.py index 9aac39edb4..fca4c90fe2 100644 --- a/numba_dpex/kernel_api_impl/spirv/dispatcher.py +++ b/numba_dpex/kernel_api_impl/spirv/dispatcher.py @@ -415,13 +415,6 @@ def cb_llvm(dur): except ExecutionQueueInferenceError as eqie: raise eqie - # A function being compiled in the KERNEL compilation mode - # cannot have a non-void return value - if return_type and return_type != void: - raise KernelHasReturnValueError( - kernel_name=None, return_type=return_type, sig=sig - ) - # Don't recompile if signature already exists existing = self.overloads.get(tuple(args)) if existing is not None: @@ -444,6 +437,16 @@ def cb_llvm(dur): kcres: _SPIRVKernelCompileResult = compiler.compile( args, return_type ) + if ( + self.targetoptions["_compilation_mode"] + == CompilationMode.KERNEL + and kcres.signature.return_type is not None + and kcres.signature.return_type != types.void + ): + raise KernelHasReturnValueError( + kernel_name=self.py_func.__name__, + return_type=kcres.signature.return_type, + ) except errors.ForceLiteralArg as err: def folded(args, kws): diff --git a/numba_dpex/tests/experimental/test_kernel_has_return_value_error.py b/numba_dpex/tests/experimental/test_kernel_has_return_value_error.py new file mode 100644 index 0000000000..739c82ec4a --- /dev/null +++ b/numba_dpex/tests/experimental/test_kernel_has_return_value_error.py @@ -0,0 +1,41 @@ +# SPDX-FileCopyrightText: 2020 - 2024 Intel Corporation +# +# SPDX-License-Identifier: Apache-2.0 + +import dpnp +import pytest +from numba.core.errors import TypingError + +import numba_dpex.experimental as dpex +from numba_dpex import int32, usm_ndarray +from numba_dpex.core.exceptions import KernelHasReturnValueError +from numba_dpex.core.types.kernel_api.index_space_ids import ItemType + +i32arrty = usm_ndarray(ndim=1, dtype=int32, layout="C") +item_ty = ItemType(ndim=1) + + +def f(item, a): + return a + + +list_of_sig = [ + None, + (i32arrty(item_ty, i32arrty)), +] + + +@pytest.fixture(params=list_of_sig) +def sig(request): + return request.param + + +def test_return(sig): + a = dpnp.arange(1024, dtype=dpnp.int32) + + with pytest.raises((TypingError, KernelHasReturnValueError)) as excinfo: + kernel_fn = dpex.kernel(sig)(f) + dpex.call_kernel(kernel_fn, dpex.Range(a.size), a) + + if isinstance(excinfo.type, TypingError): + assert "KernelHasReturnValueError" in excinfo.value.args[0]