From 90ade4c388877f552270954ea72daf6bf99e5957 Mon Sep 17 00:00:00 2001 From: Yevhenii Havrylko Date: Tue, 19 Mar 2024 12:45:42 -0400 Subject: [PATCH] Add local accessor multidimentional tests --- .../spv_overloads/test_local_accessors.py | 66 ++++++++++++++++--- 1 file changed, 57 insertions(+), 9 deletions(-) diff --git a/numba_dpex/tests/experimental/kernel_api_overloads/spv_overloads/test_local_accessors.py b/numba_dpex/tests/experimental/kernel_api_overloads/spv_overloads/test_local_accessors.py index 1ce33c63f5..d6cfb98a2f 100644 --- a/numba_dpex/tests/experimental/kernel_api_overloads/spv_overloads/test_local_accessors.py +++ b/numba_dpex/tests/experimental/kernel_api_overloads/spv_overloads/test_local_accessors.py @@ -23,34 +23,82 @@ @dpex_exp.kernel -def _kernel(nd_item: NdItem, a, slm): +def _kernel1(nd_item: NdItem, a, slm): i = nd_item.get_global_linear_id() - j = nd_item.get_local_linear_id() - slm[j] = 0 + # TODO: overload nd_item.get_local_id() + j = (nd_item.get_local_id(0),) + + slm[*j] = 0 group_barrier(nd_item.get_group(), MemoryScope.WORK_GROUP) for m in range(100): - slm[j] += i * m + slm[*j] += i * m group_barrier(nd_item.get_group(), MemoryScope.WORK_GROUP) - a[i] = slm[j] + a[i] = slm[*j] + + +@dpex_exp.kernel +def _kernel2(nd_item: NdItem, a, slm): + i = nd_item.get_global_linear_id() + + # TODO: overload nd_item.get_local_id() + j = (nd_item.get_local_id(0), nd_item.get_local_id(1)) + + slm[*j] = 0 + group_barrier(nd_item.get_group(), MemoryScope.WORK_GROUP) + + for m in range(100): + slm[*j] += i * m + group_barrier(nd_item.get_group(), MemoryScope.WORK_GROUP) + + a[i] = slm[*j] + + +@dpex_exp.kernel +def _kernel3(nd_item: NdItem, a, slm): + i = nd_item.get_global_linear_id() + + # TODO: overload nd_item.get_local_id() + j = ( + nd_item.get_local_id(0), + nd_item.get_local_id(1), + nd_item.get_local_id(2), + ) + + slm[*j] = 0 + group_barrier(nd_item.get_group(), MemoryScope.WORK_GROUP) + + for m in range(100): + slm[*j] += i * m + group_barrier(nd_item.get_group(), MemoryScope.WORK_GROUP) + + a[i] = slm[*j] @pytest.mark.parametrize("supported_dtype", list_of_supported_dtypes) -def test_local_accessor(supported_dtype): +@pytest.mark.parametrize( + "nd_range, _kernel", + [ + (dpex.NdRange((32,), (32,)), _kernel1), + (dpex.NdRange((32, 1), (32, 1)), _kernel2), + (dpex.NdRange((1, 32, 1), (1, 32, 1)), _kernel3), + ], +) +def test_local_accessor(supported_dtype, nd_range: dpex.NdRange, _kernel): """A test for passing a LocalAccessor object as a kernel argument.""" N = 32 a = dpnp.empty(N, dtype=supported_dtype) - slm = LocalAccessor((32 * 64), dtype=a.dtype) + slm = LocalAccessor(nd_range.local_range, dtype=a.dtype) # A single work group with 32 work items is launched. Each work item # computes the sum of (0..99) * its get_global_linear_id i.e., # `4950 * get_global_linear_id` and stores it into the work groups local # memory. The local memory is of size 32*64 elements of the requested dtype. # The result is then stored into `a` in global memory - dpex_exp.call_kernel(_kernel, dpex.NdRange((N,), (32,)), a, slm) + dpex_exp.call_kernel(_kernel, nd_range, a, slm) for idx in range(N): assert a[idx] == 4950 * idx @@ -68,4 +116,4 @@ def test_local_accessor_argument_to_range_kernel(): # A TypeError is raised if NUMBA_CAPTURED_ERROR=new_style and a # numba.TypingError is raised if NUMBA_CAPTURED_ERROR=old_style with pytest.raises((TypeError, TypingError)): - dpex_exp.call_kernel(_kernel, dpex.Range(N), a, slm) + dpex_exp.call_kernel(_kernel1, dpex.Range(N), a, slm)