From d6edcd77e7679791ae5ab910d13ccccf9f8ca914 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Borgna?= <121866228+aborgna-q@users.noreply.github.com> Date: Wed, 21 Aug 2024 16:57:02 +0100 Subject: [PATCH] fix(py): Invalid node indexing (#1457) Fixes #1454, and other failing cases --- hugr-py/src/hugr/node_port.py | 44 +++++++++++++++++++++++++++++++---- hugr-py/tests/test_nodes.py | 37 +++++++++++++++++++++++++++++ 2 files changed, 77 insertions(+), 4 deletions(-) create mode 100644 hugr-py/tests/test_nodes.py diff --git a/hugr-py/src/hugr/node_port.py b/hugr-py/src/hugr/node_port.py index 0f268fe88..556692aca 100644 --- a/hugr-py/src/hugr/node_port.py +++ b/hugr-py/src/hugr/node_port.py @@ -164,24 +164,60 @@ def _index( ) -> OutPort | Iterator[OutPort]: match index: case PortOffset(index): - if self._num_out_ports is not None and index >= self._num_out_ports: - msg = "Index out of range" - raise IndexError(msg) + index = self._normalize_index(index) return self.out(index) case slice(): start = index.start or 0 - stop = index.stop or self._num_out_ports + stop = index.stop if index.stop is not None else self._num_out_ports if stop is None: msg = ( f"{self} does not have a fixed number of output ports. " "Iterating over all output ports is not supported." ) raise ValueError(msg) + + start = self._normalize_index(start) + stop = self._normalize_index(stop, allow_eq_len=True) step = index.step or 1 + return (self[i] for i in range(start, stop, step)) case tuple(xs): return (self[i] for i in xs) + def _normalize_index(self, index: int, allow_eq_len: bool = False) -> int: + """Given an index passed to `__getitem__`, normalize it to be within the + range of output ports. + + Args: + index: index to normalize. + allow_eq_len: whether to allow the index to be equal to the number of + output ports. + + Returns: + Normalized index. + + Raises: + IndexError: if the index is out of range. + """ + msg = f"Index {index} out of range" + + if self._num_out_ports is not None: + if index > self._num_out_ports: + raise IndexError(msg) + if index == self._num_out_ports and not allow_eq_len: + raise IndexError(msg) + if index < -self._num_out_ports: + raise IndexError(msg) + else: + if index < 0: + raise IndexError(msg) + + if index >= 0: + return index + else: + assert self._num_out_ports is not None + return self._num_out_ports + index + def to_node(self) -> Node: return self diff --git a/hugr-py/tests/test_nodes.py b/hugr-py/tests/test_nodes.py new file mode 100644 index 000000000..7fc1d30db --- /dev/null +++ b/hugr-py/tests/test_nodes.py @@ -0,0 +1,37 @@ +import pytest + +from hugr.node_port import Node, OutPort + + +def test_index(): + n = Node(0, _num_out_ports=3) + assert n[0] == OutPort(n, 0) + assert n[1] == OutPort(n, 1) + assert n[2] == OutPort(n, 2) + assert n[-1] == OutPort(n, 2) + + with pytest.raises(IndexError, match="Index 3 out of range"): + _ = n[3] + + with pytest.raises(IndexError, match="Index -8 out of range"): + _ = n[-8] + + +def test_slices(): + n = Node(0, _num_out_ports=3) + all_ports = [OutPort(n, i) for i in range(3)] + + assert list(n) == all_ports + assert list(n[:0]) == [] + assert list(n[0:0]) == [] + assert list(n[0:1]) == [OutPort(n, 0)] + assert list(n[1:2]) == [OutPort(n, 1)] + assert list(n[:]) == all_ports + assert list(n[0:]) == all_ports + assert list(n[:3]) == all_ports + assert list(n[0:3]) == all_ports + assert list(n[-1:]) == [OutPort(n, 2)] + assert list(n[-3:]) == all_ports + + with pytest.raises(IndexError, match="Index -4 out of range"): + _ = n[-4:]