Skip to content

Commit

Permalink
Accept multidimentional atomic indexes
Browse files Browse the repository at this point in the history
  • Loading branch information
ZzEeKkAa committed Feb 27, 2024
1 parent c2cb985 commit a19cb7d
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,27 @@
)


def _normalize_indices(context, builder, indty, inds, aryty):
"""
Convert integer indices into tuple of intp
"""
if indty in types.integer_domain:
indty = types.UniTuple(dtype=indty, count=1)
indices = [inds]
else:
indices = cgutils.unpack_tuple(builder, inds, count=len(indty))
indices = [
context.cast(builder, i, t, types.intp) for t, i in zip(indty, indices)
]

if aryty.ndim != len(indty):
raise TypeError(
f"indexing {aryty.ndim}-D array with {len(indty)}-D index"
)

return indty, indices


def _parse_enum_or_int_literal_(literal_int) -> int:
"""Parse an instance of an enum class or numba.core.types.Literal to its
actual int value.
Expand Down Expand Up @@ -208,23 +229,22 @@ def _intrinsic_atomic_ref_ctor(
sig = ty_retty(ref, ty_index, ty_retty_ref)

def codegen(context, builder, sig, args):
ref = args[0]
index_pos = args[1]
aryty, indty, _ = sig.args
ary, inds, _ = args

dmm = context.data_model_manager
data_attr_pos = dmm.lookup(sig.args[0]).get_field_position("data")
data_attr = builder.extract_value(ref, data_attr_pos)
indty, indices = _normalize_indices(
context, builder, indty, inds, aryty
)

with builder.goto_entry_block():
ptr_to_data_attr = builder.alloca(data_attr.type)
builder.store(data_attr, ptr_to_data_attr)
ref_ptr_value = builder.gep(builder.load(ptr_to_data_attr), [index_pos])
lary = context.make_array(aryty)(context, builder, ary)
ref_ptr_value = cgutils.get_item_pointer(
context, builder, aryty, lary, indices, wraparound=True
)

atomic_ref_struct = cgutils.create_struct_proxy(ty_retty)(
context, builder
)
ref_attr_pos = dmm.lookup(ty_retty).get_field_position("ref")
atomic_ref_struct[ref_attr_pos] = ref_ptr_value
atomic_ref_struct.ref = ref_ptr_value
# pylint: disable=protected-access
return atomic_ref_struct._getvalue()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# SPDX-License-Identifier: Apache-2.0

import dpnp
import numpy as np
import pytest
from numba.core.errors import TypingError

Expand All @@ -25,6 +26,27 @@ def atomic_ref_kernel(item: Item, a, b):
pytest.fail("Unexpected execution failure")


def test_atomic_ref_3_dim_compilation():
@dpex_exp.kernel
def atomic_ref_kernel(item: Item, a, b):
i = item.get_id(0)
v = AtomicRef(b, index=(1, 1, 1), address_space=AddressSpace.GLOBAL)
v.fetch_add(a[i])

a = dpnp.ones(8)
b = dpnp.zeros((2, 2, 2))

want = np.zeros((2, 2, 2))
want[1, 1, 1] = a.size

try:
dpex_exp.call_kernel(atomic_ref_kernel, Range(a.size), a, b)
except Exception:
pytest.fail("Unexpected execution failure")

assert np.array_equal(b.asnumpy(), want)


def test_atomic_ref_compilation_failure():
"""A negative test that verifies that a TypingError is raised if we try to
create an AtomicRef in the local address space from a global address space
Expand Down

0 comments on commit a19cb7d

Please sign in to comment.