Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/overload dimensions attribute for indexers #1359

Merged
merged 3 commits into from
Feb 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import llvmlite.ir as llvmir
from numba.core import cgutils, types
from numba.core.errors import TypingError
from numba.extending import intrinsic, overload_method
from numba.extending import intrinsic, overload_attribute, overload_method

from numba_dpex.core.types.kernel_api.index_space_ids import (
GroupType,
Expand Down Expand Up @@ -248,3 +248,24 @@ def ol_nd_item_get_group_impl(nd_item):
return _intrinsic_get_group(nd_item)

return ol_nd_item_get_group_impl


@overload_attribute(GroupType, "dimensions", target=DPEX_KERNEL_EXP_TARGET_NAME)
@overload_attribute(ItemType, "dimensions", target=DPEX_KERNEL_EXP_TARGET_NAME)
@overload_attribute(
NdItemType, "dimensions", target=DPEX_KERNEL_EXP_TARGET_NAME
)
def ol_nd_item_dimensions(item):
"""
SPIR-V overload for :meth:`numba_dpex.kernel_api.<generic_item>.dimensions`.

Generates the same LLVM IR instruction as dpcpp for the
`sycl::<generic_item>::dimensions` attribute.
"""
dimensions = item.ndim

# pylint: disable=unused-argument
def ol_nd_item_get_group_impl(item):
return dimensions

return ol_nd_item_get_group_impl
6 changes: 3 additions & 3 deletions numba_dpex/experimental/typeof.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,11 @@ def typeof_item(val: Item, c):
Returns: A numba_dpex.experimental.core.types.kernel_api.items.ItemType
instance.
"""
return ItemType(val.ndim)
return ItemType(val.dimensions)


@typeof_impl.register(NdItem)
def typeof_nditem(val, c):
def typeof_nditem(val: NdItem, c):
"""Registers the type inference implementation function for a
numba_dpex.kernel_api.NdItem PyObject.

Expand All @@ -83,4 +83,4 @@ def typeof_nditem(val, c):
Returns: A numba_dpex.experimental.core.types.kernel_api.items.NdItemType
instance.
"""
return NdItemType(val.ndim)
return NdItemType(val.dimensions)
14 changes: 11 additions & 3 deletions numba_dpex/kernel_api/index_space_ids.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,14 @@ def leader(self):
"""
return self._leader

@property
def dimensions(self) -> int:
"""Returns the rank of a Group object.
Returns:
int: Number of dimensions in the Group object
"""
return self._global_range.ndim

@leader.setter
def leader(self, work_item_id):
"""Sets the leader attribute for the group."""
Expand Down Expand Up @@ -147,7 +155,7 @@ def get_range(self, idx):
return self._extent[idx]

@property
def ndim(self) -> int:
def dimensions(self) -> int:
"""Returns the rank of a Item object.

Returns:
Expand Down Expand Up @@ -228,10 +236,10 @@ def get_group(self):
return self._group

@property
def ndim(self) -> int:
def dimensions(self) -> int:
"""Returns the rank of a NdItem object.

Returns:
int: Number of dimensions in the NdItem object
"""
return self._global_item.ndim
return self._global_item.dimensions
41 changes: 41 additions & 0 deletions numba_dpex/tests/experimental/test_index_space_ids.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,24 @@ def set_local_ones_nd_item(nd_item: NdItem, a):
a[i] = 1


@dpex_exp.kernel
def set_dimensions_item(item: Item, a):
i = item.get_id(0)
a[i] = item.dimensions


@dpex_exp.kernel
def set_dimensions_nd_item(nd_item: NdItem, a):
i = nd_item.get_global_id(0)
a[i] = nd_item.dimensions


@dpex_exp.kernel
def set_dimensions_group(nd_item: NdItem, a):
i = nd_item.get_global_id(0)
a[i] = nd_item.get_group().dimensions


def _get_group_id_driver(nditem: NdItem, a):
i = nditem.get_global_id(0)
g = nditem.get_group()
Expand Down Expand Up @@ -149,6 +167,29 @@ def test_nd_item_get_local_id():
)


@pytest.mark.parametrize("dims", [1, 2, 3])
def test_item_dimensions(dims):
a = dpnp.zeros(_SIZE, dtype=dpnp.float32)
rng = [1] * dims
rng[0] = a.size
dpex_exp.call_kernel(set_dimensions_item, dpex.Range(*rng), a)

assert np.array_equal(a.asnumpy(), dims * np.ones(a.size, dtype=np.float32))


@pytest.mark.parametrize("dims", [1, 2, 3])
@pytest.mark.parametrize(
"kernel", [set_dimensions_nd_item, set_dimensions_group]
)
def test_nd_item_dimensions(dims, kernel):
a = dpnp.zeros(_SIZE, dtype=dpnp.float32)
rng, grp = [1] * dims, [1] * dims
rng[0], grp[0] = a.size, _GROUP_SIZE
dpex_exp.call_kernel(kernel, dpex.NdRange(rng, grp), a)

assert np.array_equal(a.asnumpy(), dims * np.ones(a.size, dtype=np.float32))


def test_error_item_get_global_id():
a = dpnp.zeros(_SIZE, dtype=dpnp.float32)

Expand Down
Loading