Skip to content

Commit

Permalink
fix: normalise slice
Browse files Browse the repository at this point in the history
  • Loading branch information
agoose77 committed Dec 19, 2023
1 parent d6ca283 commit 4803427
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 32 deletions.
39 changes: 8 additions & 31 deletions src/awkward/_slicing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@

from __future__ import annotations

import operator

import awkward as ak
from awkward._backends.backend import Backend
from awkward._nplikes import to_nplike
Expand All @@ -15,7 +13,7 @@
from awkward._typing import TYPE_CHECKING, Sequence, TypeAlias, TypeVar

if TYPE_CHECKING:
from awkward._nplikes.numpy_like import ArrayLike
from awkward._nplikes.numpy_like import ArrayLike, NumpyLike
from awkward.contents.content import Content

np = NumpyMetadata.instance()
Expand All @@ -24,48 +22,27 @@
SliceItem: TypeAlias = "int | slice | str | None | Ellipsis | ArrayLike | Content"


def normalize_slice_item(item, *, backend: Backend):
if backend.index_nplike.is_own_array(item):
if item.ndim != 0:
raise ValueError(
f"slice items must be 0D arrays or Python integers, not {item!r}"
)
else:
return item
else:
return operator.index(item)


def normalize_slice(slice_: slice, *, backend: Backend) -> slice:
def normalize_slice(slice_: slice, *, nplike: NumpyLike) -> slice:
"""
Args:
slice_: slice object
backend: backend of layout
nplike: NumpyLike of array
Return a slice of (start, stop, step) for which the slice items have been
normalized into index types.
"""
index_nplike = backend.index_nplike

start = slice_.start
stop = slice_.stop
step = slice_.step

if index_nplike.known_data:
if nplike.known_data:
return slice_
# Unknown lengths mean that the slice index is unknown
else:
start = (
index_nplike.shape_item_as_index(start)
if start is unknown_length
else start
)
stop = (
index_nplike.shape_item_as_index(stop) if stop is unknown_length else stop
)
step = (
index_nplike.shape_item_as_index(step) if step is unknown_length else step
)
start = nplike.shape_item_as_index(start) if start is unknown_length else start
stop = nplike.shape_item_as_index(stop) if stop is unknown_length else stop
step = nplike.shape_item_as_index(step) if step is unknown_length else step

return slice(start, stop, step)

Expand Down Expand Up @@ -226,7 +203,7 @@ def normalise_item(item, backend: Backend) -> SliceItem:
return normalize_integer_like(item)

elif isinstance(item, slice):
return normalize_slice(item, backend=backend)
return normalize_slice(item, nplike=backend.index_nplike)

elif isinstance(item, str):
return item
Expand Down
2 changes: 1 addition & 1 deletion src/awkward/contents/content.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,7 @@ def _getitem(self, where):
elif isinstance(where, slice) and where.step is None:
# Ensure that start, stop are non-negative!
start, stop, _, _ = self._backend.index_nplike.derive_slice_for_length(
normalize_slice(where, backend=self._backend), self.length
normalize_slice(where, nplike=self._backend.index_nplike), self.length
)
return self._getitem_range(start, stop)

Expand Down
4 changes: 4 additions & 0 deletions src/awkward/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from awkward._nplikes.numpy_like import NumpyLike, NumpyMetadata
from awkward._nplikes.shape import ShapeItem
from awkward._nplikes.typetracer import TypeTracer
from awkward._slicing import normalize_slice
from awkward._typing import Any, DType, Final, Self, cast

np: Final = NumpyMetadata.instance()
Expand Down Expand Up @@ -214,6 +215,9 @@ def form(self) -> str:
return _dtype_to_form[self._data.dtype]

def __getitem__(self, where):
if isinstance(where, slice):
where = normalize_slice(where, nplike=self.nplike)

out = self._data[where]

if hasattr(out, "shape") and len(out.shape) != 0:
Expand Down

0 comments on commit 4803427

Please sign in to comment.