From b12c1a38a80e6dd88d492e4af08b494d73be4c50 Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Sun, 10 Mar 2024 23:16:59 -0500 Subject: [PATCH] Disallow LocalAccessor arguments to RangeType kernels --- numba_dpex/experimental/launcher.py | 30 ++++++++++++++ .../spv_overloads/test_local_accessors.py | 41 +++++++++++++------ 2 files changed, 59 insertions(+), 12 deletions(-) diff --git a/numba_dpex/experimental/launcher.py b/numba_dpex/experimental/launcher.py index 82809a4c9e..44827835e5 100644 --- a/numba_dpex/experimental/launcher.py +++ b/numba_dpex/experimental/launcher.py @@ -25,6 +25,7 @@ ItemType, NdItemType, ) +from numba_dpex.core.types.kernel_api.local_accessor import LocalAccessorType from numba_dpex.core.utils import kernel_launcher as kl from numba_dpex.dpctl_iface import libsyclinterface_bindings as sycl from numba_dpex.dpctl_iface.wrappers import wrap_event_reference @@ -42,6 +43,23 @@ class _LLRange(NamedTuple): local_range_extents: list +def _has_a_local_accessor_argument(args): + """Checks if there exists at least one LocalAccessorType object in the + input tuple. + + Args: + args (_type_): A tuple of numba.core.Type objects + + Returns: + bool : True if at least one LocalAccessorType object was found, + otherwise False. + """ + for arg in args: + if isinstance(arg, LocalAccessorType): + return True + return False + + def _wrap_event_reference_tuple(ctx, builder, event1, event2): """Creates tuple data model from two event data models, so it can be boxed to Python.""" @@ -153,6 +171,18 @@ def _submit_kernel( # pylint: disable=too-many-arguments DeprecationWarning, ) + # Validate local accessor arguments are passed only to a kernel that is + # launched with an NdRange index space. Reference section 4.7.6.11. of the + # SYCL 2020 specification: A local_accessor must not be used in a SYCL + # kernel function that is invoked via single_task or via the simple form of + # parallel_for that takes a range parameter. + if _has_a_local_accessor_argument(ty_kernel_args_tuple) and isinstance( + ty_index_space, RangeType + ): + raise TypeError( + "A RangeType kernel cannot have a LocalAccessor argument" + ) + # ty_kernel_fn is type specific to exact function, so we can get function # directly from type and compile it. Thats why we don't need to get it in # codegen diff --git a/numba_dpex/tests/experimental/kernel_api_overloads/spv_overloads/test_local_accessors.py b/numba_dpex/tests/experimental/kernel_api_overloads/spv_overloads/test_local_accessors.py index 6721c830e3..1ce33c63f5 100644 --- a/numba_dpex/tests/experimental/kernel_api_overloads/spv_overloads/test_local_accessors.py +++ b/numba_dpex/tests/experimental/kernel_api_overloads/spv_overloads/test_local_accessors.py @@ -5,6 +5,7 @@ import dpnp import pytest +from numba.core.errors import TypingError import numba_dpex as dpex import numba_dpex.experimental as dpex_exp @@ -21,23 +22,24 @@ ) -@pytest.mark.parametrize("supported_dtype", list_of_supported_dtypes) -def test_local_accessor(supported_dtype): - """A test for passing a LocalAccessor object as a kernel argument.""" +@dpex_exp.kernel +def _kernel(nd_item: NdItem, a, slm): + i = nd_item.get_global_linear_id() + j = nd_item.get_local_linear_id() - @dpex_exp.kernel - def _kernel(nd_item: NdItem, a, slm): - i = nd_item.get_global_linear_id() - j = nd_item.get_local_linear_id() + slm[j] = 0 + group_barrier(nd_item.get_group(), MemoryScope.WORK_GROUP) - slm[j] = 0 + for m in range(100): + slm[j] += i * m group_barrier(nd_item.get_group(), MemoryScope.WORK_GROUP) - for m in range(100): - slm[j] += i * m - group_barrier(nd_item.get_group(), MemoryScope.WORK_GROUP) + a[i] = slm[j] - a[i] = slm[j] + +@pytest.mark.parametrize("supported_dtype", list_of_supported_dtypes) +def test_local_accessor(supported_dtype): + """A test for passing a LocalAccessor object as a kernel argument.""" N = 32 a = dpnp.empty(N, dtype=supported_dtype) @@ -52,3 +54,18 @@ def _kernel(nd_item: NdItem, a, slm): for idx in range(N): assert a[idx] == 4950 * idx + + +def test_local_accessor_argument_to_range_kernel(): + """Checks if an exception is raised when passing a local accessor to a + RangeType kernel. + """ + N = 32 + a = dpnp.empty(N) + slm = LocalAccessor((32 * 64), dtype=a.dtype) + + # Passing a local_accessor to a RangeType kernel should raise an exception. + # A TypeError is raised if NUMBA_CAPTURED_ERROR=new_style and a + # numba.TypingError is raised if NUMBA_CAPTURED_ERROR=old_style + with pytest.raises((TypeError, TypingError)): + dpex_exp.call_kernel(_kernel, dpex.Range(N), a, slm)