Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add PyDiGraph method to make edges symmetric #814

Merged
merged 14 commits into from
May 15, 2023
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
---
features:
- |
Added a new method, :meth:`~.PyDiGraph.make_symmetric`, to the
:class:`~.PyDiGraph` class. This method is used to make all the edges
in the graph symmetric (there is a reverse edge in the graph for each edge).
For example:

.. jupyter-execute::

import rustworkx as rx
from rustworkx.visualization import graphviz_draw

graph = rx.generators.directed_path_graph(5, bidirectional=False)
graph.make_symmetric()
graphviz_draw(graph)
37 changes: 37 additions & 0 deletions src/digraph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2712,6 +2712,43 @@ impl PyDiGraph {
edges.is_empty()
}

/// Make edges in graph symmetric
///
/// This function iterates over all the edges in the graph, adding for each
/// edge the reversed edge, unless one is already present. Note the edge insertion
/// is not fixed and the edge indices are not guaranteed to be consistent
/// between executions of this method on identical graphs.
///
/// :param callable edge_payload: This optional argument takes in a callable which will
/// be passed a single positional argument the data payload for an edge that will
/// have a reverse copied in the graph. The returned value from this callable will
/// be used as the data payload for the new edge created. If this is not specified
/// then by default the data payload will be copied when the reverse edge is added.
/// If there are parallel edges, then one of the edges (typically the one with the lower
/// index, but this is not a guarantee) will be copied.
pub fn make_symmetric(
&mut self,
py: Python,
edge_payload_fn: Option<PyObject>,
) -> PyResult<()> {
let edges: HashMap<[NodeIndex; 2], EdgeIndex> = self
.graph
.edge_references()
.map(|edge| ([edge.source(), edge.target()], edge.id()))
.collect();
for ([edge_source, edge_target], edge_index) in edges.iter() {
if !edges.contains_key(&[*edge_target, *edge_source]) {
let forward_weight = self.graph.edge_weight(*edge_index).unwrap();
let weight: PyObject = match edge_payload_fn.as_ref() {
Some(callback) => callback.call1(py, (forward_weight,))?,
None => forward_weight.clone_ref(py),
};
self._add_edge(*edge_target, *edge_source, weight)?;
}
}
Ok(())
}

/// Generate a new PyGraph object from this graph
///
/// This will create a new :class:`~rustworkx.PyGraph` object from this
Expand Down
83 changes: 83 additions & 0 deletions tests/rustworkx_tests/digraph/test_symmetric.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@
import rustworkx


def default_weight_function(edge):
return "Reversi"


class TestSymmetric(unittest.TestCase):
def test_single_neighbor(self):
digraph = rustworkx.PyDiGraph()
Expand All @@ -37,3 +41,82 @@ def test_bidirectional_ring(self):
]
digraph.extend_from_edge_list(edge_list)
self.assertTrue(digraph.is_symmetric())

def test_empty_graph_make_symmetric(self):
digraph = rustworkx.PyDiGraph()
digraph.make_symmetric()
self.assertEqual(0, digraph.num_edges())
self.assertEqual(0, digraph.num_nodes())

def test_path_graph_make_symmetric(self):
digraph = rustworkx.generators.directed_path_graph(4, bidirectional=False)
digraph.make_symmetric()
expected_edge_list = {
(0, 1),
(1, 2),
(2, 3),
(1, 0),
(2, 1),
(3, 2),
}
self.assertEqual(set(digraph.edge_list()), expected_edge_list)

def test_path_graph_make_symmetric_existing_reverse_edges(self):
digraph = rustworkx.generators.directed_path_graph(4, bidirectional=False)
digraph.add_edge(3, 2, None)
digraph.add_edge(1, 0, None)
digraph.make_symmetric()
expected_edge_list = {
(0, 1),
(1, 2),
(2, 3),
(3, 2),
(1, 0),
(2, 1),
}
self.assertEqual(set(digraph.edge_list()), expected_edge_list)

def test_empty_graph_make_symmetric_with_function_arg(self):
digraph = rustworkx.PyDiGraph()
digraph.make_symmetric(default_weight_function)
self.assertEqual(0, digraph.num_edges())
self.assertEqual(0, digraph.num_nodes())

def test_path_graph_make_symmetric_with_function_arg(self):
digraph = rustworkx.generators.directed_path_graph(4, bidirectional=False)
digraph.make_symmetric(default_weight_function)
expected_edge_list = {
(0, 1, None),
(1, 2, None),
(2, 3, None),
(1, 0, "Reversi"),
(2, 1, "Reversi"),
(3, 2, "Reversi"),
}
result = set(digraph.weighted_edge_list())
self.assertEqual(result, expected_edge_list)

def test_path_graph_make_symmetric_existing_reverse_edges_function_arg(self):
digraph = rustworkx.generators.directed_path_graph(4, bidirectional=False)
digraph.add_edge(3, 2, None)
digraph.add_edge(1, 0, None)
digraph.make_symmetric(default_weight_function)
expected_edge_list = {
(0, 1, None),
(1, 2, None),
(2, 3, None),
(3, 2, None),
(1, 0, None),
(2, 1, "Reversi"),
}
self.assertEqual(set(digraph.weighted_edge_list()), expected_edge_list)

def test_path_graph_make_symmetric_function_arg_raises(self):
digraph = rustworkx.generators.directed_path_graph(4)

def weight_function(edge):
if edge is None:
raise TypeError("I'm expected")

with self.assertRaises(TypeError):
digraph.make_symmetric(weight_function)