diff --git a/numba_dpex/experimental/_kernel_dpcpp_spirv_overloads/_atomic_ref_overloads.py b/numba_dpex/experimental/_kernel_dpcpp_spirv_overloads/_atomic_ref_overloads.py index c9762c5653..aab3f238ce 100644 --- a/numba_dpex/experimental/_kernel_dpcpp_spirv_overloads/_atomic_ref_overloads.py +++ b/numba_dpex/experimental/_kernel_dpcpp_spirv_overloads/_atomic_ref_overloads.py @@ -41,6 +41,27 @@ ) +def _normalize_indices(context, builder, indty, inds, aryty): + """ + Convert integer indices into tuple of intp + """ + if indty in types.integer_domain: + indty = types.UniTuple(dtype=indty, count=1) + indices = [inds] + else: + indices = cgutils.unpack_tuple(builder, inds, count=len(indty)) + indices = [ + context.cast(builder, i, t, types.intp) for t, i in zip(indty, indices) + ] + + if aryty.ndim != len(indty): + raise TypeError( + f"indexing {aryty.ndim}-D array with {len(indty)}-D index" + ) + + return indty, indices + + def _parse_enum_or_int_literal_(literal_int) -> int: """Parse an instance of an enum class or numba.core.types.Literal to its actual int value. @@ -208,23 +229,22 @@ def _intrinsic_atomic_ref_ctor( sig = ty_retty(ref, ty_index, ty_retty_ref) def codegen(context, builder, sig, args): - ref = args[0] - index_pos = args[1] + aryty, indty, _ = sig.args + ary, inds, _ = args - dmm = context.data_model_manager - data_attr_pos = dmm.lookup(sig.args[0]).get_field_position("data") - data_attr = builder.extract_value(ref, data_attr_pos) + indty, indices = _normalize_indices( + context, builder, indty, inds, aryty + ) - with builder.goto_entry_block(): - ptr_to_data_attr = builder.alloca(data_attr.type) - builder.store(data_attr, ptr_to_data_attr) - ref_ptr_value = builder.gep(builder.load(ptr_to_data_attr), [index_pos]) + lary = context.make_array(aryty)(context, builder, ary) + ref_ptr_value = cgutils.get_item_pointer( + context, builder, aryty, lary, indices, wraparound=True + ) atomic_ref_struct = cgutils.create_struct_proxy(ty_retty)( context, builder ) - ref_attr_pos = dmm.lookup(ty_retty).get_field_position("ref") - atomic_ref_struct[ref_attr_pos] = ref_ptr_value + atomic_ref_struct.ref = ref_ptr_value # pylint: disable=protected-access return atomic_ref_struct._getvalue() @@ -564,7 +584,7 @@ def _check_if_supported_ref(ref): ) def ol_atomic_ref( ref, - index=0, + index, memory_order=MemoryOrder.RELAXED, memory_scope=MemoryScope.DEVICE, address_space=AddressSpace.GLOBAL, @@ -635,7 +655,7 @@ def ol_atomic_ref( def ol_atomic_ref_ctor_impl( ref, - index=0, + index, memory_order=MemoryOrder.RELAXED, # pylint: disable=unused-argument memory_scope=MemoryScope.DEVICE, # pylint: disable=unused-argument address_space=AddressSpace.GLOBAL, # pylint: disable=unused-argument diff --git a/numba_dpex/kernel_api/atomic_ref.py b/numba_dpex/kernel_api/atomic_ref.py index 7d9e426aa1..9759cecf95 100644 --- a/numba_dpex/kernel_api/atomic_ref.py +++ b/numba_dpex/kernel_api/atomic_ref.py @@ -18,7 +18,7 @@ class AtomicRef: def __init__( # pylint: disable=too-many-arguments self, ref, - index=0, + index, memory_order=MemoryOrder.RELAXED, memory_scope=MemoryScope.DEVICE, address_space=AddressSpace.GLOBAL, diff --git a/numba_dpex/tests/experimental/kernel_api_overloads/spv_overloads/test_atomic_fence.py b/numba_dpex/tests/experimental/kernel_api_overloads/spv_overloads/test_atomic_fence.py index 111c64f6d1..1502d5dc0d 100644 --- a/numba_dpex/tests/experimental/kernel_api_overloads/spv_overloads/test_atomic_fence.py +++ b/numba_dpex/tests/experimental/kernel_api_overloads/spv_overloads/test_atomic_fence.py @@ -9,11 +9,8 @@ MemoryScope, atomic_fence, ) -from numba_dpex.tests._helper import skip_windows -# TODO: https://github.com/IntelPython/numba-dpex/issues/1308 -@skip_windows def test_atomic_fence(): """A test for atomic_fence function.""" @@ -21,7 +18,7 @@ def test_atomic_fence(): def _kernel(item: Item, a, b): i = item.get_id(0) - bref = AtomicRef(b) + bref = AtomicRef(b, index=0) if i == 1: a[i] += 1 diff --git a/numba_dpex/tests/experimental/kernel_api_overloads/spv_overloads/test_atomic_ref.py b/numba_dpex/tests/experimental/kernel_api_overloads/spv_overloads/test_atomic_ref.py index d70f1aa03a..ff8222c5c0 100644 --- a/numba_dpex/tests/experimental/kernel_api_overloads/spv_overloads/test_atomic_ref.py +++ b/numba_dpex/tests/experimental/kernel_api_overloads/spv_overloads/test_atomic_ref.py @@ -3,6 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 import dpnp +import numpy as np import pytest from numba.core.errors import TypingError @@ -25,6 +26,27 @@ def atomic_ref_kernel(item: Item, a, b): pytest.fail("Unexpected execution failure") +def test_atomic_ref_3_dim_compilation(): + @dpex_exp.kernel + def atomic_ref_kernel(item: Item, a, b): + i = item.get_id(0) + v = AtomicRef(b, index=(1, 1, 1), address_space=AddressSpace.GLOBAL) + v.fetch_add(a[i]) + + a = dpnp.ones(8) + b = dpnp.zeros((2, 2, 2)) + + want = np.zeros((2, 2, 2)) + want[1, 1, 1] = a.size + + try: + dpex_exp.call_kernel(atomic_ref_kernel, Range(a.size), a, b) + except Exception: + pytest.fail("Unexpected execution failure") + + assert np.array_equal(b.asnumpy(), want) + + def test_atomic_ref_compilation_failure(): """A negative test that verifies that a TypingError is raised if we try to create an AtomicRef in the local address space from a global address space diff --git a/numba_dpex/tests/experimental/kernel_api_overloads/spv_overloads/test_barriers.py b/numba_dpex/tests/experimental/kernel_api_overloads/spv_overloads/test_barriers.py index 61be6cb320..a5bae8546b 100644 --- a/numba_dpex/tests/experimental/kernel_api_overloads/spv_overloads/test_barriers.py +++ b/numba_dpex/tests/experimental/kernel_api_overloads/spv_overloads/test_barriers.py @@ -3,11 +3,8 @@ import numba_dpex as dpex import numba_dpex.experimental as dpex_exp from numba_dpex.kernel_api import MemoryScope, NdItem, group_barrier -from numba_dpex.tests._helper import skip_windows -# TODO: https://github.com/IntelPython/numba-dpex/issues/1308 -@skip_windows def test_group_barrier(): """A test for group_barrier function."""