Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add PyDiGraph method to make edges symmetric #814

Merged
merged 14 commits into from
May 15, 2023
37 changes: 37 additions & 0 deletions src/digraph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2606,6 +2606,43 @@ impl PyDiGraph {
edges.is_empty()
}

/// Make edges in graph symmetric
///
/// This function iterates over all the edges in the graph and for each edge if a reverse
/// edge is not present in the graph one will be added.
mtreinish marked this conversation as resolved.
Show resolved Hide resolved
///
/// :param callable edge_payload: This optional argument takes in a callable which will
/// be passed a single positional argument the data payload for an edge that will
/// have a reverse copied in the graph. The returned value from this callable will
/// be used as the data payload for the new edge created. If this is not specified
/// then by default the data payload will be copied when the reverse edge is added.
/// If there are parallel edges, then one of the edges (typically the one with the lower
/// index, but this is not a guarantee) will be copied.
pub fn make_symmetric(
&mut self,
py: Python,
edge_payload_fn: Option<PyObject>,
) -> PyResult<()> {
let edges: HashMap<[NodeIndex; 2], EdgeIndex> = self
.graph
.edge_references()
.map(|edge| ([edge.source(), edge.target()], edge.id()))
.collect();
for (edge_endpoints, edge_index) in edges.iter() {
let reverse_edge = [edge_endpoints[1], edge_endpoints[0]];
if !edges.contains_key(&reverse_edge) {
let forward_weight = self.graph.edge_weight(*edge_index).unwrap();
let weight: PyObject = match edge_payload_fn.as_ref() {
Some(callback) => callback.call1(py, (forward_weight,))?,
None => forward_weight.clone_ref(py),
};
self.graph
.add_edge(reverse_edge[0], reverse_edge[1], weight);
}
}
Ok(())
}

/// Generate a new PyGraph object from this graph
///
/// This will create a new :class:`~rustworkx.PyGraph` object from this
Expand Down
34 changes: 34 additions & 0 deletions tests/rustworkx_tests/digraph/test_symmetric.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,37 @@ def test_bidirectional_ring(self):
]
digraph.extend_from_edge_list(edge_list)
self.assertTrue(digraph.is_symmetric())

def test_empty_graph_make_symmetric(self):
digraph = rustworkx.PyDiGraph()
digraph.make_symmetric()
self.assertEqual(0, digraph.num_edges())
self.assertEqual(0, digraph.num_nodes())

def test_path_graph_make_symmetric(self):
digraph = rustworkx.generators.directed_path_graph(4)
digraph.make_symmetric()
expected_edge_list = [
(0, 1),
(1, 2),
(2, 3),
(1, 0),
(2, 1),
(3, 2),
]
self.assertEqual(digraph.edge_list(), expected_edge_list)

def test_path_graph_make_symmetric_existing_reverse_edges(self):
digraph = rustworkx.generators.directed_path_graph(4)
digraph.add_edge(3, 2, None)
digraph.add_edge(1, 0, None)
digraph.make_symmetric()
expected_edge_list = [
(0, 1),
(1, 2),
(2, 3),
(3, 2),
(1, 0),
(2, 1),
]
self.assertEqual(digraph.edge_list(), expected_edge_list)