Skip to content

Commit

Permalink
Merge pull request #1391 from IntelPython/device_func_unit_tests
Browse files Browse the repository at this point in the history
Verify global_barrier, indexing, private array inside device_func
  • Loading branch information
Diptorup Deb authored Mar 19, 2024
2 parents 97f9069 + d89104d commit 202f460
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,29 @@ def _kernel(nd_item: NdItem, a):
dpex_exp.call_kernel(_kernel, dpex.NdRange((N,), (N,)), a)

assert a[0] == N * 2


def test_group_barrier_device_func():
"""A test for group_barrier function."""

@dpex_exp.device_func
def _increment_value(nd_item: NdItem, a):
i = nd_item.get_global_id(0)

a[i] += 1
group_barrier(nd_item.get_group(), MemoryScope.DEVICE)

if i == 0:
for idx in range(1, a.size):
a[0] += a[idx]

@dpex_exp.kernel
def _kernel(nd_item: NdItem, a):
_increment_value(nd_item, a)

N = 16
a = dpnp.ones(N, dtype=dpnp.int32)

dpex_exp.call_kernel(_kernel, dpex.NdRange((N,), (N,)), a)

assert a[0] == N * 2
26 changes: 26 additions & 0 deletions numba_dpex/tests/experimental/test_private_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,29 @@ def test_private_array(call_kernel, decorator, kernel):
want = np.full(a.size, (9) * (9 + 1) * (2 * 9 + 1) / 6, dtype=np.float32)

assert np.array_equal(want, a.asnumpy())


@pytest.mark.parametrize(
"func",
[
private_array_kernel,
private_array_kernel_fill_true,
private_array_kernel_fill_false,
private_2d_array_kernel,
],
)
def test_private_array_in_device_func(func):

_df = dpex_exp.device_func(func)

@dpex_exp.kernel
def _kernel(item: Item, a):
_df(item, a)

a = dpnp.empty(10, dtype=dpnp.float32)
dpex_exp.call_kernel(_kernel, Range(a.size), a)

# sum of squares from 1 to n: n*(n+1)*(2*n+1)/6
want = np.full(a.size, (9) * (9 + 1) * (2 * 9 + 1) / 6, dtype=np.float32)

assert np.array_equal(want, a.asnumpy())

0 comments on commit 202f460

Please sign in to comment.