Skip to content

Commit

Permalink
Attempt to add pyo3 wrapper for collect_bicolor_runs
Browse files Browse the repository at this point in the history
  • Loading branch information
ElePT committed May 28, 2024
1 parent 5bdc1c6 commit 87cd111
Show file tree
Hide file tree
Showing 2 changed files with 173 additions and 278 deletions.
323 changes: 157 additions & 166 deletions rustworkx-core/src/dag_algo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -316,116 +316,7 @@ where
Ok(Some((path, path_weight)))
}

#[cfg(test)]
mod test_longest_path {
use super::*;
use petgraph::graph::DiGraph;
use petgraph::stable_graph::StableDiGraph;

#[test]
fn test_empty_graph() {
let graph: DiGraph<(), ()> = DiGraph::new();
let weight_fn = |_: petgraph::graph::EdgeReference<()>| Ok::<i32, &str>(0);
let result = longest_path(&graph, weight_fn);
assert_eq!(result, Ok(Some((vec![], 0))));
}

#[test]
fn test_single_node_graph() {
let mut graph: DiGraph<(), ()> = DiGraph::new();
let n0 = graph.add_node(());
let weight_fn = |_: petgraph::graph::EdgeReference<()>| Ok::<i32, &str>(0);
let result = longest_path(&graph, weight_fn);
assert_eq!(result, Ok(Some((vec![n0], 0))));
}

#[test]
fn test_dag_with_multiple_paths() {
let mut graph: DiGraph<(), i32> = DiGraph::new();
let n0 = graph.add_node(());
let n1 = graph.add_node(());
let n2 = graph.add_node(());
let n3 = graph.add_node(());
let n4 = graph.add_node(());
let n5 = graph.add_node(());
graph.add_edge(n0, n1, 3);
graph.add_edge(n0, n2, 2);
graph.add_edge(n1, n2, 1);
graph.add_edge(n1, n3, 4);
graph.add_edge(n2, n3, 2);
graph.add_edge(n3, n4, 2);
graph.add_edge(n2, n5, 1);
graph.add_edge(n4, n5, 3);
let weight_fn = |edge: petgraph::graph::EdgeReference<i32>| Ok::<i32, &str>(*edge.weight());
let result = longest_path(&graph, weight_fn);
assert_eq!(result, Ok(Some((vec![n0, n1, n3, n4, n5], 12))));
}

#[test]
fn test_graph_with_cycle() {
let mut graph: DiGraph<(), i32> = DiGraph::new();
let n0 = graph.add_node(());
let n1 = graph.add_node(());
graph.add_edge(n0, n1, 1);
graph.add_edge(n1, n0, 1); // Creates a cycle

let weight_fn = |edge: petgraph::graph::EdgeReference<i32>| Ok::<i32, &str>(*edge.weight());
let result = longest_path(&graph, weight_fn);
assert_eq!(result, Ok(None));
}

#[test]
fn test_negative_weights() {
let mut graph: DiGraph<(), i32> = DiGraph::new();
let n0 = graph.add_node(());
let n1 = graph.add_node(());
let n2 = graph.add_node(());
graph.add_edge(n0, n1, -1);
graph.add_edge(n0, n2, 2);
graph.add_edge(n1, n2, -2);
let weight_fn = |edge: petgraph::graph::EdgeReference<i32>| Ok::<i32, &str>(*edge.weight());
let result = longest_path(&graph, weight_fn);
assert_eq!(result, Ok(Some((vec![n0, n2], 2))));
}

#[test]
fn test_longest_path_in_stable_digraph() {
let mut graph: StableDiGraph<(), i32> = StableDiGraph::new();
let n0 = graph.add_node(());
let n1 = graph.add_node(());
let n2 = graph.add_node(());
graph.add_edge(n0, n1, 1);
graph.add_edge(n0, n2, 3);
graph.add_edge(n1, n2, 1);
let weight_fn =
|edge: petgraph::stable_graph::EdgeReference<'_, i32>| Ok::<i32, &str>(*edge.weight());
let result = longest_path(&graph, weight_fn);
assert_eq!(result, Ok(Some((vec![n0, n2], 3))));
}

#[test]
fn test_error_handling() {
let mut graph: DiGraph<(), i32> = DiGraph::new();
let n0 = graph.add_node(());
let n1 = graph.add_node(());
let n2 = graph.add_node(());
graph.add_edge(n0, n1, 1);
graph.add_edge(n0, n2, 2);
graph.add_edge(n1, n2, 1);
let weight_fn = |edge: petgraph::graph::EdgeReference<i32>| {
if *edge.weight() == 2 {
Err("Error: edge weight is 2")
} else {
Ok::<i32, &str>(*edge.weight())
}
};
let result = longest_path(&graph, weight_fn);
assert_eq!(result, Err("Error: edge weight is 2"));
}
}

/// Define custom error classes for collect_bicolor_runs

#[derive(Debug)]
pub enum CollectBicolorError<E: Error> {
DAGWouldCycle,
Expand Down Expand Up @@ -477,12 +368,10 @@ pub fn collect_bicolor_runs<G, F, C, B, E>(
filter_fn: F,
color_fn: C,
) -> Result<Vec<Vec<G::NodeId>>, CollectBicolorError<E>>
//OG type: PyResult<Vec<Vec<PyObject>>>
where
E: Error,
// add Option to input type because of line 135
F: Fn(&Option<&<G as Data>::NodeWeight>) -> Result<Option<bool>, CollectBicolorError<E>>, //OG input: &PyObject, OG return: PyResult<Option<bool>>
C: Fn(&<G as Data>::EdgeWeight) -> Result<Option<usize>, CollectBicolorError<E>>, //OG input: &PyObject, OG return: PyResult<Option<usize>>
F: Fn(&Option<&<G as Data>::NodeWeight>) -> Result<Option<bool>, CollectBicolorError<E>>,
C: Fn(&<G as Data>::EdgeWeight) -> Result<Option<usize>, CollectBicolorError<E>>,
G: NodeIndexable
// can take node index type and convert to usize. It restricts node index type.
+ IntoNodeIdentifiers // used in toposort. Turns graph into list of nodes
Expand All @@ -492,9 +381,9 @@ where
+ DataMap, // used to access node weights
<G as GraphBase>::NodeId: Eq + Hash,
{
let mut pending_list: Vec<Vec<G::NodeId>> = Vec::new(); //OG type: Vec<Vec<PyObject>>
let mut block_id: Vec<Option<usize>> = Vec::new(); //OG type: Vec<Option<usize>>
let mut block_list: Vec<Vec<G::NodeId>> = Vec::new(); //OG type: Vec<Vec<PyObject>> -> return
let mut pending_list: Vec<Vec<G::NodeId>> = Vec::new();
let mut block_id: Vec<Option<usize>> = Vec::new();
let mut block_list: Vec<Vec<G::NodeId>> = Vec::new();

let filter_node =
|node: &Option<&<G as Data>::NodeWeight>| -> Result<Option<bool>, CollectBicolorError<E>> {
Expand Down Expand Up @@ -536,7 +425,6 @@ where
// Remove null edges from color_fn
let colors = colors.into_iter().flatten().collect::<Vec<usize>>();

// &NodeIndexable::from_index(&graph, node)
if colors.len() <= 2 && is_match {
if colors.len() == 1 {
let c0 = colors[0];
Expand Down Expand Up @@ -588,70 +476,112 @@ where

Ok(block_list)
}

#[cfg(test)]
mod test_collect_bicolor_runs {
mod test_longest_path {
use super::*;
use petgraph::graph::{DiGraph, NodeIndex};

#[derive(Debug)]
struct TestError;
use petgraph::graph::DiGraph;
use petgraph::stable_graph::StableDiGraph;

impl Display for TestError {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "Test error")
}
#[test]
fn test_empty_graph() {
let graph: DiGraph<(), ()> = DiGraph::new();
let weight_fn = |_: petgraph::graph::EdgeReference<()>| Ok::<i32, &str>(0);
let result = longest_path(&graph, weight_fn);
assert_eq!(result, Ok(Some((vec![], 0))));
}

impl Error for TestError {}

fn test_filter_fn(node: &Option<&i32>) -> Result<Option<bool>, CollectBicolorError<TestError>> {
match node {
Some(data) => Ok(Some(*data % 2 == 0)),
None => Ok(None),
}
#[test]
fn test_single_node_graph() {
let mut graph: DiGraph<(), ()> = DiGraph::new();
let n0 = graph.add_node(());
let weight_fn = |_: petgraph::graph::EdgeReference<()>| Ok::<i32, &str>(0);
let result = longest_path(&graph, weight_fn);
assert_eq!(result, Ok(Some((vec![n0], 0))));
}

fn test_color_fn(edge: &i32) -> Result<Option<usize>, CollectBicolorError<TestError>> {
Ok(Some(*edge as usize))
#[test]
fn test_dag_with_multiple_paths() {
let mut graph: DiGraph<(), i32> = DiGraph::new();
let n0 = graph.add_node(());
let n1 = graph.add_node(());
let n2 = graph.add_node(());
let n3 = graph.add_node(());
let n4 = graph.add_node(());
let n5 = graph.add_node(());
graph.add_edge(n0, n1, 3);
graph.add_edge(n0, n2, 2);
graph.add_edge(n1, n2, 1);
graph.add_edge(n1, n3, 4);
graph.add_edge(n2, n3, 2);
graph.add_edge(n3, n4, 2);
graph.add_edge(n2, n5, 1);
graph.add_edge(n4, n5, 3);
let weight_fn = |edge: petgraph::graph::EdgeReference<i32>| Ok::<i32, &str>(*edge.weight());
let result = longest_path(&graph, weight_fn);
assert_eq!(result, Ok(Some((vec![n0, n1, n3, n4, n5], 12))));
}

#[test]
fn test_collect_bicolor_runs() {
let mut graph = DiGraph::new();

let n0 = graph.add_node(1);
let n1 = graph.add_node(2);
let n2 = graph.add_node(3);
let n3 = graph.add_node(4);
let n4 = graph.add_node(5);

fn test_graph_with_cycle() {
let mut graph: DiGraph<(), i32> = DiGraph::new();
let n0 = graph.add_node(());
let n1 = graph.add_node(());
graph.add_edge(n0, n1, 1);
graph.add_edge(n1, n2, 2);
graph.add_edge(n2, n3, 1);
graph.add_edge(n3, n4, 2);
graph.add_edge(n1, n0, 1); // Creates a cycle

let result = collect_bicolor_runs::<&DiGraph<i32, i32>, _, _, (), TestError>(
&graph,
|node: &Option<&i32>| test_filter_fn(node), // Wrap in closure to match expected signature (&Option<&i32>)
test_color_fn,
);
let weight_fn = |edge: petgraph::graph::EdgeReference<i32>| Ok::<i32, &str>(*edge.weight());
let result = longest_path(&graph, weight_fn);
assert_eq!(result, Ok(None));
}

// let expected: Vec<Vec<NodeIndex>> = vec![
// vec![n1, n2, n3], // this is the expected bicolor run with colors 1 and 2
// ];
//
// assert_eq!(result, expected);
#[test]
fn test_negative_weights() {
let mut graph: DiGraph<(), i32> = DiGraph::new();
let n0 = graph.add_node(());
let n1 = graph.add_node(());
let n2 = graph.add_node(());
graph.add_edge(n0, n1, -1);
graph.add_edge(n0, n2, 2);
graph.add_edge(n1, n2, -2);
let weight_fn = |edge: petgraph::graph::EdgeReference<i32>| Ok::<i32, &str>(*edge.weight());
let result = longest_path(&graph, weight_fn);
assert_eq!(result, Ok(Some((vec![n0, n2], 2))));
}

match result {
Ok(runs) => {
// Check if the runs match the expected output
let expected: Vec<Vec<NodeIndex>> = vec![
vec![n1, n2, n3], // this is the expected bicolor run with colors 1 and 2
];
assert_eq!(runs, expected);
#[test]
fn test_longest_path_in_stable_digraph() {
let mut graph: StableDiGraph<(), i32> = StableDiGraph::new();
let n0 = graph.add_node(());
let n1 = graph.add_node(());
let n2 = graph.add_node(());
graph.add_edge(n0, n1, 1);
graph.add_edge(n0, n2, 3);
graph.add_edge(n1, n2, 1);
let weight_fn =
|edge: petgraph::stable_graph::EdgeReference<'_, i32>| Ok::<i32, &str>(*edge.weight());
let result = longest_path(&graph, weight_fn);
assert_eq!(result, Ok(Some((vec![n0, n2], 3))));
}

#[test]
fn test_error_handling() {
let mut graph: DiGraph<(), i32> = DiGraph::new();
let n0 = graph.add_node(());
let n1 = graph.add_node(());
let n2 = graph.add_node(());
graph.add_edge(n0, n1, 1);
graph.add_edge(n0, n2, 2);
graph.add_edge(n1, n2, 1);
let weight_fn = |edge: petgraph::graph::EdgeReference<i32>| {
if *edge.weight() == 2 {
Err("Error: edge weight is 2")
} else {
Ok::<i32, &str>(*edge.weight())
}
Err(e) => panic!("Test failed with error: {:?}", e),
}
};
let result = longest_path(&graph, weight_fn);
assert_eq!(result, Err("Error: edge weight is 2"));
}
}

Expand Down Expand Up @@ -833,3 +763,64 @@ mod test_lexicographical_topological_sort {
assert_eq!(result, Ok(Some(vec![nodes[7]])));
}
}

#[cfg(test)]
mod test_collect_bicolor_runs {
use super::*;
use petgraph::graph::{DiGraph, NodeIndex};

#[derive(Debug)]
struct TestError;

impl Display for TestError {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "Test error")
}
}

impl Error for TestError {}

fn test_filter_fn(node: &Option<&i32>) -> Result<Option<bool>, CollectBicolorError<TestError>> {
match node {
Some(data) => Ok(Some(*data % 2 == 0)),
None => Ok(None),
}
}

fn test_color_fn(edge: &i32) -> Result<Option<usize>, CollectBicolorError<TestError>> {
Ok(Some(*edge as usize))
}

#[test]
fn test_collect_bicolor_runs() {
let mut graph = DiGraph::new();

let n0 = graph.add_node(1);
let n1 = graph.add_node(2);
let n2 = graph.add_node(3);
let n3 = graph.add_node(4);
let n4 = graph.add_node(5);

graph.add_edge(n0, n1, 1);
graph.add_edge(n1, n2, 2);
graph.add_edge(n2, n3, 1);
graph.add_edge(n3, n4, 2);

let result = collect_bicolor_runs::<&DiGraph<i32, i32>, _, _, (), TestError>(
&graph,
|node: &Option<&i32>| test_filter_fn(node), // Wrap in closure to match expected signature (&Option<&i32>)
test_color_fn,
);

match result {
Ok(results) => {
// Check if the results match the expected output
let expected: Vec<Vec<NodeIndex>> = vec![
vec![n1, n2, n3], // this is the expected bicolor run with colors 1 and 2
];
assert_eq!(results, expected);
}
Err(e) => panic!("Test failed with error: {:?}", e),
}
}
}
Loading

0 comments on commit 87cd111

Please sign in to comment.