From c7a7d53dbdbee4597cced6a44de0bbdf7b927898 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Elena=20Pe=C3=B1a=20Tapia?= <57907331+ElePT@users.noreply.github.com> Date: Mon, 10 Jun 2024 13:32:19 +0200 Subject: [PATCH] Move `collect_bicolor_runs()` to rustworkx-core (#1186) * Attempt with graph.node_weight * Progress until hitting error. * Return node ids instead of node weights. Code compiles!!! * Move collect_bicolor_runs.rs to rustworkx-core/src/dag_algo.rs * Add pyo3 wrapper for core collect_bicolor_runs, modify test, small fixes. * Fix DagHasCycle error to match previous type * Clean up unused parameters and functions, simplify error handling, add unit tests and reno. * Use PyResult, fix comments * Apply feedback from Matt's review, make tests make more sense. * Add suggestion from Matt's code review Co-authored-by: Matthew Treinish * Remove stray print Co-authored-by: Matthew Treinish * Change callbacks to take ids instead of weights, rename some variables in tests for clarity. * Fix test * Apply suggestions from code review Co-authored-by: Matthew Treinish * Get rid of unused trait * Fix fmt --------- Co-authored-by: Matthew Treinish --- ...collect-bicolor-runs-ff95d7fdbc546e2e.yaml | 6 + rustworkx-core/src/dag_algo.rs | 387 +++++++++++++++++- rustworkx-core/src/lib.rs | 1 + src/dag_algo/mod.rs | 112 ++--- tests/digraph/test_collect_bicolor_runs.py | 24 +- 5 files changed, 423 insertions(+), 107 deletions(-) create mode 100644 releasenotes/notes/migrate-collect-bicolor-runs-ff95d7fdbc546e2e.yaml diff --git a/releasenotes/notes/migrate-collect-bicolor-runs-ff95d7fdbc546e2e.yaml b/releasenotes/notes/migrate-collect-bicolor-runs-ff95d7fdbc546e2e.yaml new file mode 100644 index 000000000..51705416b --- /dev/null +++ b/releasenotes/notes/migrate-collect-bicolor-runs-ff95d7fdbc546e2e.yaml @@ -0,0 +1,6 @@ +--- +features: + - | + Added a new function ``collect_bicolor_runs`` to rustworkx-core's ``dag_algo`` module. + Previously, the ``collect_bicolor_runs`` functionality for DAGs was only exposed + via the Python interface. Now Rust users can take advantage of this functionality in ``rustworkx-core``. diff --git a/rustworkx-core/src/dag_algo.rs b/rustworkx-core/src/dag_algo.rs index fb629069c..081802fa7 100644 --- a/rustworkx-core/src/dag_algo.rs +++ b/rustworkx-core/src/dag_algo.rs @@ -17,6 +17,7 @@ use std::hash::Hash; use hashbrown::HashMap; use petgraph::algo; +use petgraph::data::DataMap; use petgraph::visit::{ EdgeRef, GraphBase, GraphProp, IntoEdgesDirected, IntoNeighborsDirected, IntoNodeIdentifiers, NodeCount, Visitable, @@ -313,6 +314,168 @@ where Ok(Some((path, path_weight))) } +/// Collect runs that match a filter function given edge colors. +/// +/// A bicolor run is a list of groups of nodes connected by edges of exactly +/// two colors. In addition, all nodes in the group must match the given +/// condition. Each node in the graph can appear in only a single group +/// in the bicolor run. +/// +/// # Arguments: +/// +/// * `dag`: The DAG to find bicolor runs in +/// * `filter_fn`: The filter function to use for matching nodes. It takes +/// in one argument, the node data payload/weight object, and will return a +/// boolean whether the node matches the conditions or not. +/// If it returns ``true``, it will continue the bicolor chain. +/// If it returns ``false``, it will stop the bicolor chain. +/// If it returns ``None`` it will skip that node. +/// * `color_fn`: The function that gives the color of the edge. It takes +/// in one argument, the edge data payload/weight object, and will +/// return a non-negative integer, the edge color. If the color is None, +/// the edge is ignored. +/// +/// # Returns: +/// +/// * `Vec>`: a list of groups with exactly two edge colors, where each group +/// is a list of node data payload/weight for the nodes in the bicolor run +/// * `None` if a cycle is found in the graph +/// * Raises an error if found computing the bicolor runs +/// +/// # Example: +/// +/// ```rust +/// use rustworkx_core::dag_algo::collect_bicolor_runs; +/// use petgraph::graph::{DiGraph, NodeIndex}; +/// use std::convert::Infallible; +/// use std::error::Error; +/// +/// let mut graph = DiGraph::new(); +/// let n0 = graph.add_node(0); +/// let n1 = graph.add_node(0); +/// let n2 = graph.add_node(1); +/// let n3 = graph.add_node(1); +/// let n4 = graph.add_node(0); +/// let n5 = graph.add_node(0); +/// graph.add_edge(n0, n2, 0); +/// graph.add_edge(n1, n2, 1); +/// graph.add_edge(n2, n3, 0); +/// graph.add_edge(n2, n3, 1); +/// graph.add_edge(n3, n4, 0); +/// graph.add_edge(n3, n5, 1); +/// +/// let filter_fn = |node_id| -> Result, Infallible> { +/// Ok(Some(*graph.node_weight(node_id).unwrap() > 0)) +/// }; +/// +/// let color_fn = |edge_id| -> Result, Infallible> { +/// Ok(Some(*graph.edge_weight(edge_id).unwrap() as usize)) +/// }; +/// +/// let result = collect_bicolor_runs(&graph, filter_fn, color_fn).unwrap(); +/// let expected: Vec> = vec![vec![n2, n3]]; +/// assert_eq!(result, Some(expected)) +/// ``` +pub fn collect_bicolor_runs( + graph: G, + filter_fn: F, + color_fn: C, +) -> Result>>, E> +where + F: Fn(::NodeId) -> Result, E>, + C: Fn(::EdgeId) -> Result, E>, + G: IntoNodeIdentifiers // Used in toposort + + IntoNeighborsDirected // Used in toposort + + IntoEdgesDirected // Used for .edges_directed + + Visitable // Used in toposort + + DataMap, // Used for .node_weight + ::NodeId: Eq + Hash, + ::EdgeId: Eq + Hash, +{ + let mut pending_list: Vec> = Vec::new(); + let mut block_id: Vec> = Vec::new(); + let mut block_list: Vec> = Vec::new(); + + let nodes = match algo::toposort(graph, None) { + Ok(nodes) => nodes, + Err(_) => return Ok(None), // Return None if the graph contains a cycle + }; + + // Utility for ensuring pending_list has the color index + macro_rules! ensure_vector_has_index { + ($pending_list: expr, $block_id: expr, $color: expr) => { + if $color >= $pending_list.len() { + $pending_list.resize($color + 1, Vec::new()); + $block_id.resize($color + 1, None); + } + }; + } + + for node in nodes { + if let Some(is_match) = filter_fn(node)? { + let raw_edges = graph.edges_directed(node, petgraph::Direction::Outgoing); + + // Remove all edges that yield errors from color_fn + let colors = raw_edges + .map(|edge| color_fn(edge.id())) + .collect::>, _>>()?; + + // Remove null edges from color_fn + let colors = colors.into_iter().flatten().collect::>(); + + if colors.len() <= 2 && is_match { + if colors.len() == 1 { + let c0 = colors[0]; + ensure_vector_has_index!(pending_list, block_id, c0); + if let Some(c0_block_id) = block_id[c0] { + block_list[c0_block_id].push(node); + } else { + pending_list[c0].push(node); + } + } else if colors.len() == 2 { + let c0 = colors[0]; + let c1 = colors[1]; + ensure_vector_has_index!(pending_list, block_id, c0); + ensure_vector_has_index!(pending_list, block_id, c1); + + if block_id[c0].is_some() + && block_id[c1].is_some() + && block_id[c0] == block_id[c1] + { + block_list[block_id[c0].unwrap_or_default()].push(node); + } else { + let mut new_block: Vec = + Vec::with_capacity(pending_list[c0].len() + pending_list[c1].len() + 1); + + // Clears pending lits and add to new block + new_block.append(&mut pending_list[c0]); + new_block.append(&mut pending_list[c1]); + + new_block.push(node); + + // Create new block, assign its id to color pair + block_id[c0] = Some(block_list.len()); + block_id[c1] = Some(block_list.len()); + block_list.push(new_block); + } + } + } else { + for color in colors { + ensure_vector_has_index!(pending_list, block_id, color); + if let Some(color_block_id) = block_id[color] { + block_list[color_block_id].append(&mut pending_list[color]); + } + block_id[color] = None; + pending_list[color].clear(); + } + } + } + } + + Ok(Some(block_list)) +} + +// Tests for longest_path #[cfg(test)] mod test_longest_path { use super::*; @@ -421,13 +584,7 @@ mod test_longest_path { } } -// pub fn lexicographical_topological_sort( -// dag: G, -// mut key: F, -// reverse: bool, -// initial: Option<&[G::NodeId]>, -// ) -> Result>, E> - +// Tests for lexicographical_topological_sort #[cfg(test)] mod test_lexicographical_topological_sort { use super::*; @@ -599,3 +756,219 @@ mod test_lexicographical_topological_sort { assert_eq!(result, Ok(Some(vec![nodes[7]]))); } } + +// Tests for collect_bicolor_runs +#[cfg(test)] +mod test_collect_bicolor_runs { + + use super::*; + use petgraph::graph::{DiGraph, EdgeIndex, NodeIndex}; + use std::error::Error; + + #[test] + fn test_cycle() { + let mut graph = DiGraph::new(); + let n0 = graph.add_node(0); + let n1 = graph.add_node(0); + let n2 = graph.add_node(0); + graph.add_edge(n0, n1, 1); + graph.add_edge(n1, n2, 1); + graph.add_edge(n2, n0, 1); + + let test_filter_fn = + |_node_id: NodeIndex| -> Result, Box> { Ok(Some(true)) }; + let test_color_fn = + |_edge_id: EdgeIndex| -> Result, Box> { Ok(Some(1)) }; + let result = match collect_bicolor_runs(&graph, test_filter_fn, test_color_fn) { + Ok(Some(_value)) => "Not None", + Ok(None) => "None", + Err(_) => "Error", + }; + assert_eq!(result, "None") + } + + #[test] + fn test_filter_function_inner_exception() { + let mut graph = DiGraph::new(); + graph.add_node(0); + + let fail_function = |node_id: NodeIndex| -> Result, Box> { + let node_weight: &i32 = graph.node_weight(node_id).expect("Invalid NodeId"); + if *node_weight > 0 { + Ok(Some(true)) + } else { + Err(Box::from("Failed!")) + } + }; + let test_color_fn = |edge_id: EdgeIndex| -> Result, Box> { + let edge_weight: &i32 = graph.edge_weight(edge_id).expect("Invalid Edge"); + Ok(Some(*edge_weight as usize)) + }; + let result = match collect_bicolor_runs(&graph, fail_function, test_color_fn) { + Ok(Some(_value)) => "Not None", + Ok(None) => "None", + Err(_) => "Error", + }; + assert_eq!(result, "Error") + } + + #[test] + fn test_empty() { + let graph = DiGraph::new(); + let test_filter_fn = |node_id: NodeIndex| -> Result, Box> { + let node_weight: &i32 = graph.node_weight(node_id).expect("Invalid NodeId"); + Ok(Some(*node_weight > 1)) + }; + let test_color_fn = |edge_id: EdgeIndex| -> Result, Box> { + let edge_weight: &i32 = graph.edge_weight(edge_id).expect("Invalid Edge"); + Ok(Some(*edge_weight as usize)) + }; + let result = collect_bicolor_runs(&graph, test_filter_fn, test_color_fn).unwrap(); + let expected: Vec> = vec![]; + assert_eq!(result, Some(expected)) + } + + #[test] + fn test_two_colors() { + /* Based on the following graph from the Python unit tests: + Input: + ┌─────────────┐ ┌─────────────┐ + │ │ │ │ + │ q0 │ │ q1 │ + │ │ │ │ + └───┬─────────┘ └──────┬──────┘ + │ ┌─────────────┐ │ + q0 │ │ │ │ q1 + │ │ │ │ + └─────────►│ cx │◄────────┘ + ┌──────────┤ ├─────────┐ + │ │ │ │ + q0 │ └─────────────┘ │ q1 + │ │ + │ ┌─────────────┐ │ + │ │ │ │ + └─────────►│ cz │◄────────┘ + ┌─────────┤ ├─────────┐ + │ └─────────────┘ │ + q0 │ │ q1 + │ │ + ┌───▼─────────┐ ┌──────▼──────┐ + │ │ │ │ + │ q0 │ │ q1 │ + │ │ │ │ + └─────────────┘ └─────────────┘ + Expected: [[cx, cz]] + */ + let mut graph = DiGraph::new(); + let n0 = graph.add_node(0); //q0 + let n1 = graph.add_node(0); //q1 + let n2 = graph.add_node(1); //cx + let n3 = graph.add_node(1); //cz + let n4 = graph.add_node(0); //q0_1 + let n5 = graph.add_node(0); //q1_1 + graph.add_edge(n0, n2, 0); //q0 -> cx + graph.add_edge(n1, n2, 1); //q1 -> cx + graph.add_edge(n2, n3, 0); //cx -> cz + graph.add_edge(n2, n3, 1); //cx -> cz + graph.add_edge(n3, n4, 0); //cz -> q0_1 + graph.add_edge(n3, n5, 1); //cz -> q1_1 + + // Filter out q0, q1, q0_1 and q1_1 + let test_filter_fn = |node_id: NodeIndex| -> Result, Box> { + let node_weight: &i32 = graph.node_weight(node_id).expect("Invalid NodeId"); + Ok(Some(*node_weight > 0)) + }; + // The edge color will match its weight + let test_color_fn = |edge_id: EdgeIndex| -> Result, Box> { + let edge_weight: &i32 = graph.edge_weight(edge_id).expect("Invalid Edge"); + Ok(Some(*edge_weight as usize)) + }; + let result = collect_bicolor_runs(&graph, test_filter_fn, test_color_fn).unwrap(); + let expected: Vec> = vec![vec![n2, n3]]; //[[cx, cz]] + assert_eq!(result, Some(expected)) + } + + #[test] + fn test_two_colors_with_pending() { + /* Based on the following graph from the Python unit tests: + Input: + ┌─────────────┐ + │ │ + │ q0 │ + │ │ + └───┬─────────┘ + | q0 + │ + ┌───▼─────────┐ + │ │ + │ h │ + │ │ + └───┬─────────┘ + | q0 + │ ┌─────────────┐ + │ │ │ + │ │ q1 │ + │ │ │ + | └──────┬──────┘ + │ ┌─────────────┐ │ + q0 │ │ │ │ q1 + │ │ │ │ + └─────────►│ cx │◄────────┘ + ┌──────────┤ ├─────────┐ + │ │ │ │ + q0 │ └─────────────┘ │ q1 + │ │ + │ ┌─────────────┐ │ + │ │ │ │ + └─────────►│ cz │◄────────┘ + ┌─────────┤ ├─────────┐ + │ └─────────────┘ │ + q0 │ │ q1 + │ │ + ┌───▼─────────┐ ┌──────▼──────┐ + │ │ │ │ + │ q0 │ │ y │ + │ │ │ │ + └─────────────┘ └─────────────┘ + | q1 + │ + ┌───▼─────────┐ + │ │ + │ q1 │ + │ │ + └─────────────┘ + Expected: [[h, cx, cz, y]] + */ + let mut graph = DiGraph::new(); + let n0 = graph.add_node(0); //q0 + let n1 = graph.add_node(0); //q1 + let n2 = graph.add_node(1); //h + let n3 = graph.add_node(1); //cx + let n4 = graph.add_node(1); //cz + let n5 = graph.add_node(1); //y + let n6 = graph.add_node(0); //q0_1 + let n7 = graph.add_node(0); //q1_1 + graph.add_edge(n0, n2, 0); //q0 -> h + graph.add_edge(n2, n3, 0); //h -> cx + graph.add_edge(n1, n3, 1); //q1 -> cx + graph.add_edge(n3, n4, 0); //cx -> cz + graph.add_edge(n3, n4, 1); //cx -> cz + graph.add_edge(n4, n6, 0); //cz -> q0_1 + graph.add_edge(n4, n5, 1); //cz -> y + graph.add_edge(n5, n7, 1); //y -> q1_1 + + // Filter out q0, q1, q0_1 and q1_1 + let test_filter_fn = |node_id: NodeIndex| -> Result, Box> { + let node_weight: &i32 = graph.node_weight(node_id).expect("Invalid NodeId"); + Ok(Some(*node_weight > 0)) + }; + // The edge color will match its weight + let test_color_fn = |edge_id: EdgeIndex| -> Result, Box> { + let edge_weight: &i32 = graph.edge_weight(edge_id).expect("Invalid Edge"); + Ok(Some(*edge_weight as usize)) + }; + let result = collect_bicolor_runs(&graph, test_filter_fn, test_color_fn).unwrap(); + let expected: Vec> = vec![vec![n2, n3, n4, n5]]; //[[h, cx, cz, y]] + assert_eq!(result, Some(expected)) + } +} diff --git a/rustworkx-core/src/lib.rs b/rustworkx-core/src/lib.rs index 6be610425..fc5d6f5df 100644 --- a/rustworkx-core/src/lib.rs +++ b/rustworkx-core/src/lib.rs @@ -98,6 +98,7 @@ pub mod centrality; /// Module for coloring algorithms. pub mod coloring; pub mod connectivity; +/// Module for algorithms that work on DAGs. pub mod dag_algo; pub mod generators; pub mod graph_ext; diff --git a/src/dag_algo/mod.rs b/src/dag_algo/mod.rs index f3138384c..99ccc86c9 100644 --- a/src/dag_algo/mod.rs +++ b/src/dag_algo/mod.rs @@ -18,6 +18,7 @@ use rustworkx_core::dictmap::InitWithHasher; use super::iterators::NodeIndices; use crate::{digraph, DAGHasCycle, InvalidNode, StablePyGraph}; +use rustworkx_core::dag_algo::collect_bicolor_runs as core_collect_bicolor_runs; use rustworkx_core::dag_algo::lexicographical_topological_sort as core_lexico_topo_sort; use rustworkx_core::dag_algo::longest_path as core_longest_path; use rustworkx_core::traversal::dfs_edges; @@ -603,7 +604,7 @@ pub fn collect_runs( Ok(out_list) } -/// Collect runs that match a filter function given edge colors +/// Collect runs that match a filter function given edge colors. /// /// A bicolor run is a list of group of nodes connected by edges of exactly /// two colors. In addition, all nodes in the group must match the given @@ -633,102 +634,37 @@ pub fn collect_bicolor_runs( filter_fn: PyObject, color_fn: PyObject, ) -> PyResult>> { - let mut pending_list: Vec> = Vec::new(); - let mut block_id: Vec> = Vec::new(); - let mut block_list: Vec> = Vec::new(); + let dag = &graph.graph; - let filter_node = |node: &PyObject| -> PyResult> { - let res = filter_fn.call1(py, (node,))?; + let filter_fn_wrapper = |node_index| -> Result, PyErr> { + let node_weight = dag.node_weight(node_index).expect("Invalid NodeId"); + let res = filter_fn.call1(py, (node_weight,))?; res.extract(py) }; - let color_edge = |edge: &PyObject| -> PyResult> { - let res = color_fn.call1(py, (edge,))?; + let color_fn_wrapper = |edge_index| -> Result, PyErr> { + let edge_weight = dag.edge_weight(edge_index).expect("Invalid EdgeId"); + let res = color_fn.call1(py, (edge_weight,))?; res.extract(py) }; - let nodes = match algo::toposort(&graph.graph, None) { - Ok(nodes) => nodes, - Err(_err) => return Err(DAGHasCycle::new_err("Sort encountered a cycle")), + let block_list = match core_collect_bicolor_runs(dag, filter_fn_wrapper, color_fn_wrapper) { + Ok(Some(block_list)) => block_list + .into_iter() + .map(|index_list| { + index_list + .into_iter() + .map(|node_index| { + let node_weight = dag.node_weight(node_index).expect("Invalid NodeId"); + node_weight.into_py(py) + }) + .collect() + }) + .collect(), + Ok(None) => return Err(DAGHasCycle::new_err("The graph contains a cycle")), + Err(e) => return Err(e), }; - // Utility for ensuring pending_list has the color index - macro_rules! ensure_vector_has_index { - ($pending_list: expr, $block_id: expr, $color: expr) => { - if $color >= $pending_list.len() { - $pending_list.resize($color + 1, Vec::new()); - $block_id.resize($color + 1, None); - } - }; - } - - for node in nodes { - if let Some(is_match) = filter_node(&graph.graph[node])? { - let raw_edges = graph - .graph - .edges_directed(node, petgraph::Direction::Outgoing); - - // Remove all edges that do not yield errors from color_fn - let colors = raw_edges - .map(|edge| { - let edge_weight = edge.weight(); - color_edge(edge_weight) - }) - .collect::>>>()?; - - // Remove null edges from color_fn - let colors = colors.into_iter().flatten().collect::>(); - - if colors.len() <= 2 && is_match { - if colors.len() == 1 { - let c0 = colors[0]; - ensure_vector_has_index!(pending_list, block_id, c0); - if let Some(c0_block_id) = block_id[c0] { - block_list[c0_block_id].push(graph.graph[node].clone_ref(py)); - } else { - pending_list[c0].push(graph.graph[node].clone_ref(py)); - } - } else if colors.len() == 2 { - let c0 = colors[0]; - let c1 = colors[1]; - ensure_vector_has_index!(pending_list, block_id, c0); - ensure_vector_has_index!(pending_list, block_id, c1); - - if block_id[c0].is_some() - && block_id[c1].is_some() - && block_id[c0] == block_id[c1] - { - block_list[block_id[c0].unwrap_or_default()] - .push(graph.graph[node].clone_ref(py)); - } else { - let mut new_block: Vec = - Vec::with_capacity(pending_list[c0].len() + pending_list[c1].len() + 1); - - // Clears pending lits and add to new block - new_block.append(&mut pending_list[c0]); - new_block.append(&mut pending_list[c1]); - - new_block.push(graph.graph[node].clone_ref(py)); - - // Create new block, assign its id to color pair - block_id[c0] = Some(block_list.len()); - block_id[c1] = Some(block_list.len()); - block_list.push(new_block); - } - } - } else { - for color in colors { - ensure_vector_has_index!(pending_list, block_id, color); - if let Some(color_block_id) = block_id[color] { - block_list[color_block_id].append(&mut pending_list[color]); - } - block_id[color] = None; - pending_list[color].clear(); - } - } - } - } - Ok(block_list) } diff --git a/tests/digraph/test_collect_bicolor_runs.py b/tests/digraph/test_collect_bicolor_runs.py index b0a903e15..8a47c6740 100644 --- a/tests/digraph/test_collect_bicolor_runs.py +++ b/tests/digraph/test_collect_bicolor_runs.py @@ -98,9 +98,9 @@ def filter_function(node): else: return None - def color_function(node): - if "q" in node: - return int(node[1:]) + def color_function(edge): + if "q" in edge: + return int(edge[1:]) else: return None @@ -187,9 +187,9 @@ def filter_function(node): else: return None - def color_function(node): - if "q" in node: - return int(node[1:]) + def color_function(edge): + if "q" in edge: + return int(edge[1:]) else: return None @@ -264,9 +264,9 @@ def filter_function(node): else: return None - def color_function(node): - if "q" in node: - return int(node[1:]) + def color_function(edge): + if "q" in edge: + return int(edge[1:]) else: return None @@ -338,9 +338,9 @@ def filter_function(node): else: return None - def color_function(node): - if "q" in node: - return int(node[1:]) + def color_function(edge): + if "q" in edge: + return int(edge[1:]) else: return None