Skip to content

Commit

Permalink
Add reverse inplace function for digraph (#853)
Browse files Browse the repository at this point in the history
* added a reverse_inplace function in digraph,

the function reverses the direction of the edges in the digraph
implemented by switching the indices of the nodes in an edge.

* added python tests for the reverse_inplace function.

testing a simple case and a case for a large graph.

* ran rust fmt and clippy, also added more detailed documentation

* rename reverse_inplace to reverse

* change excepts to unwraps (If this fails is because of PyO3. It panics and there is not much point in printing a message)

* added tests for empty graph and graph with node removed in the middle

* added interface signature for IDEs

* ran cargo fmt

* Fix doc syntax

---------

Co-authored-by: Matthew Treinish <mtreinish@kortar.org>
  • Loading branch information
matanco64 and mtreinish committed May 10, 2023
1 parent c17eea5 commit a16c18d
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 0 deletions.
1 change: 1 addition & 0 deletions rustworkx/digraph.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ class PyDiGraph(Generic[S, T]):
deliminator: Optional[str] = ...,
weight_fn: Optional[Callable[[T], str]] = ...,
) -> None: ...
def reverse(self) -> None: ...
def __delitem__(self, idx: int, /) -> None: ...
def __getitem__(self, idx: int, /) -> S: ...
def __getstate__(self) -> Any: ...
Expand Down
33 changes: 33 additions & 0 deletions src/digraph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2819,6 +2819,39 @@ impl PyDiGraph {
self.clone()
}

/// Reverse the direction of all edges in the graph, in place.
///
/// This method modifies the graph instance to reverse the direction of all edges.
/// It does so by iterating over all edges in the graph and removing each edge,
/// then adding a new edge in the opposite direction with the same weight.
///
/// For Example::
///
/// import rustworkx as rx
///
/// graph = rx.PyDiGraph()
///
/// # Generate a path directed path graph with weights
/// graph.extend_from_weighted_edge_list([
/// (0, 1, 3),
/// (1, 2, 5),
/// (2, 3, 2),
/// ])
/// # Reverse edges
/// graph.reverse()
///
/// assert graph.weighted_edge_list() == [(3, 2, 2), (2, 1, 5), (1, 0, 3)];
#[pyo3(text_signature = "(self)")]
pub fn reverse(&mut self, py: Python) {
let indices = self.graph.edge_indices().collect::<Vec<EdgeIndex>>();
for idx in indices {
let (source_node, dest_node) = self.graph.edge_endpoints(idx).unwrap();
let weight = self.graph.edge_weight(idx).unwrap().clone_ref(py);
self.graph.remove_edge(idx);
self.graph.add_edge(dest_node, source_node, weight);
}
}

/// Return the number of nodes in the graph
fn __len__(&self) -> PyResult<usize> {
Ok(self.graph.node_count())
Expand Down
41 changes: 41 additions & 0 deletions tests/rustworkx_tests/digraph/test_edges.py
Original file line number Diff line number Diff line change
Expand Up @@ -962,3 +962,44 @@ def test_extend_from_weighted_edge_list(self):
graph.extend_from_weighted_edge_list(edge_list)
self.assertEqual(len(graph), 4)
self.assertEqual(["a", "b", "c", "d", "e"], graph.edges())

def test_reverse_graph(self):
graph = rustworkx.PyDiGraph()
graph.add_nodes_from([i for i in range(4)])
edge_list = [
(0, 1, "a"),
(1, 2, "b"),
(0, 2, "c"),
(2, 3, "d"),
(0, 3, "e"),
]
graph.add_edges_from(edge_list)
graph.reverse()
self.assertEqual([(1, 0), (2, 1), (2, 0), (3, 2), (3, 0)], graph.edge_list())

def test_reverse_large_graph(self):
LARGE_AMOUNT_OF_NODES = 10000000

graph = rustworkx.PyDiGraph()
graph.add_nodes_from(range(LARGE_AMOUNT_OF_NODES))
edge_list = list(zip(range(LARGE_AMOUNT_OF_NODES), range(1, LARGE_AMOUNT_OF_NODES)))
weighted_edge_list = [(s, d, "a") for s, d in edge_list]
graph.add_edges_from(weighted_edge_list)
graph.reverse()
reversed_edge_list = [(d, s) for s, d in edge_list]
self.assertEqual(reversed_edge_list, graph.edge_list())

def test_reverse_empty_graph(self):
graph = rustworkx.PyDiGraph()
edges_before = graph.edge_list()
graph.reverse()
self.assertEqual(graph.edge_list(), edges_before)

def test_removed_middle_node_reverse(self):
graph = rustworkx.PyDiGraph()
graph.add_nodes_from(list(range(5)))
edge_list = [(0, 1), (2, 1), (1, 3), (3, 4), (4, 0)]
graph.extend_from_edge_list(edge_list)
graph.remove_node(1)
graph.reverse()
self.assertEqual(graph.edge_list(), [(4, 3), (0, 4)])

0 comments on commit a16c18d

Please sign in to comment.