diff --git a/rustworkx-core/src/dag_algo.rs b/rustworkx-core/src/dag_algo.rs index edad9d8dc7..5514e511e4 100644 --- a/rustworkx-core/src/dag_algo.rs +++ b/rustworkx-core/src/dag_algo.rs @@ -316,116 +316,7 @@ where Ok(Some((path, path_weight))) } -#[cfg(test)] -mod test_longest_path { - use super::*; - use petgraph::graph::DiGraph; - use petgraph::stable_graph::StableDiGraph; - - #[test] - fn test_empty_graph() { - let graph: DiGraph<(), ()> = DiGraph::new(); - let weight_fn = |_: petgraph::graph::EdgeReference<()>| Ok::(0); - let result = longest_path(&graph, weight_fn); - assert_eq!(result, Ok(Some((vec![], 0)))); - } - - #[test] - fn test_single_node_graph() { - let mut graph: DiGraph<(), ()> = DiGraph::new(); - let n0 = graph.add_node(()); - let weight_fn = |_: petgraph::graph::EdgeReference<()>| Ok::(0); - let result = longest_path(&graph, weight_fn); - assert_eq!(result, Ok(Some((vec![n0], 0)))); - } - - #[test] - fn test_dag_with_multiple_paths() { - let mut graph: DiGraph<(), i32> = DiGraph::new(); - let n0 = graph.add_node(()); - let n1 = graph.add_node(()); - let n2 = graph.add_node(()); - let n3 = graph.add_node(()); - let n4 = graph.add_node(()); - let n5 = graph.add_node(()); - graph.add_edge(n0, n1, 3); - graph.add_edge(n0, n2, 2); - graph.add_edge(n1, n2, 1); - graph.add_edge(n1, n3, 4); - graph.add_edge(n2, n3, 2); - graph.add_edge(n3, n4, 2); - graph.add_edge(n2, n5, 1); - graph.add_edge(n4, n5, 3); - let weight_fn = |edge: petgraph::graph::EdgeReference| Ok::(*edge.weight()); - let result = longest_path(&graph, weight_fn); - assert_eq!(result, Ok(Some((vec![n0, n1, n3, n4, n5], 12)))); - } - - #[test] - fn test_graph_with_cycle() { - let mut graph: DiGraph<(), i32> = DiGraph::new(); - let n0 = graph.add_node(()); - let n1 = graph.add_node(()); - graph.add_edge(n0, n1, 1); - graph.add_edge(n1, n0, 1); // Creates a cycle - - let weight_fn = |edge: petgraph::graph::EdgeReference| Ok::(*edge.weight()); - let result = longest_path(&graph, weight_fn); - assert_eq!(result, Ok(None)); - } - - #[test] - fn test_negative_weights() { - let mut graph: DiGraph<(), i32> = DiGraph::new(); - let n0 = graph.add_node(()); - let n1 = graph.add_node(()); - let n2 = graph.add_node(()); - graph.add_edge(n0, n1, -1); - graph.add_edge(n0, n2, 2); - graph.add_edge(n1, n2, -2); - let weight_fn = |edge: petgraph::graph::EdgeReference| Ok::(*edge.weight()); - let result = longest_path(&graph, weight_fn); - assert_eq!(result, Ok(Some((vec![n0, n2], 2)))); - } - - #[test] - fn test_longest_path_in_stable_digraph() { - let mut graph: StableDiGraph<(), i32> = StableDiGraph::new(); - let n0 = graph.add_node(()); - let n1 = graph.add_node(()); - let n2 = graph.add_node(()); - graph.add_edge(n0, n1, 1); - graph.add_edge(n0, n2, 3); - graph.add_edge(n1, n2, 1); - let weight_fn = - |edge: petgraph::stable_graph::EdgeReference<'_, i32>| Ok::(*edge.weight()); - let result = longest_path(&graph, weight_fn); - assert_eq!(result, Ok(Some((vec![n0, n2], 3)))); - } - - #[test] - fn test_error_handling() { - let mut graph: DiGraph<(), i32> = DiGraph::new(); - let n0 = graph.add_node(()); - let n1 = graph.add_node(()); - let n2 = graph.add_node(()); - graph.add_edge(n0, n1, 1); - graph.add_edge(n0, n2, 2); - graph.add_edge(n1, n2, 1); - let weight_fn = |edge: petgraph::graph::EdgeReference| { - if *edge.weight() == 2 { - Err("Error: edge weight is 2") - } else { - Ok::(*edge.weight()) - } - }; - let result = longest_path(&graph, weight_fn); - assert_eq!(result, Err("Error: edge weight is 2")); - } -} - /// Define custom error classes for collect_bicolor_runs - #[derive(Debug)] pub enum CollectBicolorError { DAGWouldCycle, @@ -477,12 +368,10 @@ pub fn collect_bicolor_runs( filter_fn: F, color_fn: C, ) -> Result>, CollectBicolorError> -//OG type: PyResult>> where E: Error, - // add Option to input type because of line 135 - F: Fn(&Option<&::NodeWeight>) -> Result, CollectBicolorError>, //OG input: &PyObject, OG return: PyResult> - C: Fn(&::EdgeWeight) -> Result, CollectBicolorError>, //OG input: &PyObject, OG return: PyResult> + F: Fn(&Option<&::NodeWeight>) -> Result, CollectBicolorError>, + C: Fn(&::EdgeWeight) -> Result, CollectBicolorError>, G: NodeIndexable // can take node index type and convert to usize. It restricts node index type. + IntoNodeIdentifiers // used in toposort. Turns graph into list of nodes @@ -492,9 +381,9 @@ where + DataMap, // used to access node weights ::NodeId: Eq + Hash, { - let mut pending_list: Vec> = Vec::new(); //OG type: Vec> - let mut block_id: Vec> = Vec::new(); //OG type: Vec> - let mut block_list: Vec> = Vec::new(); //OG type: Vec> -> return + let mut pending_list: Vec> = Vec::new(); + let mut block_id: Vec> = Vec::new(); + let mut block_list: Vec> = Vec::new(); let filter_node = |node: &Option<&::NodeWeight>| -> Result, CollectBicolorError> { @@ -536,7 +425,6 @@ where // Remove null edges from color_fn let colors = colors.into_iter().flatten().collect::>(); - // &NodeIndexable::from_index(&graph, node) if colors.len() <= 2 && is_match { if colors.len() == 1 { let c0 = colors[0]; @@ -588,70 +476,112 @@ where Ok(block_list) } + #[cfg(test)] -mod test_collect_bicolor_runs { +mod test_longest_path { use super::*; - use petgraph::graph::{DiGraph, NodeIndex}; - - #[derive(Debug)] - struct TestError; + use petgraph::graph::DiGraph; + use petgraph::stable_graph::StableDiGraph; - impl Display for TestError { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "Test error") - } + #[test] + fn test_empty_graph() { + let graph: DiGraph<(), ()> = DiGraph::new(); + let weight_fn = |_: petgraph::graph::EdgeReference<()>| Ok::(0); + let result = longest_path(&graph, weight_fn); + assert_eq!(result, Ok(Some((vec![], 0)))); } - impl Error for TestError {} - - fn test_filter_fn(node: &Option<&i32>) -> Result, CollectBicolorError> { - match node { - Some(data) => Ok(Some(*data % 2 == 0)), - None => Ok(None), - } + #[test] + fn test_single_node_graph() { + let mut graph: DiGraph<(), ()> = DiGraph::new(); + let n0 = graph.add_node(()); + let weight_fn = |_: petgraph::graph::EdgeReference<()>| Ok::(0); + let result = longest_path(&graph, weight_fn); + assert_eq!(result, Ok(Some((vec![n0], 0)))); } - fn test_color_fn(edge: &i32) -> Result, CollectBicolorError> { - Ok(Some(*edge as usize)) + #[test] + fn test_dag_with_multiple_paths() { + let mut graph: DiGraph<(), i32> = DiGraph::new(); + let n0 = graph.add_node(()); + let n1 = graph.add_node(()); + let n2 = graph.add_node(()); + let n3 = graph.add_node(()); + let n4 = graph.add_node(()); + let n5 = graph.add_node(()); + graph.add_edge(n0, n1, 3); + graph.add_edge(n0, n2, 2); + graph.add_edge(n1, n2, 1); + graph.add_edge(n1, n3, 4); + graph.add_edge(n2, n3, 2); + graph.add_edge(n3, n4, 2); + graph.add_edge(n2, n5, 1); + graph.add_edge(n4, n5, 3); + let weight_fn = |edge: petgraph::graph::EdgeReference| Ok::(*edge.weight()); + let result = longest_path(&graph, weight_fn); + assert_eq!(result, Ok(Some((vec![n0, n1, n3, n4, n5], 12)))); } #[test] - fn test_collect_bicolor_runs() { - let mut graph = DiGraph::new(); - - let n0 = graph.add_node(1); - let n1 = graph.add_node(2); - let n2 = graph.add_node(3); - let n3 = graph.add_node(4); - let n4 = graph.add_node(5); - + fn test_graph_with_cycle() { + let mut graph: DiGraph<(), i32> = DiGraph::new(); + let n0 = graph.add_node(()); + let n1 = graph.add_node(()); graph.add_edge(n0, n1, 1); - graph.add_edge(n1, n2, 2); - graph.add_edge(n2, n3, 1); - graph.add_edge(n3, n4, 2); + graph.add_edge(n1, n0, 1); // Creates a cycle - let result = collect_bicolor_runs::<&DiGraph, _, _, (), TestError>( - &graph, - |node: &Option<&i32>| test_filter_fn(node), // Wrap in closure to match expected signature (&Option<&i32>) - test_color_fn, - ); + let weight_fn = |edge: petgraph::graph::EdgeReference| Ok::(*edge.weight()); + let result = longest_path(&graph, weight_fn); + assert_eq!(result, Ok(None)); + } - // let expected: Vec> = vec![ - // vec![n1, n2, n3], // this is the expected bicolor run with colors 1 and 2 - // ]; - // - // assert_eq!(result, expected); + #[test] + fn test_negative_weights() { + let mut graph: DiGraph<(), i32> = DiGraph::new(); + let n0 = graph.add_node(()); + let n1 = graph.add_node(()); + let n2 = graph.add_node(()); + graph.add_edge(n0, n1, -1); + graph.add_edge(n0, n2, 2); + graph.add_edge(n1, n2, -2); + let weight_fn = |edge: petgraph::graph::EdgeReference| Ok::(*edge.weight()); + let result = longest_path(&graph, weight_fn); + assert_eq!(result, Ok(Some((vec![n0, n2], 2)))); + } - match result { - Ok(runs) => { - // Check if the runs match the expected output - let expected: Vec> = vec![ - vec![n1, n2, n3], // this is the expected bicolor run with colors 1 and 2 - ]; - assert_eq!(runs, expected); + #[test] + fn test_longest_path_in_stable_digraph() { + let mut graph: StableDiGraph<(), i32> = StableDiGraph::new(); + let n0 = graph.add_node(()); + let n1 = graph.add_node(()); + let n2 = graph.add_node(()); + graph.add_edge(n0, n1, 1); + graph.add_edge(n0, n2, 3); + graph.add_edge(n1, n2, 1); + let weight_fn = + |edge: petgraph::stable_graph::EdgeReference<'_, i32>| Ok::(*edge.weight()); + let result = longest_path(&graph, weight_fn); + assert_eq!(result, Ok(Some((vec![n0, n2], 3)))); + } + + #[test] + fn test_error_handling() { + let mut graph: DiGraph<(), i32> = DiGraph::new(); + let n0 = graph.add_node(()); + let n1 = graph.add_node(()); + let n2 = graph.add_node(()); + graph.add_edge(n0, n1, 1); + graph.add_edge(n0, n2, 2); + graph.add_edge(n1, n2, 1); + let weight_fn = |edge: petgraph::graph::EdgeReference| { + if *edge.weight() == 2 { + Err("Error: edge weight is 2") + } else { + Ok::(*edge.weight()) } - Err(e) => panic!("Test failed with error: {:?}", e), - } + }; + let result = longest_path(&graph, weight_fn); + assert_eq!(result, Err("Error: edge weight is 2")); } } @@ -833,3 +763,64 @@ mod test_lexicographical_topological_sort { assert_eq!(result, Ok(Some(vec![nodes[7]]))); } } + +#[cfg(test)] +mod test_collect_bicolor_runs { + use super::*; + use petgraph::graph::{DiGraph, NodeIndex}; + + #[derive(Debug)] + struct TestError; + + impl Display for TestError { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "Test error") + } + } + + impl Error for TestError {} + + fn test_filter_fn(node: &Option<&i32>) -> Result, CollectBicolorError> { + match node { + Some(data) => Ok(Some(*data % 2 == 0)), + None => Ok(None), + } + } + + fn test_color_fn(edge: &i32) -> Result, CollectBicolorError> { + Ok(Some(*edge as usize)) + } + + #[test] + fn test_collect_bicolor_runs() { + let mut graph = DiGraph::new(); + + let n0 = graph.add_node(1); + let n1 = graph.add_node(2); + let n2 = graph.add_node(3); + let n3 = graph.add_node(4); + let n4 = graph.add_node(5); + + graph.add_edge(n0, n1, 1); + graph.add_edge(n1, n2, 2); + graph.add_edge(n2, n3, 1); + graph.add_edge(n3, n4, 2); + + let result = collect_bicolor_runs::<&DiGraph, _, _, (), TestError>( + &graph, + |node: &Option<&i32>| test_filter_fn(node), // Wrap in closure to match expected signature (&Option<&i32>) + test_color_fn, + ); + + match result { + Ok(results) => { + // Check if the results match the expected output + let expected: Vec> = vec![ + vec![n1, n2, n3], // this is the expected bicolor run with colors 1 and 2 + ]; + assert_eq!(results, expected); + } + Err(e) => panic!("Test failed with error: {:?}", e), + } + } +} \ No newline at end of file diff --git a/src/dag_algo/mod.rs b/src/dag_algo/mod.rs index 2478a82dff..72a9eba712 100644 --- a/src/dag_algo/mod.rs +++ b/src/dag_algo/mod.rs @@ -15,6 +15,8 @@ use hashbrown::{HashMap, HashSet}; use indexmap::IndexSet; use rustworkx_core::dictmap::InitWithHasher; +use std::error::Error; +use std::fmt::{Debug, Display, Formatter}; use super::iterators::NodeIndices; use crate::{digraph, DAGHasCycle, InvalidNode, StablePyGraph}; @@ -604,6 +606,7 @@ pub fn collect_runs( Ok(out_list) } +/// Define custom error classes for collect_bicolor_runs /// Collect runs that match a filter function given edge colors. /// /// A bicolor run is a list of group of nodes connected by edges of exactly @@ -635,125 +638,26 @@ pub fn collect_bicolor_runs( color_fn: PyObject, ) -> PyResult>> { - // let dag = &graph.graph; - // // Create a new filter function that matches the required signature - // - // let filter_node_fn = |node: &Option<&Py>| -> PyResult> { - // let res = filter_fn.call1(py, (node,))?; - // res.extract(py) - // }; - // - // let filter_node_fn_wrapper = |arg: &Option<&Py>| { - // let node_weight = arg.map(|py_node| {/* Convert py_node to NodeWeight */}); - // filter_node_fn(arg) - // }; - // - // - // // Create a new color function that matches the required signature - // let color_edge_fn = |edge: &PyObject| -> PyResult> { - // let res = color_fn.call1(py, (edge,))?; - // res.extract(py) - // }; - // // Use collect_bicolor_runs function from rustworkx-core and wrap result to match signature - // // let block_list = core_collect_bicolor_runs(dag, filter_node_fn, color_edge_fn); - // let block_list = core_collect_bicolor_runs(dag, filter_node_fn_wrapper, color_edge_fn); - // Ok(block_list) - let mut pending_list: Vec> = Vec::new(); - let mut block_id: Vec> = Vec::new(); - let mut block_list: Vec> = Vec::new(); - - let filter_node = |node: &PyObject| -> PyResult> { - let res = filter_fn.call1(py, (node,))?; + let dag = &graph.graph; + + // Extract the filter function closure from the PyObject + let filter_fn_closure = |node: &Option<&PyObject>| { + let node_value = match node { + Some(value) => *value, + None => return Ok(None) + }; + let res = filter_fn.call1(py, (node_value,))?; res.extract(py) }; - let color_edge = |edge: &PyObject| -> PyResult> { + // Extract the color function closure from the PyObject + let color_fn_closure = |edge: &PyObject|{ let res = color_fn.call1(py, (edge,))?; res.extract(py) }; - let nodes = match algo::toposort(&graph.graph, None) { - Ok(nodes) => nodes, - Err(_err) => return Err(DAGHasCycle::new_err("Sort encountered a cycle")), - }; - - // Utility for ensuring pending_list has the color index - macro_rules! ensure_vector_has_index { - ($pending_list: expr, $block_id: expr, $color: expr) => { - if $color >= $pending_list.len() { - $pending_list.resize($color + 1, Vec::new()); - $block_id.resize($color + 1, None); - } - }; - } - - for node in nodes { - if let Some(is_match) = filter_node(&graph.graph[node])? { - let raw_edges = graph - .graph - .edges_directed(node, petgraph::Direction::Outgoing); - - // Remove all edges that do not yield errors from color_fn - let colors = raw_edges - .map(|edge| { - let edge_weight = edge.weight(); - color_edge(edge_weight) - }) - .collect::>>>()?; - - // Remove null edges from color_fn - let colors = colors.into_iter().flatten().collect::>(); - - if colors.len() <= 2 && is_match { - if colors.len() == 1 { - let c0 = colors[0]; - ensure_vector_has_index!(pending_list, block_id, c0); - if let Some(c0_block_id) = block_id[c0] { - block_list[c0_block_id].push(graph.graph[node].clone_ref(py)); - } else { - pending_list[c0].push(graph.graph[node].clone_ref(py)); - } - } else if colors.len() == 2 { - let c0 = colors[0]; - let c1 = colors[1]; - ensure_vector_has_index!(pending_list, block_id, c0); - ensure_vector_has_index!(pending_list, block_id, c1); - - if block_id[c0].is_some() - && block_id[c1].is_some() - && block_id[c0] == block_id[c1] - { - block_list[block_id[c0].unwrap_or_default()] - .push(graph.graph[node].clone_ref(py)); - } else { - let mut new_block: Vec = - Vec::with_capacity(pending_list[c0].len() + pending_list[c1].len() + 1); - - // Clears pending lits and add to new block - new_block.append(&mut pending_list[c0]); - new_block.append(&mut pending_list[c1]); - - new_block.push(graph.graph[node].clone_ref(py)); - - // Create new block, assign its id to color pair - block_id[c0] = Some(block_list.len()); - block_id[c1] = Some(block_list.len()); - block_list.push(new_block); - } - } - } else { - for color in colors { - ensure_vector_has_index!(pending_list, block_id, color); - if let Some(color_block_id) = block_id[color] { - block_list[color_block_id].append(&mut pending_list[color]); - } - block_id[color] = None; - pending_list[color].clear(); - } - } - } - } - + // Use the closures as arguments for core_collect_bicolor_runs + let block_list = core_collect_bicolor_runs(dag, filter_fn_closure, color_fn_closure); Ok(block_list) }