Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Accept multidimentional atomic indexes #1367

Merged
merged 3 commits into from
Feb 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,27 @@
)


def _normalize_indices(context, builder, indty, inds, aryty):
"""
Convert integer indices into tuple of intp
"""
if indty in types.integer_domain:
indty = types.UniTuple(dtype=indty, count=1)
indices = [inds]
else:
indices = cgutils.unpack_tuple(builder, inds, count=len(indty))
indices = [
context.cast(builder, i, t, types.intp) for t, i in zip(indty, indices)
]

if aryty.ndim != len(indty):
raise TypeError(
f"indexing {aryty.ndim}-D array with {len(indty)}-D index"
)

return indty, indices


def _parse_enum_or_int_literal_(literal_int) -> int:
"""Parse an instance of an enum class or numba.core.types.Literal to its
actual int value.
Expand Down Expand Up @@ -208,23 +229,22 @@ def _intrinsic_atomic_ref_ctor(
sig = ty_retty(ref, ty_index, ty_retty_ref)

def codegen(context, builder, sig, args):
ref = args[0]
index_pos = args[1]
aryty, indty, _ = sig.args
ary, inds, _ = args

dmm = context.data_model_manager
data_attr_pos = dmm.lookup(sig.args[0]).get_field_position("data")
data_attr = builder.extract_value(ref, data_attr_pos)
indty, indices = _normalize_indices(
context, builder, indty, inds, aryty
)

with builder.goto_entry_block():
ptr_to_data_attr = builder.alloca(data_attr.type)
builder.store(data_attr, ptr_to_data_attr)
ref_ptr_value = builder.gep(builder.load(ptr_to_data_attr), [index_pos])
lary = context.make_array(aryty)(context, builder, ary)
ref_ptr_value = cgutils.get_item_pointer(
context, builder, aryty, lary, indices, wraparound=True
)

atomic_ref_struct = cgutils.create_struct_proxy(ty_retty)(
context, builder
)
ref_attr_pos = dmm.lookup(ty_retty).get_field_position("ref")
atomic_ref_struct[ref_attr_pos] = ref_ptr_value
atomic_ref_struct.ref = ref_ptr_value
# pylint: disable=protected-access
return atomic_ref_struct._getvalue()

Expand Down Expand Up @@ -564,7 +584,7 @@ def _check_if_supported_ref(ref):
)
def ol_atomic_ref(
ref,
index=0,
index,
memory_order=MemoryOrder.RELAXED,
memory_scope=MemoryScope.DEVICE,
address_space=AddressSpace.GLOBAL,
Expand Down Expand Up @@ -635,7 +655,7 @@ def ol_atomic_ref(

def ol_atomic_ref_ctor_impl(
ref,
index=0,
index,
memory_order=MemoryOrder.RELAXED, # pylint: disable=unused-argument
memory_scope=MemoryScope.DEVICE, # pylint: disable=unused-argument
address_space=AddressSpace.GLOBAL, # pylint: disable=unused-argument
Expand Down
2 changes: 1 addition & 1 deletion numba_dpex/kernel_api/atomic_ref.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class AtomicRef:
def __init__( # pylint: disable=too-many-arguments
self,
ref,
index=0,
index,
memory_order=MemoryOrder.RELAXED,
memory_scope=MemoryScope.DEVICE,
address_space=AddressSpace.GLOBAL,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,16 @@
MemoryScope,
atomic_fence,
)
from numba_dpex.tests._helper import skip_windows


# TODO: https://github.com/IntelPython/numba-dpex/issues/1308
@skip_windows
def test_atomic_fence():
"""A test for atomic_fence function."""

@dpex_exp.kernel
def _kernel(item: Item, a, b):
i = item.get_id(0)

bref = AtomicRef(b)
bref = AtomicRef(b, index=0)

if i == 1:
a[i] += 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# SPDX-License-Identifier: Apache-2.0

import dpnp
import numpy as np
import pytest
from numba.core.errors import TypingError

Expand All @@ -25,6 +26,27 @@ def atomic_ref_kernel(item: Item, a, b):
pytest.fail("Unexpected execution failure")


def test_atomic_ref_3_dim_compilation():
@dpex_exp.kernel
def atomic_ref_kernel(item: Item, a, b):
i = item.get_id(0)
v = AtomicRef(b, index=(1, 1, 1), address_space=AddressSpace.GLOBAL)
v.fetch_add(a[i])

a = dpnp.ones(8)
b = dpnp.zeros((2, 2, 2))

want = np.zeros((2, 2, 2))
want[1, 1, 1] = a.size

try:
dpex_exp.call_kernel(atomic_ref_kernel, Range(a.size), a, b)
except Exception:
pytest.fail("Unexpected execution failure")

assert np.array_equal(b.asnumpy(), want)


def test_atomic_ref_compilation_failure():
"""A negative test that verifies that a TypingError is raised if we try to
create an AtomicRef in the local address space from a global address space
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,8 @@
import numba_dpex as dpex
import numba_dpex.experimental as dpex_exp
from numba_dpex.kernel_api import MemoryScope, NdItem, group_barrier
from numba_dpex.tests._helper import skip_windows


# TODO: https://github.com/IntelPython/numba-dpex/issues/1308
@skip_windows
def test_group_barrier():
"""A test for group_barrier function."""

Expand Down
Loading