Skip to content

Commit

Permalink
fix: IndexError on node slicing (#1500)
Browse files Browse the repository at this point in the history
Fixes #1498 

Python lets you slice out of range; `[0,1][90:99]` returns an empty
list. This PR makes node slicing behave as such.
  • Loading branch information
aborgna-q committed Sep 3, 2024
1 parent 199c18b commit a32bd84
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 9 deletions.
18 changes: 9 additions & 9 deletions hugr-py/src/hugr/hugr/node_port.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,22 +176,22 @@ def _index(
)
raise ValueError(msg)

start = self._normalize_index(start)
stop = self._normalize_index(stop, allow_eq_len=True)
start = self._normalize_index(start, allow_overflow=True)
stop = self._normalize_index(stop, allow_overflow=True)
step = index.step or 1

return (self[i] for i in range(start, stop, step))
case tuple(xs):
return (self[i] for i in xs)

def _normalize_index(self, index: int, allow_eq_len: bool = False) -> int:
def _normalize_index(self, index: int, allow_overflow: bool = False) -> int:
"""Given an index passed to `__getitem__`, normalize it to be within the
range of output ports.
Args:
index: index to normalize.
allow_eq_len: whether to allow the index to be equal to the number of
output ports.
allow_overflow: whether to allow indices beyond the number of outputs.
If True, indices over `self._num_out_ports` will be truncated.
Returns:
Normalized index.
Expand All @@ -202,17 +202,17 @@ def _normalize_index(self, index: int, allow_eq_len: bool = False) -> int:
msg = f"Index {index} out of range"

if self._num_out_ports is not None:
if index > self._num_out_ports:
raise IndexError(msg)
if index == self._num_out_ports and not allow_eq_len:
if index >= self._num_out_ports and not allow_overflow:
raise IndexError(msg)
if index < -self._num_out_ports:
raise IndexError(msg)
else:
if index < 0:
raise IndexError(msg)

if index >= 0:
if index >= 0 and self._num_out_ports is not None:
return min(index, self._num_out_ports)
elif index >= 0:
return index
else:
assert self._num_out_ports is not None
Expand Down
11 changes: 11 additions & 0 deletions hugr-py/tests/test_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,19 @@ def test_slices():
assert list(n[0:]) == all_ports
assert list(n[:3]) == all_ports
assert list(n[0:3]) == all_ports
assert list(n[0:999]) == all_ports
assert list(n[999:1000]) == []
assert list(n[-1:]) == [OutPort(n, 2)]
assert list(n[-3:]) == all_ports

with pytest.raises(IndexError, match="Index -4 out of range"):
_ = n[-4:]

n0 = Node(0, _num_out_ports=0)
assert list(n0) == []
assert list(n0[:0]) == []
assert list(n0[:10]) == []
assert list(n0[0:0]) == []
assert list(n0[0:]) == []
assert list(n0[10:]) == []
assert list(n0[:]) == []

0 comments on commit a32bd84

Please sign in to comment.