Skip to content

Commit

Permalink
Add local accessor multidimentional tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ZzEeKkAa committed Mar 19, 2024
1 parent c086ae6 commit 90ade4c
Showing 1 changed file with 57 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,34 +23,82 @@


@dpex_exp.kernel
def _kernel(nd_item: NdItem, a, slm):
def _kernel1(nd_item: NdItem, a, slm):
i = nd_item.get_global_linear_id()
j = nd_item.get_local_linear_id()

slm[j] = 0
# TODO: overload nd_item.get_local_id()
j = (nd_item.get_local_id(0),)

slm[*j] = 0
group_barrier(nd_item.get_group(), MemoryScope.WORK_GROUP)

for m in range(100):
slm[j] += i * m
slm[*j] += i * m
group_barrier(nd_item.get_group(), MemoryScope.WORK_GROUP)

a[i] = slm[j]
a[i] = slm[*j]


@dpex_exp.kernel
def _kernel2(nd_item: NdItem, a, slm):
i = nd_item.get_global_linear_id()

# TODO: overload nd_item.get_local_id()
j = (nd_item.get_local_id(0), nd_item.get_local_id(1))

slm[*j] = 0
group_barrier(nd_item.get_group(), MemoryScope.WORK_GROUP)

for m in range(100):
slm[*j] += i * m
group_barrier(nd_item.get_group(), MemoryScope.WORK_GROUP)

a[i] = slm[*j]


@dpex_exp.kernel
def _kernel3(nd_item: NdItem, a, slm):
i = nd_item.get_global_linear_id()

# TODO: overload nd_item.get_local_id()
j = (
nd_item.get_local_id(0),
nd_item.get_local_id(1),
nd_item.get_local_id(2),
)

slm[*j] = 0
group_barrier(nd_item.get_group(), MemoryScope.WORK_GROUP)

for m in range(100):
slm[*j] += i * m
group_barrier(nd_item.get_group(), MemoryScope.WORK_GROUP)

a[i] = slm[*j]


@pytest.mark.parametrize("supported_dtype", list_of_supported_dtypes)
def test_local_accessor(supported_dtype):
@pytest.mark.parametrize(
"nd_range, _kernel",
[
(dpex.NdRange((32,), (32,)), _kernel1),
(dpex.NdRange((32, 1), (32, 1)), _kernel2),
(dpex.NdRange((1, 32, 1), (1, 32, 1)), _kernel3),
],
)
def test_local_accessor(supported_dtype, nd_range: dpex.NdRange, _kernel):
"""A test for passing a LocalAccessor object as a kernel argument."""

N = 32
a = dpnp.empty(N, dtype=supported_dtype)
slm = LocalAccessor((32 * 64), dtype=a.dtype)
slm = LocalAccessor(nd_range.local_range, dtype=a.dtype)

# A single work group with 32 work items is launched. Each work item
# computes the sum of (0..99) * its get_global_linear_id i.e.,
# `4950 * get_global_linear_id` and stores it into the work groups local
# memory. The local memory is of size 32*64 elements of the requested dtype.
# The result is then stored into `a` in global memory
dpex_exp.call_kernel(_kernel, dpex.NdRange((N,), (32,)), a, slm)
dpex_exp.call_kernel(_kernel, nd_range, a, slm)

for idx in range(N):
assert a[idx] == 4950 * idx
Expand All @@ -68,4 +116,4 @@ def test_local_accessor_argument_to_range_kernel():
# A TypeError is raised if NUMBA_CAPTURED_ERROR=new_style and a
# numba.TypingError is raised if NUMBA_CAPTURED_ERROR=old_style
with pytest.raises((TypeError, TypingError)):
dpex_exp.call_kernel(_kernel, dpex.Range(N), a, slm)
dpex_exp.call_kernel(_kernel1, dpex.Range(N), a, slm)

0 comments on commit 90ade4c

Please sign in to comment.