diff --git a/releasenotes/notes/handle-non-existent-edge-15d70cfe60c89ac2.yaml b/releasenotes/notes/handle-non-existent-edge-15d70cfe60c89ac2.yaml new file mode 100644 index 0000000000..8f8e2acd22 --- /dev/null +++ b/releasenotes/notes/handle-non-existent-edge-15d70cfe60c89ac2.yaml @@ -0,0 +1,6 @@ +--- +fixes: + - | + :meth:`rustworkx.PyGraph.add_edge` and :meth:`rustworkx.PyDiGraph.add_edge` and now raises an + ``IndexError`` when one of the nodes does not exist in the graph. Previously, it caused the Python + interpreter to exit with a ``PanicException`` diff --git a/src/connectivity/mod.rs b/src/connectivity/mod.rs index 005ad08afa..e17d30733d 100644 --- a/src/connectivity/mod.rs +++ b/src/connectivity/mod.rs @@ -398,7 +398,7 @@ pub fn graph_complement(py: Python, graph: &graph::PyGraph) -> PyResult PyResult { let p_index = NodeIndex::new(parent); let c_index = NodeIndex::new(child); + if !self.graph.contains_node(p_index) || !self.graph.contains_node(c_index) { + return Err(PyIndexError::new_err( + "One of the endpoints of the edge does not exist in graph", + )); + } let out_index = self._add_edge(p_index, c_index, edge)?; Ok(out_index) } @@ -1103,9 +1108,7 @@ impl PyDiGraph { ) -> PyResult> { let mut out_list: Vec = Vec::with_capacity(obj_list.len()); for obj in obj_list { - let p_index = NodeIndex::new(obj.0); - let c_index = NodeIndex::new(obj.1); - let edge = self._add_edge(p_index, c_index, obj.2)?; + let edge = self.add_edge(obj.0, obj.1, obj.2)?; out_list.push(edge); } Ok(out_list) @@ -1129,9 +1132,7 @@ impl PyDiGraph { ) -> PyResult> { let mut out_list: Vec = Vec::with_capacity(obj_list.len()); for obj in obj_list { - let p_index = NodeIndex::new(obj.0); - let c_index = NodeIndex::new(obj.1); - let edge = self._add_edge(p_index, c_index, py.None())?; + let edge = self.add_edge(obj.0, obj.1, py.None())?; out_list.push(edge); } Ok(out_list) diff --git a/src/graph.rs b/src/graph.rs index 7857cef9c8..d57fbb3215 100644 --- a/src/graph.rs +++ b/src/graph.rs @@ -847,10 +847,15 @@ impl PyGraph { /// of an existing edge with ``multigraph=False``) edge. /// :rtype: int #[pyo3(text_signature = "(self, node_a, node_b, edge, /)")] - pub fn add_edge(&mut self, node_a: usize, node_b: usize, edge: PyObject) -> usize { + pub fn add_edge(&mut self, node_a: usize, node_b: usize, edge: PyObject) -> PyResult { let p_index = NodeIndex::new(node_a); let c_index = NodeIndex::new(node_b); - self._add_edge(p_index, c_index, edge) + if !self.graph.contains_node(p_index) || !self.graph.contains_node(c_index) { + return Err(PyIndexError::new_err( + "One of the endpoints of the edge does not exist in graph", + )); + } + Ok(self._add_edge(p_index, c_index, edge)) } /// Add new edges to the graph. @@ -869,14 +874,15 @@ impl PyGraph { /// :returns: A list of int indices of the newly created edges /// :rtype: list #[pyo3(text_signature = "(self, obj_list, /)")] - pub fn add_edges_from(&mut self, obj_list: Vec<(usize, usize, PyObject)>) -> EdgeIndices { + pub fn add_edges_from( + &mut self, + obj_list: Vec<(usize, usize, PyObject)>, + ) -> PyResult { let mut out_list: Vec = Vec::with_capacity(obj_list.len()); for obj in obj_list { - let p_index = NodeIndex::new(obj.0); - let c_index = NodeIndex::new(obj.1); - out_list.push(self._add_edge(p_index, c_index, obj.2)); + out_list.push(self.add_edge(obj.0, obj.1, obj.2)?); } - EdgeIndices { edges: out_list } + Ok(EdgeIndices { edges: out_list }) } /// Add new edges to the graph without python data. @@ -898,14 +904,12 @@ impl PyGraph { &mut self, py: Python, obj_list: Vec<(usize, usize)>, - ) -> EdgeIndices { + ) -> PyResult { let mut out_list: Vec = Vec::with_capacity(obj_list.len()); for obj in obj_list { - let p_index = NodeIndex::new(obj.0); - let c_index = NodeIndex::new(obj.1); - out_list.push(self._add_edge(p_index, c_index, py.None())); + out_list.push(self.add_edge(obj.0, obj.1, py.None())?); } - EdgeIndices { edges: out_list } + Ok(EdgeIndices { edges: out_list }) } /// Extend graph from an edge list @@ -1703,7 +1707,7 @@ impl PyGraph { } for (source, weight) in edges { - self.add_edge(source.index(), node_index.index(), weight); + self.add_edge(source.index(), node_index.index(), weight)?; } Ok(node_index.index()) diff --git a/src/tree.rs b/src/tree.rs index 11e2ba5b66..b9426346c2 100644 --- a/src/tree.rs +++ b/src/tree.rs @@ -130,7 +130,7 @@ pub fn minimum_spanning_tree( .edges .iter() { - spanning_tree.add_edge(edge.0, edge.1, edge.2.clone_ref(py)); + spanning_tree.add_edge(edge.0, edge.1, edge.2.clone_ref(py))?; } Ok(spanning_tree) diff --git a/tests/rustworkx_tests/digraph/test_edges.py b/tests/rustworkx_tests/digraph/test_edges.py index 2d9f56ae52..ef0af66dd6 100644 --- a/tests/rustworkx_tests/digraph/test_edges.py +++ b/tests/rustworkx_tests/digraph/test_edges.py @@ -963,6 +963,21 @@ def test_extend_from_weighted_edge_list(self): self.assertEqual(len(graph), 4) self.assertEqual(["a", "b", "c", "d", "e"], graph.edges()) + def test_add_edge_non_existent(self): + g = rustworkx.PyDiGraph() + with self.assertRaises(IndexError): + g.add_edge(2, 3, None) + + def test_add_edges_from_non_existent(self): + g = rustworkx.PyDiGraph() + with self.assertRaises(IndexError): + g.add_edges_from([(2, 3, 5)]) + + def test_add_edges_from_no_data_non_existent(self): + g = rustworkx.PyDiGraph() + with self.assertRaises(IndexError): + g.add_edges_from_no_data([(2, 3)]) + def test_reverse_graph(self): graph = rustworkx.PyDiGraph() graph.add_nodes_from([i for i in range(4)]) @@ -978,7 +993,7 @@ def test_reverse_graph(self): self.assertEqual([(1, 0), (2, 1), (2, 0), (3, 2), (3, 0)], graph.edge_list()) def test_reverse_large_graph(self): - LARGE_AMOUNT_OF_NODES = 10000000 + LARGE_AMOUNT_OF_NODES = 1000000 graph = rustworkx.PyDiGraph() graph.add_nodes_from(range(LARGE_AMOUNT_OF_NODES)) diff --git a/tests/rustworkx_tests/graph/test_edges.py b/tests/rustworkx_tests/graph/test_edges.py index 4981225a50..04f24af1a3 100644 --- a/tests/rustworkx_tests/graph/test_edges.py +++ b/tests/rustworkx_tests/graph/test_edges.py @@ -817,3 +817,18 @@ def test_extend_from_weighted_edge_list(self): graph.extend_from_weighted_edge_list(edge_list) self.assertEqual(len(graph), 4) self.assertEqual(["a", "b", "c", "d", "e"], graph.edges()) + + def test_add_edge_non_existent(self): + g = rustworkx.PyGraph() + with self.assertRaises(IndexError): + g.add_edge(2, 3, None) + + def test_add_edges_from_non_existent(self): + g = rustworkx.PyGraph() + with self.assertRaises(IndexError): + g.add_edges_from([(2, 3, 5)]) + + def test_add_edges_from_no_data_non_existent(self): + g = rustworkx.PyGraph() + with self.assertRaises(IndexError): + g.add_edges_from_no_data([(2, 3)])