From 554e52d52f0c5d804ebf33c946e10c6d87334724 Mon Sep 17 00:00:00 2001 From: Yevhenii Havrylko Date: Wed, 21 Feb 2024 19:44:45 -0500 Subject: [PATCH] Overload generic item's attribute 'dimensions' --- .../_index_space_id_overloads.py | 23 ++++++++++- .../experimental/test_index_space_ids.py | 41 +++++++++++++++++++ 2 files changed, 63 insertions(+), 1 deletion(-) diff --git a/numba_dpex/experimental/_kernel_dpcpp_spirv_overloads/_index_space_id_overloads.py b/numba_dpex/experimental/_kernel_dpcpp_spirv_overloads/_index_space_id_overloads.py index d574f6aced..a99781e53f 100644 --- a/numba_dpex/experimental/_kernel_dpcpp_spirv_overloads/_index_space_id_overloads.py +++ b/numba_dpex/experimental/_kernel_dpcpp_spirv_overloads/_index_space_id_overloads.py @@ -9,7 +9,7 @@ import llvmlite.ir as llvmir from numba.core import cgutils, types from numba.core.errors import TypingError -from numba.extending import intrinsic, overload_method +from numba.extending import intrinsic, overload_attribute, overload_method from numba_dpex.core.types.kernel_api.index_space_ids import ( GroupType, @@ -248,3 +248,24 @@ def ol_nd_item_get_group_impl(nd_item): return _intrinsic_get_group(nd_item) return ol_nd_item_get_group_impl + + +@overload_attribute(GroupType, "dimensions", target=DPEX_KERNEL_EXP_TARGET_NAME) +@overload_attribute(ItemType, "dimensions", target=DPEX_KERNEL_EXP_TARGET_NAME) +@overload_attribute( + NdItemType, "dimensions", target=DPEX_KERNEL_EXP_TARGET_NAME +) +def ol_nd_item_dimensions(item): + """ + SPIR-V overload for :meth:`numba_dpex.kernel_api..dimensions`. + + Generates the same LLVM IR instruction as dpcpp for the + `sycl::::dimensions` attribute. + """ + dimensions = item.ndim + + # pylint: disable=unused-argument + def ol_nd_item_get_group_impl(item): + return dimensions + + return ol_nd_item_get_group_impl diff --git a/numba_dpex/tests/experimental/test_index_space_ids.py b/numba_dpex/tests/experimental/test_index_space_ids.py index 2d1edb54f2..887ce6584e 100644 --- a/numba_dpex/tests/experimental/test_index_space_ids.py +++ b/numba_dpex/tests/experimental/test_index_space_ids.py @@ -63,6 +63,24 @@ def set_local_ones_nd_item(nd_item: NdItem, a): a[i] = 1 +@dpex_exp.kernel +def set_dimensions_item(item: Item, a): + i = item.get_id(0) + a[i] = item.dimensions + + +@dpex_exp.kernel +def set_dimensions_nd_item(nd_item: NdItem, a): + i = nd_item.get_global_id(0) + a[i] = nd_item.dimensions + + +@dpex_exp.kernel +def set_dimensions_group(nd_item: NdItem, a): + i = nd_item.get_global_id(0) + a[i] = nd_item.get_group().dimensions + + def _get_group_id_driver(nditem: NdItem, a): i = nditem.get_global_id(0) g = nditem.get_group() @@ -149,6 +167,29 @@ def test_nd_item_get_local_id(): ) +@pytest.mark.parametrize("dims", [1, 2, 3]) +def test_item_dimensions(dims): + a = dpnp.zeros(_SIZE, dtype=dpnp.float32) + rng = [1] * dims + rng[0] = a.size + dpex_exp.call_kernel(set_dimensions_item, dpex.Range(*rng), a) + + assert np.array_equal(a.asnumpy(), dims * np.ones(a.size, dtype=np.float32)) + + +@pytest.mark.parametrize("dims", [1, 2, 3]) +@pytest.mark.parametrize( + "kernel", [set_dimensions_nd_item, set_dimensions_group] +) +def test_nd_item_dimensions(dims, kernel): + a = dpnp.zeros(_SIZE, dtype=dpnp.float32) + rng, grp = [1] * dims, [1] * dims + rng[0], grp[0] = a.size, _GROUP_SIZE + dpex_exp.call_kernel(kernel, dpex.NdRange(rng, grp), a) + + assert np.array_equal(a.asnumpy(), dims * np.ones(a.size, dtype=np.float32)) + + def test_error_item_get_global_id(): a = dpnp.zeros(_SIZE, dtype=dpnp.float32)