Skip to content

Commit

Permalink
Fix KernelHasReturnValueError inside KernelDispatcher.
Browse files Browse the repository at this point in the history
  • Loading branch information
Diptorup Deb committed Mar 20, 2024
1 parent 68b1f39 commit ad9955c
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 7 deletions.
17 changes: 10 additions & 7 deletions numba_dpex/kernel_api_impl/spirv/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,13 +415,6 @@ def cb_llvm(dur):
except ExecutionQueueInferenceError as eqie:
raise eqie

# A function being compiled in the KERNEL compilation mode
# cannot have a non-void return value
if return_type and return_type != void:
raise KernelHasReturnValueError(
kernel_name=None, return_type=return_type, sig=sig
)

# Don't recompile if signature already exists
existing = self.overloads.get(tuple(args))
if existing is not None:
Expand All @@ -444,6 +437,16 @@ def cb_llvm(dur):
kcres: _SPIRVKernelCompileResult = compiler.compile(
args, return_type
)
if (
self.targetoptions["_compilation_mode"]
== CompilationMode.KERNEL
and kcres.signature.return_type is not None
and kcres.signature.return_type != types.void
):
raise KernelHasReturnValueError(
kernel_name=self.py_func.__name__,
return_type=kcres.signature.return_type,
)
except errors.ForceLiteralArg as err:

def folded(args, kws):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# SPDX-FileCopyrightText: 2020 - 2024 Intel Corporation
#
# SPDX-License-Identifier: Apache-2.0

import dpnp
import pytest
from numba.core.errors import TypingError

import numba_dpex.experimental as dpex
from numba_dpex import int32, usm_ndarray
from numba_dpex.core.exceptions import KernelHasReturnValueError
from numba_dpex.core.types.kernel_api.index_space_ids import ItemType

i32arrty = usm_ndarray(ndim=1, dtype=int32, layout="C")
item_ty = ItemType(ndim=1)


def f(item, a):
return a


list_of_sig = [
None,
(i32arrty(item_ty, i32arrty)),
]


@pytest.fixture(params=list_of_sig)
def sig(request):
return request.param


def test_return(sig):
a = dpnp.arange(1024, dtype=dpnp.int32)

with pytest.raises((TypingError, KernelHasReturnValueError)) as excinfo:
kernel_fn = dpex.kernel(sig)(f)
dpex.call_kernel(kernel_fn, dpex.Range(a.size), a)

if isinstance(excinfo.type, TypingError):
assert "KernelHasReturnValueError" in excinfo.value.args[0]

0 comments on commit ad9955c

Please sign in to comment.