From a0df600511970150d9bde1b27079f410db87ba30 Mon Sep 17 00:00:00 2001 From: Jake Lishman Date: Mon, 5 Feb 2024 15:04:15 +0000 Subject: [PATCH 1/4] Add node-removal methods linear in node degree The existing method, `PyDiGraph.remove_node_retain_edges` is quadratic in node degree, because the ``condition`` function takes in pairs of edges. This poses a problem for extreme-degree nodes (for example massive barriers in Qiskit). This commit adds two methods based on making edge-retention decisions by hashable keys, making it linear in the degree of the node (at least; the MIMO broadcasting can make it quadratic again if all edges have the same key, but that's fundamental to the output, rather than the algorithm). The ideal situation (for performance) is that the edges can be disambiguated by Python object identity, which doesn't require Python-space calls to retrieve or hash, so can be in pure Rust. This is `remove_node_retain_edges_by_id`. The more general situation is that the user wants to supply a Python key function, which naturally returns a Python object that we need to use Python hashing and equality semantics for. This means using Python collections to do the tracking, which impacts the performance (very casual benchmarking using the implicit identity function as the key shows it's about 2x slower than using the identity). This method is `remove_node_retain_edges_by_key`. --- Cargo.lock | 1 + Cargo.toml | 1 + .../remode-node-by-key-9ec75b5cf589319e.yaml | 10 + rustworkx/rustworkx.pyi | 14 +- src/digraph.rs | 222 +++++++++++++++++- tests/digraph/test_nodes.py | 166 +++++++++++++ 6 files changed, 409 insertions(+), 5 deletions(-) create mode 100644 releasenotes/notes/remode-node-by-key-9ec75b5cf589319e.yaml diff --git a/Cargo.lock b/Cargo.lock index 8d0dde77a..b2976448a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -614,6 +614,7 @@ dependencies = [ "rustworkx-core", "serde", "serde_json", + "smallvec", "sprs", ] diff --git a/Cargo.toml b/Cargo.toml index dd86347d5..e0dd0dbc7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -57,6 +57,7 @@ rand_pcg.workspace = true rayon.workspace = true serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" +smallvec = { version = "1.0", features = ["union"] } rustworkx-core = { path = "rustworkx-core", version = "=0.15.0" } [dependencies.pyo3] diff --git a/releasenotes/notes/remode-node-by-key-9ec75b5cf589319e.yaml b/releasenotes/notes/remode-node-by-key-9ec75b5cf589319e.yaml new file mode 100644 index 000000000..fb228437f --- /dev/null +++ b/releasenotes/notes/remode-node-by-key-9ec75b5cf589319e.yaml @@ -0,0 +1,10 @@ +--- +features: + - | + Added the :meth:`.PyDiGraph.remove_node_retain_edges_by_id` and + :meth:`~.PyDiGraph.remove_node_retain_edges_by_key` methods, which provide a node-removal that + is linear in the degree of the node, as opposed to quadratic like + :meth:`~.PyDiGraph.remove_node_retain_edges`. These methods require, respectively, that the + edge weights are referentially identical if they should be retained (``a is b``, in Python), or + that you can supply a ``key`` function that produces a Python-hashable result that is used to + do the equality matching between input and output edges. diff --git a/rustworkx/rustworkx.pyi b/rustworkx/rustworkx.pyi index 51dceaa96..87152afc0 100644 --- a/rustworkx/rustworkx.pyi +++ b/rustworkx/rustworkx.pyi @@ -23,6 +23,7 @@ from typing import ( ValuesView, Mapping, overload, + Hashable, ) from abc import ABC from rustworkx import generators # noqa @@ -1325,8 +1326,17 @@ class PyDiGraph(Generic[_S, _T]): self, node: int, /, - use_outgoing: bool | None = ..., - condition: Callable[[_S, _S], bool] | None = ..., + use_outgoing: bool = ..., + condition: Callable[[_T, _T], bool] | None = ..., + ) -> None: ... + def remove_node_retain_edges_by_id(self, node: int, /) -> None: ... + def remove_node_retain_edges_by_key( + self, + node: int, + /, + key: Callable[[_T], Hashable] | None = ..., + *, + use_outgoing: bool = ..., ) -> None: ... def remove_nodes_from(self, index_list: Sequence[int], /) -> None: ... def subgraph( diff --git a/src/digraph.rs b/src/digraph.rs index 6e8cb53d5..11f6e4ef2 100644 --- a/src/digraph.rs +++ b/src/digraph.rs @@ -26,6 +26,8 @@ use indexmap::IndexSet; use rustworkx_core::dictmap::*; +use smallvec::SmallVec; + use pyo3::exceptions::PyIndexError; use pyo3::gc::PyVisit; use pyo3::prelude::*; @@ -988,6 +990,13 @@ impl PyDiGraph { /// By default the data/weight on edges into the removed node will be used /// for the retained edges. /// + /// This function has a minimum time complexity of :math:`\mathcal O(e_i e_o)`, where + /// :math:`e_i` and :math:`e_o` are the numbers of incoming and outgoing edges respectively. + /// If your ``condition`` can be cast as an equality between two hashable quantities, consider + /// using :meth:`remove_node_retain_edges_by_key` instead, or if your ``condition`` is + /// referential object identity of the edge weights, consider + /// :meth:`remove_node_retain_edges_by_id`. + /// /// :param int node: The index of the node to remove. If the index is not /// present in the graph it will be ingored and this function willl have /// no effect. @@ -1003,7 +1012,7 @@ impl PyDiGraph { /// /// would only retain edges if the input edge to ``node`` had the same /// data payload as the outgoing edge. - #[pyo3(text_signature = "(self, node, /, use_outgoing=None, condition=None)")] + #[pyo3(text_signature = "(self, node, /, use_outgoing=False, condition=None)")] #[pyo3(signature=(node, use_outgoing=false, condition=None))] pub fn remove_node_retain_edges( &mut self, @@ -1049,8 +1058,183 @@ impl PyDiGraph { for (source, target, weight) in edge_list { self._add_edge(source, target, weight)?; } - self.graph.remove_node(index); - self.node_removed = true; + self.node_removed = self.graph.remove_node(index).is_some(); + Ok(()) + } + + /// Remove a node from the graph and add edges from predecessors to successors in cases where + /// an incoming and outgoing edge have the same weight by Python object identity. + /// + /// This function has a minimum time complexity of :math:`\mathcal O(e_i + e_o)`, where + /// :math:`e_i` is the number of incoming edges and :math:`e_o` the number of outgoing edges + /// (the full complexity depends on the number of new edges to be created). + /// + /// Edges will be added between all pairs of predecessor and successor nodes that have the same + /// weight. As a consequence, any weight which appears only on predecessor edges will not + /// appear in the output, as there are no successors to pair it with. + /// + /// :param int node: The index of the node to remove. If the index is not present in the graph + /// it will be ignored and this function will have no effect. + #[pyo3(signature=(node, /))] + pub fn remove_node_retain_edges_by_id(&mut self, py: Python, node: usize) -> PyResult<()> { + // As many indices as will fit inline within the minimum inline size of a `SmallVec`. Many + // use cases of this likely only have one inbound and outbound edge with each id anyway. + const INLINE_SIZE: usize = + 2 * ::std::mem::size_of::() / ::std::mem::size_of::(); + let new_node_list = || SmallVec::<[NodeIndex; INLINE_SIZE]>::new(); + let node_index = NodeIndex::new(node); + let in_edges = { + let mut in_edges = HashMap::new(); + for edge in self + .graph + .edges_directed(node_index, petgraph::Direction::Incoming) + { + in_edges + .entry(PyAnyId(edge.weight().clone_ref(py))) + .or_insert_with(new_node_list) + .push(edge.source()); + } + in_edges + }; + let mut out_edges = { + let mut out_edges = HashMap::new(); + for edge in self + .graph + .edges_directed(node_index, petgraph::Direction::Outgoing) + { + out_edges + .entry(PyAnyId(edge.weight().clone_ref(py))) + .or_insert_with(new_node_list) + .push(edge.target()); + } + out_edges + }; + + for (weight, in_edges_subset) in in_edges { + let out_edges_subset = match out_edges.remove(&weight) { + Some(out_edges_key) => out_edges_key, + None => continue, + }; + for source in in_edges_subset { + for target in out_edges_subset.iter() { + self._add_edge(source, *target, weight.clone_ref(py))?; + } + } + } + self.node_removed = self.graph.remove_node(node_index).is_some(); + Ok(()) + } + + /// Remove a node from the graph and add edges from predecessors to successors in cases where + /// an incoming and outgoing edge have the same weight by Python object equality. + /// + /// This function has a minimum time complexity of :math:`\mathcal O(e_i + e_o)`, where + /// :math:`e_i` is the number of incoming edges and :math:`e_o` the number of outgoing edges + /// (the full complexity depends on the number of new edges to be created). + /// + /// Edges will be added between all pairs of predecessor and successor nodes that have equal + /// weights. As a consequence, any weight which appears only on predecessor edges will not + /// appear in the output, as there are no successors to pair it with. + /// + /// If there are multiple edges with the same weight, the exact Python object used on the new + /// edges is an implementation detail and may change. The only guarantees are that it will be + /// deterministic for a given graph, and that it will be drawn from the incoming edges if + /// ``use_outgoing=False`` (the default) or from the outgoing edges if ``use_outgoing=True``. + /// + /// :param int node: The index of the node to remove. If the index is not present in the graph + /// it will be ignored and this function will have no effect. + /// :param key: A callable Python object that is called once for each connected edge, to + /// generate the "key" for that weight. It is passed exactly one argument positionally + /// (the weight of the edge), and should return a Python object that is hashable and + /// implements equality checking with all other relevant keys. If not given, the edge + /// weights will be used directly. + /// :param bool use_outgoing: If ``False`` (default), the new edges will use the weight from + /// one of the incoming edges. If ``True``, they will instead use a weight from one of the + /// outgoing edges. + #[pyo3(signature=(node, /, key=None, *, use_outgoing=false))] + pub fn remove_node_retain_edges_by_key( + &mut self, + py: Python, + node: usize, + key: Option>, + use_outgoing: bool, + ) -> PyResult<()> { + let node_index = NodeIndex::new(node); + let in_edges = { + let in_edges = PyDict::new(py); + for edge in self + .graph + .edges_directed(node_index, petgraph::Direction::Incoming) + { + let key_value = if let Some(key_fn) = &key { + key_fn.call1(py, (edge.weight(),))? + } else { + edge.weight().clone_ref(py) + }; + if let Some(edge_data) = in_edges.get_item(key_value.as_ref(py))? { + let edge_data = edge_data.extract::()?; + edge_data.nodes.as_ref(py).append(edge.source().index())? + } else { + in_edges.set_item( + key_value, + RemoveNodeEdgeValue { + weight: edge.weight().clone_ref(py), + nodes: PyList::new(py, [edge.source().index()]).into_py(py), + } + .into_py(py), + )? + } + } + in_edges + }; + let out_edges = { + let out_edges = PyDict::new(py); + for edge in self + .graph + .edges_directed(node_index, petgraph::Direction::Outgoing) + { + let key_value = if let Some(key_fn) = &key { + key_fn.call1(py, (edge.weight(),))? + } else { + edge.weight().clone_ref(py) + }; + if let Some(edge_data) = out_edges.get_item(key_value.as_ref(py))? { + let edge_data = edge_data.extract::()?; + edge_data.nodes.as_ref(py).append(edge.target().index())? + } else { + out_edges.set_item( + key_value, + RemoveNodeEdgeValue { + weight: edge.weight().clone_ref(py), + nodes: PyList::new(py, [edge.target().index()]).into_py(py), + } + .into_py(py), + )? + } + } + out_edges + }; + + for (in_key, in_edge_data) in in_edges { + let in_edge_data = in_edge_data.extract::()?; + let out_edge_data = match out_edges.get_item(in_key)? { + Some(out_edge_data) => out_edge_data.extract::()?, + None => continue, + }; + for source in in_edge_data.nodes.as_ref(py) { + let source = NodeIndex::new(source.extract::()?); + for target in out_edge_data.nodes.as_ref(py) { + let target = NodeIndex::new(target.extract::()?); + let weight = if use_outgoing { + out_edge_data.weight.clone_ref(py) + } else { + in_edge_data.weight.clone_ref(py) + }; + self._add_edge(source, target, weight)?; + } + } + } + self.node_removed = self.graph.remove_node(node_index).is_some(); Ok(()) } @@ -3100,3 +3284,35 @@ where attrs: py.None(), } } + +/// Simple wrapper newtype that lets us use `Py` pointers as hash keys with the equality defined by +/// the pointer address. This is equivalent to using Python's `is` operator for comparisons. +/// Using a newtype rather than casting the pointer to `usize` inline lets us retrieve a copy of +/// the reference from the key entry. +struct PyAnyId(Py); +impl PyAnyId { + fn clone_ref(&self, py: Python) -> Py { + self.0.clone_ref(py) + } +} +impl ::std::hash::Hash for PyAnyId { + fn hash(&self, state: &mut H) { + (self.0.as_ptr() as usize).hash(state) + } +} +impl PartialEq for PyAnyId { + fn eq(&self, other: &Self) -> bool { + self.0.as_ptr() == other.0.as_ptr() + } +} +impl Eq for PyAnyId {} + +/// Internal-only helper class used by `remove_node_retain_edges_by_key` to store its data as a +/// typed object in a Python dictionary. This object should be fairly cheap to construct new +/// instances of; it involves two refcount updates, but otherwise is just two pointers wide. +#[pyclass] +#[derive(Clone)] +struct RemoveNodeEdgeValue { + weight: Py, + nodes: Py, +} diff --git a/tests/digraph/test_nodes.py b/tests/digraph/test_nodes.py index 36dcc0b45..febfa7bb0 100644 --- a/tests/digraph/test_nodes.py +++ b/tests/digraph/test_nodes.py @@ -174,6 +174,172 @@ def test_remove_nodes_retain_edges_with_invalid_index(self): self.assertEqual(["a", "b", "c"], res) self.assertEqual([0, 1, 2], dag.node_indexes()) + def test_remove_nodes_retain_edges_by_id_singles(self): + dag = rustworkx.PyDAG() + weights = [object(), object()] + before_nodes = [dag.add_node(i) for i, _ in enumerate(weights)] + middle = dag.add_node(10) + after_nodes = [dag.add_node(20 + i) for i, _ in enumerate(weights)] + for before, after, weight in zip(before_nodes, after_nodes, weights): + dag.add_edge(before, middle, weight) + dag.add_edge(middle, after, weight) + dag.remove_node_retain_edges_by_id(middle) + self.assertEqual(set(dag.node_indices()), set(before_nodes) | set(after_nodes)) + expected_edges = set(zip(before_nodes, after_nodes, weights)) + self.assertEqual(set(dag.weighted_edge_list()), expected_edges) + + def test_remove_nodes_retain_edges_by_id_parallel(self): + dag = rustworkx.PyDAG() + nodes = [dag.add_node(i) for i in range(3)] + weights = [object(), object(), object()] + for weight in weights: + dag.add_edge(nodes[0], nodes[1], weight) + dag.add_edge(nodes[1], nodes[2], weight) + # The middle node has three precessor edges and three successor edges, where each set has + # one edge each of three weights. Edges should be paired up in bijection during the removal. + dag.remove_node_retain_edges_by_id(nodes[1]) + self.assertEqual(set(dag.node_indices()), {nodes[0], nodes[2]}) + expected_edges = {(nodes[0], nodes[2], weight) for weight in weights} + self.assertEqual(set(dag.weighted_edge_list()), expected_edges) + + def test_remove_nodes_retain_edges_by_id_broadcast(self): + dag = rustworkx.PyDAG() + nodes = {a: dag.add_node(a) for a in "abcdefghijklmn"} + mid = dag.add_node("middle") + weights = [object(), object(), object(), object(), object()] + expected_edges = set() + + # 2:1 broadcast. + dag.add_edge(nodes["a"], mid, weights[0]) + dag.add_edge(nodes["b"], mid, weights[0]) + dag.add_edge(mid, nodes["c"], weights[0]) + expected_edges |= { + (nodes["a"], nodes["c"], weights[0]), + (nodes["b"], nodes["c"], weights[0]), + } + + # 1:2 broadcast + dag.add_edge(nodes["d"], mid, weights[1]) + dag.add_edge(mid, nodes["e"], weights[1]) + dag.add_edge(mid, nodes["f"], weights[1]) + expected_edges |= { + (nodes["d"], nodes["e"], weights[1]), + (nodes["d"], nodes["f"], weights[1]), + } + + # 2:2 broadacst + dag.add_edge(nodes["g"], mid, weights[2]) + dag.add_edge(nodes["h"], mid, weights[2]) + dag.add_edge(mid, nodes["i"], weights[2]) + dag.add_edge(mid, nodes["j"], weights[2]) + expected_edges |= { + (nodes["g"], nodes["i"], weights[2]), + (nodes["g"], nodes["j"], weights[2]), + (nodes["h"], nodes["i"], weights[2]), + (nodes["h"], nodes["j"], weights[2]), + } + + # 0:1 broadcast + dag.add_edge(mid, nodes["k"], weights[3]) + + # 1:0 broadcast + dag.add_edge(nodes["l"], mid, weights[4]) + + # Edge that doesn't go via the middle at all, but shares an id with another edge. This + # shouldn't be touched. + dag.add_edge(nodes["m"], nodes["n"], weights[0]) + expected_edges |= { + (nodes["m"], nodes["n"], weights[0]), + } + + dag.remove_node_retain_edges_by_id(mid) + + self.assertEqual(set(dag.nodes()), set(nodes)) + self.assertEqual(set(dag.weighted_edge_list()), expected_edges) + + def test_remove_nodes_retain_edges_by_key_singles_id_map(self): + dag = rustworkx.PyDAG() + before = [dag.add_node(0), dag.add_node(1)] + middle = dag.add_node(2) + after = [dag.add_node(3), dag.add_node(4)] + + expected_edges = set() + dag.add_edge(before[0], middle, (0,)) + dag.add_edge(middle, after[0], (0,)) + expected_edges.add((before[0], after[0], (0,))) + dag.add_edge(before[1], middle, (1,)) + dag.add_edge(middle, after[1], (1,)) + expected_edges.add((before[1], after[1], (1,))) + + dag.remove_node_retain_edges_by_id(middle) + + self.assertEqual(set(dag.node_indices()), set(before) | set(after)) + self.assertEqual(set(dag.weighted_edge_list()), expected_edges) + + def test_remove_nodes_retain_edges_by_key_broadcast_mod_map(self): + dag = rustworkx.PyDAG() + nodes = {a: dag.add_node(a) for a in "abcdefghijklmn"} + mid = dag.add_node("middle") + expected_edges = {} + + # 2:1 broadcast. + dag.add_edge(nodes["a"], mid, 10) + dag.add_edge(nodes["b"], mid, 20) + dag.add_edge(mid, nodes["c"], 30) + # The edge data here is a list of allowed weight - the function doesn't prescribe which + # exact weight will be used, just that it's from the incoming edges. + allowed_weights = {10, 20} + expected_edges[nodes["a"], nodes["c"]] = allowed_weights + expected_edges[nodes["b"], nodes["c"]] = allowed_weights + + # 1:2 broadcast + dag.add_edge(nodes["d"], mid, 11) + dag.add_edge(mid, nodes["e"], 21) + dag.add_edge(mid, nodes["f"], 31) + allowed_weights = {11} + expected_edges[nodes["d"], nodes["e"]] = allowed_weights + expected_edges[nodes["d"], nodes["f"]] = allowed_weights + + # 2:2 broadacst + dag.add_edge(nodes["g"], mid, 12) + dag.add_edge(nodes["h"], mid, 22) + dag.add_edge(mid, nodes["i"], 32) + dag.add_edge(mid, nodes["j"], 42) + allowed_weights = {12, 22} + expected_edges[nodes["g"], nodes["i"]] = allowed_weights + expected_edges[nodes["g"], nodes["j"]] = allowed_weights + expected_edges[nodes["h"], nodes["i"]] = allowed_weights + expected_edges[nodes["h"], nodes["j"]] = allowed_weights + + # 0:1 broadcast + dag.add_edge(mid, nodes["k"], 13) + + # 1:0 broadcast + dag.add_edge(nodes["l"], mid, 14) + + # Edge that doesn't go via the middle at all, but shares a key with another edge. This + # shouldn't be touched. + dag.add_edge(nodes["m"], nodes["n"], 10) + expected_edges[nodes["m"], nodes["n"]] = {10} + + dag.remove_node_retain_edges_by_key(mid, key=lambda weight: weight % 10) + + self.assertEqual(set(dag.nodes()), set(nodes)) + self.assertEqual(set(dag.edge_list()), set(expected_edges)) + for source, target, weight in dag.weighted_edge_list(): + self.assertIn(weight, expected_edges[source, target]) + + def test_remove_nodes_retain_edges_by_key_use_outgoing(self): + dag = rustworkx.PyDAG() + before = dag.add_node(0) + middle = dag.add_node(1) + after = dag.add_node(2) + dag.add_edge(before, middle, 0) + dag.add_edge(middle, after, 2) + dag.remove_node_retain_edges_by_key(middle, key=lambda weight: weight % 2, use_outgoing=True) + self.assertEqual(set(dag.node_indices()), {before, after}) + self.assertEqual(set(dag.weighted_edge_list()), {(before, after, 2)}) + def test_topo_sort_empty(self): dag = rustworkx.PyDAG() self.assertEqual([], rustworkx.topological_sort(dag)) From f86e8b7d733c00dbde9603e90c032c77c3644f97 Mon Sep 17 00:00:00 2001 From: Jake Lishman Date: Thu, 8 Feb 2024 15:28:12 +0000 Subject: [PATCH 2/4] Format --- tests/digraph/test_nodes.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/digraph/test_nodes.py b/tests/digraph/test_nodes.py index febfa7bb0..b36858c32 100644 --- a/tests/digraph/test_nodes.py +++ b/tests/digraph/test_nodes.py @@ -336,7 +336,9 @@ def test_remove_nodes_retain_edges_by_key_use_outgoing(self): after = dag.add_node(2) dag.add_edge(before, middle, 0) dag.add_edge(middle, after, 2) - dag.remove_node_retain_edges_by_key(middle, key=lambda weight: weight % 2, use_outgoing=True) + dag.remove_node_retain_edges_by_key( + middle, key=lambda weight: weight % 2, use_outgoing=True + ) self.assertEqual(set(dag.node_indices()), {before, after}) self.assertEqual(set(dag.weighted_edge_list()), {(before, after, 2)}) From ab20d3c8d21965f12081a5a6eff8ec3a475af288 Mon Sep 17 00:00:00 2001 From: Jake Lishman Date: Fri, 5 Apr 2024 12:18:04 +0100 Subject: [PATCH 3/4] Update conversion actions to be by `Bound` reference This is partly a PyO3 0.21 upgrade, partly fixing a dodgy `extract` in favour of `downcast`. --- src/digraph.rs | 35 ++++++++++++++++------------------- 1 file changed, 16 insertions(+), 19 deletions(-) diff --git a/src/digraph.rs b/src/digraph.rs index b738dc20a..a60814a64 100644 --- a/src/digraph.rs +++ b/src/digraph.rs @@ -1151,7 +1151,7 @@ impl PyDiGraph { ) -> PyResult<()> { let node_index = NodeIndex::new(node); let in_edges = { - let in_edges = PyDict::new(py); + let in_edges = PyDict::new_bound(py); for edge in self .graph .edges_directed(node_index, petgraph::Direction::Incoming) @@ -1161,15 +1161,15 @@ impl PyDiGraph { } else { edge.weight().clone_ref(py) }; - if let Some(edge_data) = in_edges.get_item(key_value.as_ref(py))? { - let edge_data = edge_data.extract::()?; - edge_data.nodes.as_ref(py).append(edge.source().index())? + if let Some(edge_data) = in_edges.get_item(key_value.bind(py))? { + let edge_data = edge_data.downcast::()?; + edge_data.borrow_mut().nodes.push(edge.source()); } else { in_edges.set_item( key_value, RemoveNodeEdgeValue { weight: edge.weight().clone_ref(py), - nodes: PyList::new(py, [edge.source().index()]).into_py(py), + nodes: vec![edge.source()], } .into_py(py), )? @@ -1178,7 +1178,7 @@ impl PyDiGraph { in_edges }; let out_edges = { - let out_edges = PyDict::new(py); + let out_edges = PyDict::new_bound(py); for edge in self .graph .edges_directed(node_index, petgraph::Direction::Outgoing) @@ -1188,15 +1188,15 @@ impl PyDiGraph { } else { edge.weight().clone_ref(py) }; - if let Some(edge_data) = out_edges.get_item(key_value.as_ref(py))? { - let edge_data = edge_data.extract::()?; - edge_data.nodes.as_ref(py).append(edge.target().index())? + if let Some(edge_data) = out_edges.get_item(key_value.bind(py))? { + let edge_data = edge_data.downcast::()?; + edge_data.borrow_mut().nodes.push(edge.target()); } else { out_edges.set_item( key_value, RemoveNodeEdgeValue { weight: edge.weight().clone_ref(py), - nodes: PyList::new(py, [edge.target().index()]).into_py(py), + nodes: vec![edge.target()], } .into_py(py), )? @@ -1206,21 +1206,19 @@ impl PyDiGraph { }; for (in_key, in_edge_data) in in_edges { - let in_edge_data = in_edge_data.extract::()?; + let in_edge_data = in_edge_data.downcast::()?.borrow(); let out_edge_data = match out_edges.get_item(in_key)? { - Some(out_edge_data) => out_edge_data.extract::()?, + Some(out_edge_data) => out_edge_data.downcast::()?.borrow(), None => continue, }; - for source in in_edge_data.nodes.as_ref(py) { - let source = NodeIndex::new(source.extract::()?); - for target in out_edge_data.nodes.as_ref(py) { - let target = NodeIndex::new(target.extract::()?); + for source in in_edge_data.nodes.iter() { + for target in out_edge_data.nodes.iter() { let weight = if use_outgoing { out_edge_data.weight.clone_ref(py) } else { in_edge_data.weight.clone_ref(py) }; - self._add_edge(source, target, weight)?; + self._add_edge(*source, *target, weight)?; } } } @@ -3301,8 +3299,7 @@ impl Eq for PyAnyId {} /// typed object in a Python dictionary. This object should be fairly cheap to construct new /// instances of; it involves two refcount updates, but otherwise is just two pointers wide. #[pyclass] -#[derive(Clone)] struct RemoveNodeEdgeValue { weight: Py, - nodes: Py, + nodes: Vec, } From 9f3a0a7e0de58defd3bc2a005501026b5fdb66bc Mon Sep 17 00:00:00 2001 From: Jake Lishman Date: Fri, 5 Apr 2024 12:22:34 +0100 Subject: [PATCH 4/4] Update out-of-date comment --- src/digraph.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/digraph.rs b/src/digraph.rs index a60814a64..dd91f8848 100644 --- a/src/digraph.rs +++ b/src/digraph.rs @@ -3296,8 +3296,7 @@ impl PartialEq for PyAnyId { impl Eq for PyAnyId {} /// Internal-only helper class used by `remove_node_retain_edges_by_key` to store its data as a -/// typed object in a Python dictionary. This object should be fairly cheap to construct new -/// instances of; it involves two refcount updates, but otherwise is just two pointers wide. +/// typed object in a Python dictionary. #[pyclass] struct RemoveNodeEdgeValue { weight: Py,