Skip to content

Commit

Permalink
Allow TopologicalSorter without error checking
Browse files Browse the repository at this point in the history
For representative examples in Qiskit, this reduces the runtime of
interacting with the topological sorter by around 15-20%.
  • Loading branch information
jakelishman committed Apr 5, 2024
1 parent 2979762 commit 2774414
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 23 deletions.
8 changes: 8 additions & 0 deletions releasenotes/notes/toposort-check-args-1378bab51e4172a3.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
---
features:
- |
:class:`.TopologicalSorter` now has a ``check_args`` keyword argument, which can be set to
``False`` to disable the runtime detection of invalid arguments to
:meth:`~.TopologicalSorter.done`. This provides a memory and runtime improvement to the online
sorter, at the cost that the results will be undefined and likely meaningless if invalid values
are given.
1 change: 1 addition & 0 deletions rustworkx/rustworkx.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,7 @@ class TopologicalSorter:
*,
reverse: bool = ...,
initial: Iterable[int] | None = ...,
check_args: bool = ...,
) -> None: ...
def is_active(self) -> bool: ...
def get_ready(self) -> list[int]: ...
Expand Down
61 changes: 38 additions & 23 deletions src/toposort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,17 @@ enum NodeState {
/// It is a :exc:`ValueError` to give an `initial` set where the nodes have even a partial
/// topological order between themselves, though this might not appear until some call
/// to :meth:`done`.
/// :param bool check_args: If ``True`` (the default), then all arguments to :meth:`done` are
/// checked for validity, and a :exc:`ValueError` is raised if any were not ready, already
/// done, or not indices of the circuit. If ``False``, the tracking for this is disabled,
/// which can provide a meaningful performance and memory improvement, but the results will
/// be undefined if invalid values are given.
#[pyclass(module = "rustworkx")]
pub struct TopologicalSorter {
dag: Py<PyDiGraph>,
ready_nodes: Vec<NodeIndex>,
predecessor_count: HashMap<NodeIndex, usize>,
node2state: HashMap<NodeIndex, NodeState>,
node2state: Option<HashMap<NodeIndex, NodeState>>,
num_passed_out: usize,
num_finished: usize,
in_dir: petgraph::Direction,
Expand All @@ -92,13 +97,14 @@ pub struct TopologicalSorter {
#[pymethods]
impl TopologicalSorter {
#[new]
#[pyo3(signature=(dag, /, check_cycle=true, *, reverse=false, initial=None))]
#[pyo3(signature=(dag, /, check_cycle=true, *, reverse=false, initial=None, check_args=true))]
fn new(
py: Python,
dag: Py<PyDiGraph>,
check_cycle: bool,
reverse: bool,
initial: Option<&Bound<PyAny>>,
check_args: bool,
) -> PyResult<Self> {
{
let dag = &dag.borrow(py);
Expand Down Expand Up @@ -144,7 +150,7 @@ impl TopologicalSorter {
dag,
ready_nodes,
predecessor_count,
node2state: HashMap::new(),
node2state: check_args.then(HashMap::new),
num_passed_out: 0,
num_finished: 0,
in_dir,
Expand Down Expand Up @@ -172,10 +178,17 @@ impl TopologicalSorter {
/// :rtype: List
fn get_ready<'py>(&mut self, py: Python<'py>) -> Bound<'py, PyList> {
self.num_passed_out += self.ready_nodes.len();
PyList::new_bound(py, self.ready_nodes.drain(..).map(|nx| {
self.node2state.insert(nx, NodeState::Ready);
nx.index()
}))
if let Some(node2state) = self.node2state.as_mut() {
PyList::new_bound(
py,
self.ready_nodes.drain(..).map(|nx| {
node2state.insert(nx, NodeState::Ready);
nx.index()
}),
)
} else {
PyList::new_bound(py, self.ready_nodes.drain(..).map(|nx| nx.index()))
}
}

/// Marks a set of nodes returned by "get_ready" as processed.
Expand Down Expand Up @@ -214,22 +227,24 @@ impl TopologicalSorter {
#[inline(always)]
fn done_single(&mut self, py: Python, node: NodeIndex) -> PyResult<()> {
let dag = self.dag.borrow(py);
match self.node2state.get_mut(&node) {
None => {
return Err(PyValueError::new_err(format!(
"node {} was not passed out (still not ready).",
node.index()
)));
}
Some(NodeState::Done) => {
return Err(PyValueError::new_err(format!(
"node {} was already marked done.",
node.index()
)));
}
Some(state) => {
debug_assert_eq!(*state, NodeState::Ready);
*state = NodeState::Done;
if let Some(node2state) = self.node2state.as_mut() {
match node2state.get_mut(&node) {
None => {
return Err(PyValueError::new_err(format!(
"node {} was not passed out (still not ready).",
node.index()
)));
}
Some(NodeState::Done) => {
return Err(PyValueError::new_err(format!(
"node {} was already marked done.",
node.index()
)));
}
Some(state) => {
debug_assert_eq!(*state, NodeState::Ready);
*state = NodeState::Done;
}
}
}

Expand Down

0 comments on commit 2774414

Please sign in to comment.