From 2774414a0138129fa05dbfcdde0e25d6bb0ee4f3 Mon Sep 17 00:00:00 2001 From: Jake Lishman Date: Fri, 5 Apr 2024 18:50:38 +0100 Subject: [PATCH] Allow `TopologicalSorter` without error checking For representative examples in Qiskit, this reduces the runtime of interacting with the topological sorter by around 15-20%. --- .../toposort-check-args-1378bab51e4172a3.yaml | 8 +++ rustworkx/rustworkx.pyi | 1 + src/toposort.rs | 61 ++++++++++++------- 3 files changed, 47 insertions(+), 23 deletions(-) create mode 100644 releasenotes/notes/toposort-check-args-1378bab51e4172a3.yaml diff --git a/releasenotes/notes/toposort-check-args-1378bab51e4172a3.yaml b/releasenotes/notes/toposort-check-args-1378bab51e4172a3.yaml new file mode 100644 index 0000000000..e491bc5680 --- /dev/null +++ b/releasenotes/notes/toposort-check-args-1378bab51e4172a3.yaml @@ -0,0 +1,8 @@ +--- +features: + - | + :class:`.TopologicalSorter` now has a ``check_args`` keyword argument, which can be set to + ``False`` to disable the runtime detection of invalid arguments to + :meth:`~.TopologicalSorter.done`. This provides a memory and runtime improvement to the online + sorter, at the cost that the results will be undefined and likely meaningless if invalid values + are given. diff --git a/rustworkx/rustworkx.pyi b/rustworkx/rustworkx.pyi index 16cdbab217..32f2e26d32 100644 --- a/rustworkx/rustworkx.pyi +++ b/rustworkx/rustworkx.pyi @@ -292,6 +292,7 @@ class TopologicalSorter: *, reverse: bool = ..., initial: Iterable[int] | None = ..., + check_args: bool = ..., ) -> None: ... def is_active(self) -> bool: ... def get_ready(self) -> list[int]: ... diff --git a/src/toposort.rs b/src/toposort.rs index 9983c90a3a..cc518d814a 100644 --- a/src/toposort.rs +++ b/src/toposort.rs @@ -77,12 +77,17 @@ enum NodeState { /// It is a :exc:`ValueError` to give an `initial` set where the nodes have even a partial /// topological order between themselves, though this might not appear until some call /// to :meth:`done`. +/// :param bool check_args: If ``True`` (the default), then all arguments to :meth:`done` are +/// checked for validity, and a :exc:`ValueError` is raised if any were not ready, already +/// done, or not indices of the circuit. If ``False``, the tracking for this is disabled, +/// which can provide a meaningful performance and memory improvement, but the results will +/// be undefined if invalid values are given. #[pyclass(module = "rustworkx")] pub struct TopologicalSorter { dag: Py, ready_nodes: Vec, predecessor_count: HashMap, - node2state: HashMap, + node2state: Option>, num_passed_out: usize, num_finished: usize, in_dir: petgraph::Direction, @@ -92,13 +97,14 @@ pub struct TopologicalSorter { #[pymethods] impl TopologicalSorter { #[new] - #[pyo3(signature=(dag, /, check_cycle=true, *, reverse=false, initial=None))] + #[pyo3(signature=(dag, /, check_cycle=true, *, reverse=false, initial=None, check_args=true))] fn new( py: Python, dag: Py, check_cycle: bool, reverse: bool, initial: Option<&Bound>, + check_args: bool, ) -> PyResult { { let dag = &dag.borrow(py); @@ -144,7 +150,7 @@ impl TopologicalSorter { dag, ready_nodes, predecessor_count, - node2state: HashMap::new(), + node2state: check_args.then(HashMap::new), num_passed_out: 0, num_finished: 0, in_dir, @@ -172,10 +178,17 @@ impl TopologicalSorter { /// :rtype: List fn get_ready<'py>(&mut self, py: Python<'py>) -> Bound<'py, PyList> { self.num_passed_out += self.ready_nodes.len(); - PyList::new_bound(py, self.ready_nodes.drain(..).map(|nx| { - self.node2state.insert(nx, NodeState::Ready); - nx.index() - })) + if let Some(node2state) = self.node2state.as_mut() { + PyList::new_bound( + py, + self.ready_nodes.drain(..).map(|nx| { + node2state.insert(nx, NodeState::Ready); + nx.index() + }), + ) + } else { + PyList::new_bound(py, self.ready_nodes.drain(..).map(|nx| nx.index())) + } } /// Marks a set of nodes returned by "get_ready" as processed. @@ -214,22 +227,24 @@ impl TopologicalSorter { #[inline(always)] fn done_single(&mut self, py: Python, node: NodeIndex) -> PyResult<()> { let dag = self.dag.borrow(py); - match self.node2state.get_mut(&node) { - None => { - return Err(PyValueError::new_err(format!( - "node {} was not passed out (still not ready).", - node.index() - ))); - } - Some(NodeState::Done) => { - return Err(PyValueError::new_err(format!( - "node {} was already marked done.", - node.index() - ))); - } - Some(state) => { - debug_assert_eq!(*state, NodeState::Ready); - *state = NodeState::Done; + if let Some(node2state) = self.node2state.as_mut() { + match node2state.get_mut(&node) { + None => { + return Err(PyValueError::new_err(format!( + "node {} was not passed out (still not ready).", + node.index() + ))); + } + Some(NodeState::Done) => { + return Err(PyValueError::new_err(format!( + "node {} was already marked done.", + node.index() + ))); + } + Some(state) => { + debug_assert_eq!(*state, NodeState::Ready); + *state = NodeState::Done; + } } }