From 55f3b1a4e3a031841ee2c4db25fb72969c992942 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Elena=20Pe=C3=B1a=20Tapia?= Date: Wed, 8 May 2024 10:10:36 +0200 Subject: [PATCH 01/16] Attempt with graph.node_weight --- rustworkx-core/src/collect_bicolor_runs.rs | 202 +++++++++++++++++++++ rustworkx-core/src/lib.rs | 1 + 2 files changed, 203 insertions(+) create mode 100644 rustworkx-core/src/collect_bicolor_runs.rs diff --git a/rustworkx-core/src/collect_bicolor_runs.rs b/rustworkx-core/src/collect_bicolor_runs.rs new file mode 100644 index 000000000..98126725f --- /dev/null +++ b/rustworkx-core/src/collect_bicolor_runs.rs @@ -0,0 +1,202 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations +// under the License. + +use std::cmp::Eq; +use std::error::Error; +use std::hash::Hash; +use std::fmt::{Debug, Display, Formatter}; + +use petgraph::algo; +use petgraph::visit::Data; +use petgraph::data::DataMap; +use petgraph::visit::{EdgeRef, GraphBase, GraphProp, IntoNeighborsDirected, IntoNodeIdentifiers, NodeCount, NodeIndexable, Visitable, IntoEdgesDirected}; + + +#[derive(Debug)] +pub enum CollectBicolorError { + DAGWouldCycle, +} + +impl Display for CollectBicolorError { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + CollectBicolorError::DAGWouldCycle => fmt_dag_would_cycle(f), + } + } +} + +impl Error for CollectBicolorError {} + +#[derive(Debug)] +pub enum CollectBicolorSimpleError { + DAGWouldCycle, + MergeError(E), //placeholder, may remove if not used +} + +impl Display for CollectBicolorSimpleError { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + CollectBicolorSimpleError::DAGWouldCycle => fmt_dag_would_cycle(f), + CollectBicolorSimpleError::MergeError(ref e) => fmt_merge_error(f, e), + } + } +} + +impl Error for CollectBicolorSimpleError {} + +fn fmt_dag_would_cycle(f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "The operation would introduce a cycle.") +} + +fn fmt_merge_error(f: &mut Formatter<'_>, inner: &E) -> std::fmt::Result { + write!(f, "The prov failed with: {:?}", inner) +} +/// Collect runs that match a filter function given edge colors +/// +/// A bicolor run is a list of group of nodes connected by edges of exactly +/// two colors. In addition, all nodes in the group must match the given +/// condition. Each node in the graph can appear in only a single group +/// in the bicolor run. +/// +/// :param PyDiGraph graph: The graph to find runs in +/// :param filter_fn: The filter function to use for matching nodes. It takes +/// in one argument, the node data payload/weight object, and will return a +/// boolean whether the node matches the conditions or not. +/// If it returns ``True``, it will continue the bicolor chain. +/// If it returns ``False``, it will stop the bicolor chain. +/// If it returns ``None`` it will skip that node. +/// :param color_fn: The function that gives the color of the edge. It takes +/// in one argument, the edge data payload/weight object, and will +/// return a non-negative integer, the edge color. If the color is None, +/// the edge is ignored. +/// +/// :returns: a list of groups with exactly two edge colors, where each group +/// is a list of node data payload/weight for the nodes in the bicolor run +/// :rtype: list +pub fn collect_bicolor_runs( + graph: G, + filter_fn: F, + color_fn: C, +) -> Result>, CollectBicolorSimpleError> //OG type: PyResult>> +where + E: Error, + // add option because of line 135 + F: FnMut(&Option<&::NodeWeight>) -> Result, CollectBicolorSimpleError>, //OG input: &PyObject, OG return: PyResult> + C: FnMut(&::EdgeWeight) -> Result, CollectBicolorSimpleError>, //OG input: &PyObject, OG return: PyResult> + G: NodeIndexable //can take node index type and convert to usize. It restricts node index type. + + IntoNodeIdentifiers //turn graph into list of nodes + + IntoNeighborsDirected // toposort + + IntoEdgesDirected + + Visitable // toposort + + GraphProp // gives access to whether graph is directed + + NodeCount + + DataMap, + ::NodeId: Eq + Hash, +{ + let mut pending_list = Vec::new(); //OG type: Vec> + let mut block_id = Vec::new(); //OG type: Vec> + let mut block_list = Vec::new(); //OG type: Vec> -> return + + let filter_node = |node: &Option<&::NodeWeight>| -> Result, CollectBicolorSimpleError>{ + let res = filter_fn(node); + res + }; + + let color_edge = |edge: &::EdgeWeight| -> Result, CollectBicolorSimpleError>{ + let res = color_fn(edge); + res + }; + + let nodes = match algo::toposort(&graph, None){ + Ok(nodes) => nodes, + Err(_err) => return Err(CollectBicolorSimpleError::DAGWouldCycle) + }; + + // Utility for ensuring pending_list has the color index + macro_rules! ensure_vector_has_index { + ($pending_list: expr, $block_id: expr, $color: expr) => { + if $color >= $pending_list.len() { + $pending_list.resize($color + 1, Vec::new()); + $block_id.resize($color + 1, None); + } + }; + } + + // tried unsuccessfully &NodeIndexable::from_index(&graph, node) + for node in nodes { + if let Some(is_match) = filter_node(&graph.node_weight(node))? { + let raw_edges = graph + .edges_directed(node, petgraph::Direction::Outgoing); + + // Remove all edges that do not yield errors from color_fn + let colors = raw_edges + .map(|edge| { + let edge_weight = edge.weight(); + color_edge(edge_weight) + }) + .collect::>, _>>()?; + + // Remove null edges from color_fn + let colors = colors.into_iter().flatten().collect::>(); + + if colors.len() <= 2 && is_match { + if colors.len() == 1 { + let c0 = colors[0]; + ensure_vector_has_index!(pending_list, block_id, c0); + if let Some(c0_block_id) = block_id[c0] { + block_list[c0_block_id].push(graph.node_weight(node)); + } else { + pending_list[c0].push(graph.node_weight(node)); + } + } else if colors.len() == 2 { + let c0 = colors[0]; + let c1 = colors[1]; + ensure_vector_has_index!(pending_list, block_id, c0); + ensure_vector_has_index!(pending_list, block_id, c1); + + if block_id[c0].is_some() + && block_id[c1].is_some() + && block_id[c0] == block_id[c1] + { + block_list[block_id[c0].unwrap_or_default()] + .push(graph.node_weight(node)); + } else { + let mut new_block: Vec::NodeWeight>> = + Vec::with_capacity(pending_list[c0].len() + pending_list[c1].len() + 1); + + // Clears pending lits and add to new block + new_block.append(&mut pending_list[c0]); + new_block.append(&mut pending_list[c1]); + + new_block.push(graph.node_weight(node)); + + // Create new block, assign its id to color pair + block_id[c0] = Some(block_list.len()); + block_id[c1] = Some(block_list.len()); + block_list.push(new_block); + } + } + } else { + for color in colors { + ensure_vector_has_index!(pending_list, block_id, color); + if let Some(color_block_id) = block_id[color] { + block_list[color_block_id].append(&mut pending_list[color]); + } + block_id[color] = None; + pending_list[color].clear(); + } + } + } + } + + Ok(block_list) +} \ No newline at end of file diff --git a/rustworkx-core/src/lib.rs b/rustworkx-core/src/lib.rs index e5d38eb58..709aad493 100644 --- a/rustworkx-core/src/lib.rs +++ b/rustworkx-core/src/lib.rs @@ -78,6 +78,7 @@ pub mod coloring; pub mod connectivity; pub mod generators; pub mod line_graph; +pub mod collect_bicolor_runs; /// Module for maximum weight matching algorithms. pub mod max_weight_matching; From 59f27c9bdc87156cbc3514ce2d469d49f29f0792 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Elena=20Pe=C3=B1a=20Tapia?= Date: Wed, 8 May 2024 14:10:54 +0200 Subject: [PATCH 02/16] Progress until hitting error. --- rustworkx-core/src/collect_bicolor_runs.rs | 41 +++++++++++----------- 1 file changed, 21 insertions(+), 20 deletions(-) diff --git a/rustworkx-core/src/collect_bicolor_runs.rs b/rustworkx-core/src/collect_bicolor_runs.rs index 98126725f..64808992e 100644 --- a/rustworkx-core/src/collect_bicolor_runs.rs +++ b/rustworkx-core/src/collect_bicolor_runs.rs @@ -18,9 +18,11 @@ use std::fmt::{Debug, Display, Formatter}; use petgraph::algo; use petgraph::visit::Data; use petgraph::data::DataMap; -use petgraph::visit::{EdgeRef, GraphBase, GraphProp, IntoNeighborsDirected, IntoNodeIdentifiers, NodeCount, NodeIndexable, Visitable, IntoEdgesDirected}; +use petgraph::visit::{EdgeRef, GraphBase, IntoNeighborsDirected, IntoNodeIdentifiers, NodeIndexable, Visitable, IntoEdgesDirected}; +// Taken from Kevin's PR, but we probably don't need the enum (no MergeError either) +// TODO: clean up once the code compiles #[derive(Debug)] pub enum CollectBicolorError { DAGWouldCycle, @@ -86,32 +88,32 @@ pub fn collect_bicolor_runs( graph: G, filter_fn: F, color_fn: C, -) -> Result>, CollectBicolorSimpleError> //OG type: PyResult>> +) -> Result::NodeWeight>>, CollectBicolorSimpleError> //OG type: PyResult>> where E: Error, - // add option because of line 135 + // add Option to input type because of line 135 F: FnMut(&Option<&::NodeWeight>) -> Result, CollectBicolorSimpleError>, //OG input: &PyObject, OG return: PyResult> C: FnMut(&::EdgeWeight) -> Result, CollectBicolorSimpleError>, //OG input: &PyObject, OG return: PyResult> - G: NodeIndexable //can take node index type and convert to usize. It restricts node index type. - + IntoNodeIdentifiers //turn graph into list of nodes - + IntoNeighborsDirected // toposort - + IntoEdgesDirected - + Visitable // toposort - + GraphProp // gives access to whether graph is directed - + NodeCount - + DataMap, + G: NodeIndexable // can take node index type and convert to usize. It restricts node index type. + + IntoNodeIdentifiers // used in toposort. Turns graph into list of nodes + + IntoNeighborsDirected // used in toposort + + IntoEdgesDirected // used in line 138 + + Visitable // used in toposort + + DataMap, // used to access node weights ::NodeId: Eq + Hash, { - let mut pending_list = Vec::new(); //OG type: Vec> - let mut block_id = Vec::new(); //OG type: Vec> - let mut block_list = Vec::new(); //OG type: Vec> -> return + let mut pending_list: Vec::NodeWeight>> = Vec::new(); //OG type: Vec> + let mut block_id: Vec> = Vec::new(); //OG type: Vec> + let mut block_list: Vec::NodeWeight>> = Vec::new(); //OG type: Vec> -> return let filter_node = |node: &Option<&::NodeWeight>| -> Result, CollectBicolorSimpleError>{ + // TODO: just return the output of filter_fn let res = filter_fn(node); res }; let color_edge = |edge: &::EdgeWeight| -> Result, CollectBicolorSimpleError>{ + // TODO: just return the output of color_fn let res = color_fn(edge); res }; @@ -131,7 +133,6 @@ where }; } - // tried unsuccessfully &NodeIndexable::from_index(&graph, node) for node in nodes { if let Some(is_match) = filter_node(&graph.node_weight(node))? { let raw_edges = graph @@ -153,9 +154,9 @@ where let c0 = colors[0]; ensure_vector_has_index!(pending_list, block_id, c0); if let Some(c0_block_id) = block_id[c0] { - block_list[c0_block_id].push(graph.node_weight(node)); + block_list[c0_block_id].push(graph.node_weight(node).expect("REASON")); } else { - pending_list[c0].push(graph.node_weight(node)); + pending_list[c0].push(graph.node_weight(node).expect("REASON")); } } else if colors.len() == 2 { let c0 = colors[0]; @@ -168,16 +169,16 @@ where && block_id[c0] == block_id[c1] { block_list[block_id[c0].unwrap_or_default()] - .push(graph.node_weight(node)); + .push(graph.node_weight(node).expect("REASON")); } else { - let mut new_block: Vec::NodeWeight>> = + let mut new_block: Vec<&::NodeWeight> = Vec::with_capacity(pending_list[c0].len() + pending_list[c1].len() + 1); // Clears pending lits and add to new block new_block.append(&mut pending_list[c0]); new_block.append(&mut pending_list[c1]); - new_block.push(graph.node_weight(node)); + new_block.push(graph.node_weight(node).expect("REASON")); // Create new block, assign its id to color pair block_id[c0] = Some(block_list.len()); From b57b9d7716dc5b55047d6f76d80acfbfcdd89d35 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Elena=20Pe=C3=B1a=20Tapia?= Date: Wed, 8 May 2024 16:39:30 +0200 Subject: [PATCH 03/16] Return node ids instead of node weights. Code compiles!!! --- rustworkx-core/src/collect_bicolor_runs.rs | 94 +++++++++------------- rustworkx-core/src/lib.rs | 2 +- 2 files changed, 41 insertions(+), 55 deletions(-) diff --git a/rustworkx-core/src/collect_bicolor_runs.rs b/rustworkx-core/src/collect_bicolor_runs.rs index 64808992e..cd5750b19 100644 --- a/rustworkx-core/src/collect_bicolor_runs.rs +++ b/rustworkx-core/src/collect_bicolor_runs.rs @@ -12,55 +12,42 @@ use std::cmp::Eq; use std::error::Error; -use std::hash::Hash; use std::fmt::{Debug, Display, Formatter}; +use std::hash::Hash; use petgraph::algo; -use petgraph::visit::Data; use petgraph::data::DataMap; -use petgraph::visit::{EdgeRef, GraphBase, IntoNeighborsDirected, IntoNodeIdentifiers, NodeIndexable, Visitable, IntoEdgesDirected}; - +use petgraph::visit::Data; +use petgraph::visit::{ + EdgeRef, GraphBase, IntoEdgesDirected, IntoNeighborsDirected, IntoNodeIdentifiers, + NodeIndexable, Visitable, +}; -// Taken from Kevin's PR, but we probably don't need the enum (no MergeError either) +/// Define custom error classes for collect_bicolor_runs // TODO: clean up once the code compiles #[derive(Debug)] -pub enum CollectBicolorError { +pub enum CollectBicolorError { DAGWouldCycle, + CallableError(E), //placeholder, may remove if not used } -impl Display for CollectBicolorError { +impl Display for CollectBicolorError { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { match self { CollectBicolorError::DAGWouldCycle => fmt_dag_would_cycle(f), + CollectBicolorError::CallableError(ref e) => fmt_callable_error(f, e), } } } -impl Error for CollectBicolorError {} - -#[derive(Debug)] -pub enum CollectBicolorSimpleError { - DAGWouldCycle, - MergeError(E), //placeholder, may remove if not used -} - -impl Display for CollectBicolorSimpleError { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - match self { - CollectBicolorSimpleError::DAGWouldCycle => fmt_dag_would_cycle(f), - CollectBicolorSimpleError::MergeError(ref e) => fmt_merge_error(f, e), - } - } -} - -impl Error for CollectBicolorSimpleError {} +impl Error for CollectBicolorError {} fn fmt_dag_would_cycle(f: &mut Formatter<'_>) -> std::fmt::Result { write!(f, "The operation would introduce a cycle.") } -fn fmt_merge_error(f: &mut Formatter<'_>, inner: &E) -> std::fmt::Result { - write!(f, "The prov failed with: {:?}", inner) +fn fmt_callable_error(f: &mut Formatter<'_>, inner: &E) -> std::fmt::Result { + write!(f, "The function failed with: {:?}", inner) } /// Collect runs that match a filter function given edge colors /// @@ -88,13 +75,15 @@ pub fn collect_bicolor_runs( graph: G, filter_fn: F, color_fn: C, -) -> Result::NodeWeight>>, CollectBicolorSimpleError> //OG type: PyResult>> +) -> Result>, CollectBicolorError> +//OG type: PyResult>> where E: Error, // add Option to input type because of line 135 - F: FnMut(&Option<&::NodeWeight>) -> Result, CollectBicolorSimpleError>, //OG input: &PyObject, OG return: PyResult> - C: FnMut(&::EdgeWeight) -> Result, CollectBicolorSimpleError>, //OG input: &PyObject, OG return: PyResult> - G: NodeIndexable // can take node index type and convert to usize. It restricts node index type. + F: Fn(&Option<&::NodeWeight>) -> Result, CollectBicolorError>, //OG input: &PyObject, OG return: PyResult> + C: Fn(&::EdgeWeight) -> Result, CollectBicolorError>, //OG input: &PyObject, OG return: PyResult> + G: NodeIndexable + // can take node index type and convert to usize. It restricts node index type. + IntoNodeIdentifiers // used in toposort. Turns graph into list of nodes + IntoNeighborsDirected // used in toposort + IntoEdgesDirected // used in line 138 @@ -102,25 +91,23 @@ where + DataMap, // used to access node weights ::NodeId: Eq + Hash, { - let mut pending_list: Vec::NodeWeight>> = Vec::new(); //OG type: Vec> + let mut pending_list: Vec> = Vec::new(); //OG type: Vec> let mut block_id: Vec> = Vec::new(); //OG type: Vec> - let mut block_list: Vec::NodeWeight>> = Vec::new(); //OG type: Vec> -> return + let mut block_list: Vec> = Vec::new(); //OG type: Vec> -> return - let filter_node = |node: &Option<&::NodeWeight>| -> Result, CollectBicolorSimpleError>{ - // TODO: just return the output of filter_fn - let res = filter_fn(node); - res - }; + let filter_node = + |node: &Option<&::NodeWeight>| -> Result, CollectBicolorError> { + filter_fn(node) + }; - let color_edge = |edge: &::EdgeWeight| -> Result, CollectBicolorSimpleError>{ - // TODO: just return the output of color_fn - let res = color_fn(edge); - res - }; + let color_edge = + |edge: &::EdgeWeight| -> Result, CollectBicolorError> { + color_fn(edge) + }; - let nodes = match algo::toposort(&graph, None){ + let nodes = match algo::toposort(&graph, None) { Ok(nodes) => nodes, - Err(_err) => return Err(CollectBicolorSimpleError::DAGWouldCycle) + Err(_err) => return Err(CollectBicolorError::DAGWouldCycle), }; // Utility for ensuring pending_list has the color index @@ -135,8 +122,7 @@ where for node in nodes { if let Some(is_match) = filter_node(&graph.node_weight(node))? { - let raw_edges = graph - .edges_directed(node, petgraph::Direction::Outgoing); + let raw_edges = graph.edges_directed(node, petgraph::Direction::Outgoing); // Remove all edges that do not yield errors from color_fn let colors = raw_edges @@ -149,14 +135,15 @@ where // Remove null edges from color_fn let colors = colors.into_iter().flatten().collect::>(); + // &NodeIndexable::from_index(&graph, node) if colors.len() <= 2 && is_match { if colors.len() == 1 { let c0 = colors[0]; ensure_vector_has_index!(pending_list, block_id, c0); if let Some(c0_block_id) = block_id[c0] { - block_list[c0_block_id].push(graph.node_weight(node).expect("REASON")); + block_list[c0_block_id].push(node); } else { - pending_list[c0].push(graph.node_weight(node).expect("REASON")); + pending_list[c0].push(node); } } else if colors.len() == 2 { let c0 = colors[0]; @@ -168,17 +155,16 @@ where && block_id[c1].is_some() && block_id[c0] == block_id[c1] { - block_list[block_id[c0].unwrap_or_default()] - .push(graph.node_weight(node).expect("REASON")); + block_list[block_id[c0].unwrap_or_default()].push(node); } else { - let mut new_block: Vec<&::NodeWeight> = + let mut new_block: Vec = Vec::with_capacity(pending_list[c0].len() + pending_list[c1].len() + 1); // Clears pending lits and add to new block new_block.append(&mut pending_list[c0]); new_block.append(&mut pending_list[c1]); - new_block.push(graph.node_weight(node).expect("REASON")); + new_block.push(node); // Create new block, assign its id to color pair block_id[c0] = Some(block_list.len()); @@ -200,4 +186,4 @@ where } Ok(block_list) -} \ No newline at end of file +} diff --git a/rustworkx-core/src/lib.rs b/rustworkx-core/src/lib.rs index 709aad493..4fc997edb 100644 --- a/rustworkx-core/src/lib.rs +++ b/rustworkx-core/src/lib.rs @@ -73,12 +73,12 @@ pub type Result = core::result::Result; pub mod bipartite_coloring; /// Module for centrality algorithms. pub mod centrality; +pub mod collect_bicolor_runs; /// Module for coloring algorithms. pub mod coloring; pub mod connectivity; pub mod generators; pub mod line_graph; -pub mod collect_bicolor_runs; /// Module for maximum weight matching algorithms. pub mod max_weight_matching; From d6a2894c93654aab089cc16fb54395c41da4bafd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Elena=20Pe=C3=B1a=20Tapia?= Date: Mon, 27 May 2024 11:36:40 +0200 Subject: [PATCH 04/16] Move collect_bicolor_runs.rs to rustworkx-core/src/dag_algo.rs --- rustworkx-core/src/collect_bicolor_runs.rs | 189 --------------------- rustworkx-core/src/dag_algo.rs | 172 ++++++++++++++++++- rustworkx-core/src/lib.rs | 4 +- 3 files changed, 172 insertions(+), 193 deletions(-) delete mode 100644 rustworkx-core/src/collect_bicolor_runs.rs diff --git a/rustworkx-core/src/collect_bicolor_runs.rs b/rustworkx-core/src/collect_bicolor_runs.rs deleted file mode 100644 index cd5750b19..000000000 --- a/rustworkx-core/src/collect_bicolor_runs.rs +++ /dev/null @@ -1,189 +0,0 @@ -// Licensed under the Apache License, Version 2.0 (the "License"); you may -// not use this file except in compliance with the License. You may obtain -// a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -// License for the specific language governing permissions and limitations -// under the License. - -use std::cmp::Eq; -use std::error::Error; -use std::fmt::{Debug, Display, Formatter}; -use std::hash::Hash; - -use petgraph::algo; -use petgraph::data::DataMap; -use petgraph::visit::Data; -use petgraph::visit::{ - EdgeRef, GraphBase, IntoEdgesDirected, IntoNeighborsDirected, IntoNodeIdentifiers, - NodeIndexable, Visitable, -}; - -/// Define custom error classes for collect_bicolor_runs -// TODO: clean up once the code compiles -#[derive(Debug)] -pub enum CollectBicolorError { - DAGWouldCycle, - CallableError(E), //placeholder, may remove if not used -} - -impl Display for CollectBicolorError { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - match self { - CollectBicolorError::DAGWouldCycle => fmt_dag_would_cycle(f), - CollectBicolorError::CallableError(ref e) => fmt_callable_error(f, e), - } - } -} - -impl Error for CollectBicolorError {} - -fn fmt_dag_would_cycle(f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "The operation would introduce a cycle.") -} - -fn fmt_callable_error(f: &mut Formatter<'_>, inner: &E) -> std::fmt::Result { - write!(f, "The function failed with: {:?}", inner) -} -/// Collect runs that match a filter function given edge colors -/// -/// A bicolor run is a list of group of nodes connected by edges of exactly -/// two colors. In addition, all nodes in the group must match the given -/// condition. Each node in the graph can appear in only a single group -/// in the bicolor run. -/// -/// :param PyDiGraph graph: The graph to find runs in -/// :param filter_fn: The filter function to use for matching nodes. It takes -/// in one argument, the node data payload/weight object, and will return a -/// boolean whether the node matches the conditions or not. -/// If it returns ``True``, it will continue the bicolor chain. -/// If it returns ``False``, it will stop the bicolor chain. -/// If it returns ``None`` it will skip that node. -/// :param color_fn: The function that gives the color of the edge. It takes -/// in one argument, the edge data payload/weight object, and will -/// return a non-negative integer, the edge color. If the color is None, -/// the edge is ignored. -/// -/// :returns: a list of groups with exactly two edge colors, where each group -/// is a list of node data payload/weight for the nodes in the bicolor run -/// :rtype: list -pub fn collect_bicolor_runs( - graph: G, - filter_fn: F, - color_fn: C, -) -> Result>, CollectBicolorError> -//OG type: PyResult>> -where - E: Error, - // add Option to input type because of line 135 - F: Fn(&Option<&::NodeWeight>) -> Result, CollectBicolorError>, //OG input: &PyObject, OG return: PyResult> - C: Fn(&::EdgeWeight) -> Result, CollectBicolorError>, //OG input: &PyObject, OG return: PyResult> - G: NodeIndexable - // can take node index type and convert to usize. It restricts node index type. - + IntoNodeIdentifiers // used in toposort. Turns graph into list of nodes - + IntoNeighborsDirected // used in toposort - + IntoEdgesDirected // used in line 138 - + Visitable // used in toposort - + DataMap, // used to access node weights - ::NodeId: Eq + Hash, -{ - let mut pending_list: Vec> = Vec::new(); //OG type: Vec> - let mut block_id: Vec> = Vec::new(); //OG type: Vec> - let mut block_list: Vec> = Vec::new(); //OG type: Vec> -> return - - let filter_node = - |node: &Option<&::NodeWeight>| -> Result, CollectBicolorError> { - filter_fn(node) - }; - - let color_edge = - |edge: &::EdgeWeight| -> Result, CollectBicolorError> { - color_fn(edge) - }; - - let nodes = match algo::toposort(&graph, None) { - Ok(nodes) => nodes, - Err(_err) => return Err(CollectBicolorError::DAGWouldCycle), - }; - - // Utility for ensuring pending_list has the color index - macro_rules! ensure_vector_has_index { - ($pending_list: expr, $block_id: expr, $color: expr) => { - if $color >= $pending_list.len() { - $pending_list.resize($color + 1, Vec::new()); - $block_id.resize($color + 1, None); - } - }; - } - - for node in nodes { - if let Some(is_match) = filter_node(&graph.node_weight(node))? { - let raw_edges = graph.edges_directed(node, petgraph::Direction::Outgoing); - - // Remove all edges that do not yield errors from color_fn - let colors = raw_edges - .map(|edge| { - let edge_weight = edge.weight(); - color_edge(edge_weight) - }) - .collect::>, _>>()?; - - // Remove null edges from color_fn - let colors = colors.into_iter().flatten().collect::>(); - - // &NodeIndexable::from_index(&graph, node) - if colors.len() <= 2 && is_match { - if colors.len() == 1 { - let c0 = colors[0]; - ensure_vector_has_index!(pending_list, block_id, c0); - if let Some(c0_block_id) = block_id[c0] { - block_list[c0_block_id].push(node); - } else { - pending_list[c0].push(node); - } - } else if colors.len() == 2 { - let c0 = colors[0]; - let c1 = colors[1]; - ensure_vector_has_index!(pending_list, block_id, c0); - ensure_vector_has_index!(pending_list, block_id, c1); - - if block_id[c0].is_some() - && block_id[c1].is_some() - && block_id[c0] == block_id[c1] - { - block_list[block_id[c0].unwrap_or_default()].push(node); - } else { - let mut new_block: Vec = - Vec::with_capacity(pending_list[c0].len() + pending_list[c1].len() + 1); - - // Clears pending lits and add to new block - new_block.append(&mut pending_list[c0]); - new_block.append(&mut pending_list[c1]); - - new_block.push(node); - - // Create new block, assign its id to color pair - block_id[c0] = Some(block_list.len()); - block_id[c1] = Some(block_list.len()); - block_list.push(new_block); - } - } - } else { - for color in colors { - ensure_vector_has_index!(pending_list, block_id, color); - if let Some(color_block_id) = block_id[color] { - block_list[color_block_id].append(&mut pending_list[color]); - } - block_id[color] = None; - pending_list[color].clear(); - } - } - } - } - - Ok(block_list) -} diff --git a/rustworkx-core/src/dag_algo.rs b/rustworkx-core/src/dag_algo.rs index fb629069c..736344e51 100644 --- a/rustworkx-core/src/dag_algo.rs +++ b/rustworkx-core/src/dag_algo.rs @@ -12,14 +12,17 @@ use std::cmp::{Eq, Ordering}; use std::collections::BinaryHeap; +use std::error::Error; +use std::fmt::{Debug, Display, Formatter}; use std::hash::Hash; use hashbrown::HashMap; use petgraph::algo; +use petgraph::data::DataMap; use petgraph::visit::{ - EdgeRef, GraphBase, GraphProp, IntoEdgesDirected, IntoNeighborsDirected, IntoNodeIdentifiers, - NodeCount, Visitable, + Data, EdgeRef, GraphBase, GraphProp, IntoEdgesDirected, IntoNeighborsDirected, + IntoNodeIdentifiers, NodeIndexable, NodeCount, Visitable, }; use petgraph::Directed; @@ -421,6 +424,171 @@ mod test_longest_path { } } +/// Define custom error classes for collect_bicolor_runs +// TODO: clean up once the code compiles +#[derive(Debug)] +pub enum CollectBicolorError { + DAGWouldCycle, + CallableError(E), //placeholder, may remove if not used +} + +impl Display for CollectBicolorError { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + CollectBicolorError::DAGWouldCycle => fmt_dag_would_cycle(f), + CollectBicolorError::CallableError(ref e) => fmt_callable_error(f, e), + } + } +} + +impl Error for CollectBicolorError {} + +fn fmt_dag_would_cycle(f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "The operation would introduce a cycle.") +} + +fn fmt_callable_error(f: &mut Formatter<'_>, inner: &E) -> std::fmt::Result { + write!(f, "The function failed with: {:?}", inner) +} +/// Collect runs that match a filter function given edge colors +/// +/// A bicolor run is a list of group of nodes connected by edges of exactly +/// two colors. In addition, all nodes in the group must match the given +/// condition. Each node in the graph can appear in only a single group +/// in the bicolor run. +/// +/// :param PyDiGraph graph: The graph to find runs in +/// :param filter_fn: The filter function to use for matching nodes. It takes +/// in one argument, the node data payload/weight object, and will return a +/// boolean whether the node matches the conditions or not. +/// If it returns ``True``, it will continue the bicolor chain. +/// If it returns ``False``, it will stop the bicolor chain. +/// If it returns ``None`` it will skip that node. +/// :param color_fn: The function that gives the color of the edge. It takes +/// in one argument, the edge data payload/weight object, and will +/// return a non-negative integer, the edge color. If the color is None, +/// the edge is ignored. +/// +/// :returns: a list of groups with exactly two edge colors, where each group +/// is a list of node data payload/weight for the nodes in the bicolor run +/// :rtype: list +pub fn collect_bicolor_runs( + graph: G, + filter_fn: F, + color_fn: C, +) -> Result>, CollectBicolorError> +//OG type: PyResult>> +where + E: Error, + // add Option to input type because of line 135 + F: Fn(&Option<&::NodeWeight>) -> Result, CollectBicolorError>, //OG input: &PyObject, OG return: PyResult> + C: Fn(&::EdgeWeight) -> Result, CollectBicolorError>, //OG input: &PyObject, OG return: PyResult> + G: NodeIndexable + // can take node index type and convert to usize. It restricts node index type. + + IntoNodeIdentifiers // used in toposort. Turns graph into list of nodes + + IntoNeighborsDirected // used in toposort + + IntoEdgesDirected // used in line 138 + + Visitable // used in toposort + + DataMap, // used to access node weights + ::NodeId: Eq + Hash, +{ + let mut pending_list: Vec> = Vec::new(); //OG type: Vec> + let mut block_id: Vec> = Vec::new(); //OG type: Vec> + let mut block_list: Vec> = Vec::new(); //OG type: Vec> -> return + + let filter_node = + |node: &Option<&::NodeWeight>| -> Result, CollectBicolorError> { + filter_fn(node) + }; + + let color_edge = + |edge: &::EdgeWeight| -> Result, CollectBicolorError> { + color_fn(edge) + }; + + let nodes = match algo::toposort(&graph, None) { + Ok(nodes) => nodes, + Err(_err) => return Err(CollectBicolorError::DAGWouldCycle), + }; + + // Utility for ensuring pending_list has the color index + macro_rules! ensure_vector_has_index { + ($pending_list: expr, $block_id: expr, $color: expr) => { + if $color >= $pending_list.len() { + $pending_list.resize($color + 1, Vec::new()); + $block_id.resize($color + 1, None); + } + }; + } + + for node in nodes { + if let Some(is_match) = filter_node(&graph.node_weight(node))? { + let raw_edges = graph.edges_directed(node, petgraph::Direction::Outgoing); + + // Remove all edges that do not yield errors from color_fn + let colors = raw_edges + .map(|edge| { + let edge_weight = edge.weight(); + color_edge(edge_weight) + }) + .collect::>, _>>()?; + + // Remove null edges from color_fn + let colors = colors.into_iter().flatten().collect::>(); + + // &NodeIndexable::from_index(&graph, node) + if colors.len() <= 2 && is_match { + if colors.len() == 1 { + let c0 = colors[0]; + ensure_vector_has_index!(pending_list, block_id, c0); + if let Some(c0_block_id) = block_id[c0] { + block_list[c0_block_id].push(node); + } else { + pending_list[c0].push(node); + } + } else if colors.len() == 2 { + let c0 = colors[0]; + let c1 = colors[1]; + ensure_vector_has_index!(pending_list, block_id, c0); + ensure_vector_has_index!(pending_list, block_id, c1); + + if block_id[c0].is_some() + && block_id[c1].is_some() + && block_id[c0] == block_id[c1] + { + block_list[block_id[c0].unwrap_or_default()].push(node); + } else { + let mut new_block: Vec = + Vec::with_capacity(pending_list[c0].len() + pending_list[c1].len() + 1); + + // Clears pending lits and add to new block + new_block.append(&mut pending_list[c0]); + new_block.append(&mut pending_list[c1]); + + new_block.push(node); + + // Create new block, assign its id to color pair + block_id[c0] = Some(block_list.len()); + block_id[c1] = Some(block_list.len()); + block_list.push(new_block); + } + } + } else { + for color in colors { + ensure_vector_has_index!(pending_list, block_id, color); + if let Some(color_block_id) = block_id[color] { + block_list[color_block_id].append(&mut pending_list[color]); + } + block_id[color] = None; + pending_list[color].clear(); + } + } + } + } + + Ok(block_list) +} + // pub fn lexicographical_topological_sort( // dag: G, // mut key: F, diff --git a/rustworkx-core/src/lib.rs b/rustworkx-core/src/lib.rs index 7dce5a3cd..373f9a32d 100644 --- a/rustworkx-core/src/lib.rs +++ b/rustworkx-core/src/lib.rs @@ -95,14 +95,14 @@ pub mod err; pub mod bipartite_coloring; /// Module for centrality algorithms. pub mod centrality; -pub mod collect_bicolor_runs; /// Module for coloring algorithms. pub mod coloring; pub mod connectivity; -pub mod dag_algo; pub mod generators; pub mod graph_ext; pub mod line_graph; +/// Module for algorithms that work on DAGs. +pub mod dag_algo; /// Module for maximum weight matching algorithms. pub mod max_weight_matching; pub mod planar; From 22821fd7cbc3272ab014eea1a0296652d352e58a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Elena=20Pe=C3=B1a=20Tapia?= Date: Tue, 28 May 2024 11:01:30 +0200 Subject: [PATCH 05/16] Add pyo3 wrapper for core collect_bicolor_runs, modify test, small fixes. --- rustworkx-core/src/dag_algo.rs | 275 ++++++++++----------- src/dag_algo/mod.rs | 139 ++++------- tests/digraph/test_collect_bicolor_runs.py | 2 +- 3 files changed, 184 insertions(+), 232 deletions(-) diff --git a/rustworkx-core/src/dag_algo.rs b/rustworkx-core/src/dag_algo.rs index 736344e51..6cd2bc2bb 100644 --- a/rustworkx-core/src/dag_algo.rs +++ b/rustworkx-core/src/dag_algo.rs @@ -316,122 +316,15 @@ where Ok(Some((path, path_weight))) } -#[cfg(test)] -mod test_longest_path { - use super::*; - use petgraph::graph::DiGraph; - use petgraph::stable_graph::StableDiGraph; - - #[test] - fn test_empty_graph() { - let graph: DiGraph<(), ()> = DiGraph::new(); - let weight_fn = |_: petgraph::graph::EdgeReference<()>| Ok::(0); - let result = longest_path(&graph, weight_fn); - assert_eq!(result, Ok(Some((vec![], 0)))); - } - - #[test] - fn test_single_node_graph() { - let mut graph: DiGraph<(), ()> = DiGraph::new(); - let n0 = graph.add_node(()); - let weight_fn = |_: petgraph::graph::EdgeReference<()>| Ok::(0); - let result = longest_path(&graph, weight_fn); - assert_eq!(result, Ok(Some((vec![n0], 0)))); - } - - #[test] - fn test_dag_with_multiple_paths() { - let mut graph: DiGraph<(), i32> = DiGraph::new(); - let n0 = graph.add_node(()); - let n1 = graph.add_node(()); - let n2 = graph.add_node(()); - let n3 = graph.add_node(()); - let n4 = graph.add_node(()); - let n5 = graph.add_node(()); - graph.add_edge(n0, n1, 3); - graph.add_edge(n0, n2, 2); - graph.add_edge(n1, n2, 1); - graph.add_edge(n1, n3, 4); - graph.add_edge(n2, n3, 2); - graph.add_edge(n3, n4, 2); - graph.add_edge(n2, n5, 1); - graph.add_edge(n4, n5, 3); - let weight_fn = |edge: petgraph::graph::EdgeReference| Ok::(*edge.weight()); - let result = longest_path(&graph, weight_fn); - assert_eq!(result, Ok(Some((vec![n0, n1, n3, n4, n5], 12)))); - } - - #[test] - fn test_graph_with_cycle() { - let mut graph: DiGraph<(), i32> = DiGraph::new(); - let n0 = graph.add_node(()); - let n1 = graph.add_node(()); - graph.add_edge(n0, n1, 1); - graph.add_edge(n1, n0, 1); // Creates a cycle - - let weight_fn = |edge: petgraph::graph::EdgeReference| Ok::(*edge.weight()); - let result = longest_path(&graph, weight_fn); - assert_eq!(result, Ok(None)); - } - - #[test] - fn test_negative_weights() { - let mut graph: DiGraph<(), i32> = DiGraph::new(); - let n0 = graph.add_node(()); - let n1 = graph.add_node(()); - let n2 = graph.add_node(()); - graph.add_edge(n0, n1, -1); - graph.add_edge(n0, n2, 2); - graph.add_edge(n1, n2, -2); - let weight_fn = |edge: petgraph::graph::EdgeReference| Ok::(*edge.weight()); - let result = longest_path(&graph, weight_fn); - assert_eq!(result, Ok(Some((vec![n0, n2], 2)))); - } - - #[test] - fn test_longest_path_in_stable_digraph() { - let mut graph: StableDiGraph<(), i32> = StableDiGraph::new(); - let n0 = graph.add_node(()); - let n1 = graph.add_node(()); - let n2 = graph.add_node(()); - graph.add_edge(n0, n1, 1); - graph.add_edge(n0, n2, 3); - graph.add_edge(n1, n2, 1); - let weight_fn = - |edge: petgraph::stable_graph::EdgeReference<'_, i32>| Ok::(*edge.weight()); - let result = longest_path(&graph, weight_fn); - assert_eq!(result, Ok(Some((vec![n0, n2], 3)))); - } - - #[test] - fn test_error_handling() { - let mut graph: DiGraph<(), i32> = DiGraph::new(); - let n0 = graph.add_node(()); - let n1 = graph.add_node(()); - let n2 = graph.add_node(()); - graph.add_edge(n0, n1, 1); - graph.add_edge(n0, n2, 2); - graph.add_edge(n1, n2, 1); - let weight_fn = |edge: petgraph::graph::EdgeReference| { - if *edge.weight() == 2 { - Err("Error: edge weight is 2") - } else { - Ok::(*edge.weight()) - } - }; - let result = longest_path(&graph, weight_fn); - assert_eq!(result, Err("Error: edge weight is 2")); - } -} - /// Define custom error classes for collect_bicolor_runs -// TODO: clean up once the code compiles #[derive(Debug)] pub enum CollectBicolorError { DAGWouldCycle, - CallableError(E), //placeholder, may remove if not used + CallableError(E) } +impl Error for CollectBicolorError {} + impl Display for CollectBicolorError { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { match self { @@ -441,8 +334,6 @@ impl Display for CollectBicolorError { } } -impl Error for CollectBicolorError {} - fn fmt_dag_would_cycle(f: &mut Formatter<'_>) -> std::fmt::Result { write!(f, "The operation would introduce a cycle.") } @@ -450,54 +341,55 @@ fn fmt_dag_would_cycle(f: &mut Formatter<'_>) -> std::fmt::Result { fn fmt_callable_error(f: &mut Formatter<'_>, inner: &E) -> std::fmt::Result { write!(f, "The function failed with: {:?}", inner) } -/// Collect runs that match a filter function given edge colors +/// Collect runs that match a filter function given edge colors. /// -/// A bicolor run is a list of group of nodes connected by edges of exactly +/// A bicolor run is a list of groups of nodes connected by edges of exactly /// two colors. In addition, all nodes in the group must match the given /// condition. Each node in the graph can appear in only a single group /// in the bicolor run. /// -/// :param PyDiGraph graph: The graph to find runs in -/// :param filter_fn: The filter function to use for matching nodes. It takes +/// # Arguments: +/// +/// * `dag`: The DAG to find bicolor runs in +/// * `filter_fn`: The filter function to use for matching nodes. It takes /// in one argument, the node data payload/weight object, and will return a /// boolean whether the node matches the conditions or not. -/// If it returns ``True``, it will continue the bicolor chain. -/// If it returns ``False``, it will stop the bicolor chain. +/// If it returns ``true``, it will continue the bicolor chain. +/// If it returns ``false``, it will stop the bicolor chain. /// If it returns ``None`` it will skip that node. -/// :param color_fn: The function that gives the color of the edge. It takes +/// * `color_fn`: The function that gives the color of the edge. It takes /// in one argument, the edge data payload/weight object, and will /// return a non-negative integer, the edge color. If the color is None, /// the edge is ignored. /// -/// :returns: a list of groups with exactly two edge colors, where each group +/// # Returns: +/// +/// * `Vec>`: a list of groups with exactly two edge colors, where each group /// is a list of node data payload/weight for the nodes in the bicolor run -/// :rtype: list +/// * `CollectBicolorError>` if there is an error computing the bicolor runs pub fn collect_bicolor_runs( graph: G, filter_fn: F, color_fn: C, ) -> Result>, CollectBicolorError> -//OG type: PyResult>> where E: Error, - // add Option to input type because of line 135 - F: Fn(&Option<&::NodeWeight>) -> Result, CollectBicolorError>, //OG input: &PyObject, OG return: PyResult> - C: Fn(&::EdgeWeight) -> Result, CollectBicolorError>, //OG input: &PyObject, OG return: PyResult> + F: Fn(&::NodeWeight) -> Result, CollectBicolorError>, + C: Fn(&::EdgeWeight) -> Result, CollectBicolorError>, G: NodeIndexable - // can take node index type and convert to usize. It restricts node index type. - + IntoNodeIdentifiers // used in toposort. Turns graph into list of nodes - + IntoNeighborsDirected // used in toposort - + IntoEdgesDirected // used in line 138 - + Visitable // used in toposort - + DataMap, // used to access node weights + + IntoNodeIdentifiers + + IntoNeighborsDirected + + IntoEdgesDirected + + Visitable + + DataMap, ::NodeId: Eq + Hash, { - let mut pending_list: Vec> = Vec::new(); //OG type: Vec> - let mut block_id: Vec> = Vec::new(); //OG type: Vec> - let mut block_list: Vec> = Vec::new(); //OG type: Vec> -> return + let mut pending_list: Vec> = Vec::new(); + let mut block_id: Vec> = Vec::new(); + let mut block_list: Vec> = Vec::new(); let filter_node = - |node: &Option<&::NodeWeight>| -> Result, CollectBicolorError> { + |node: &::NodeWeight| -> Result, CollectBicolorError> { filter_fn(node) }; @@ -522,7 +414,7 @@ where } for node in nodes { - if let Some(is_match) = filter_node(&graph.node_weight(node))? { + if let Some(is_match) = filter_node(&graph.node_weight(node).expect("Invalid NodeId"))? { let raw_edges = graph.edges_directed(node, petgraph::Direction::Outgoing); // Remove all edges that do not yield errors from color_fn @@ -536,7 +428,6 @@ where // Remove null edges from color_fn let colors = colors.into_iter().flatten().collect::>(); - // &NodeIndexable::from_index(&graph, node) if colors.len() <= 2 && is_match { if colors.len() == 1 { let c0 = colors[0]; @@ -589,6 +480,114 @@ where Ok(block_list) } +#[cfg(test)] +mod test_longest_path { + use super::*; + use petgraph::graph::DiGraph; + use petgraph::stable_graph::StableDiGraph; + + #[test] + fn test_empty_graph() { + let graph: DiGraph<(), ()> = DiGraph::new(); + let weight_fn = |_: petgraph::graph::EdgeReference<()>| Ok::(0); + let result = longest_path(&graph, weight_fn); + assert_eq!(result, Ok(Some((vec![], 0)))); + } + + #[test] + fn test_single_node_graph() { + let mut graph: DiGraph<(), ()> = DiGraph::new(); + let n0 = graph.add_node(()); + let weight_fn = |_: petgraph::graph::EdgeReference<()>| Ok::(0); + let result = longest_path(&graph, weight_fn); + assert_eq!(result, Ok(Some((vec![n0], 0)))); + } + + #[test] + fn test_dag_with_multiple_paths() { + let mut graph: DiGraph<(), i32> = DiGraph::new(); + let n0 = graph.add_node(()); + let n1 = graph.add_node(()); + let n2 = graph.add_node(()); + let n3 = graph.add_node(()); + let n4 = graph.add_node(()); + let n5 = graph.add_node(()); + graph.add_edge(n0, n1, 3); + graph.add_edge(n0, n2, 2); + graph.add_edge(n1, n2, 1); + graph.add_edge(n1, n3, 4); + graph.add_edge(n2, n3, 2); + graph.add_edge(n3, n4, 2); + graph.add_edge(n2, n5, 1); + graph.add_edge(n4, n5, 3); + let weight_fn = |edge: petgraph::graph::EdgeReference| Ok::(*edge.weight()); + let result = longest_path(&graph, weight_fn); + assert_eq!(result, Ok(Some((vec![n0, n1, n3, n4, n5], 12)))); + } + + #[test] + fn test_graph_with_cycle() { + let mut graph: DiGraph<(), i32> = DiGraph::new(); + let n0 = graph.add_node(()); + let n1 = graph.add_node(()); + graph.add_edge(n0, n1, 1); + graph.add_edge(n1, n0, 1); // Creates a cycle + + let weight_fn = |edge: petgraph::graph::EdgeReference| Ok::(*edge.weight()); + let result = longest_path(&graph, weight_fn); + assert_eq!(result, Ok(None)); + } + + #[test] + fn test_negative_weights() { + let mut graph: DiGraph<(), i32> = DiGraph::new(); + let n0 = graph.add_node(()); + let n1 = graph.add_node(()); + let n2 = graph.add_node(()); + graph.add_edge(n0, n1, -1); + graph.add_edge(n0, n2, 2); + graph.add_edge(n1, n2, -2); + let weight_fn = |edge: petgraph::graph::EdgeReference| Ok::(*edge.weight()); + let result = longest_path(&graph, weight_fn); + assert_eq!(result, Ok(Some((vec![n0, n2], 2)))); + } + + #[test] + fn test_longest_path_in_stable_digraph() { + let mut graph: StableDiGraph<(), i32> = StableDiGraph::new(); + let n0 = graph.add_node(()); + let n1 = graph.add_node(()); + let n2 = graph.add_node(()); + graph.add_edge(n0, n1, 1); + graph.add_edge(n0, n2, 3); + graph.add_edge(n1, n2, 1); + let weight_fn = + |edge: petgraph::stable_graph::EdgeReference<'_, i32>| Ok::(*edge.weight()); + let result = longest_path(&graph, weight_fn); + assert_eq!(result, Ok(Some((vec![n0, n2], 3)))); + } + + #[test] + fn test_error_handling() { + let mut graph: DiGraph<(), i32> = DiGraph::new(); + let n0 = graph.add_node(()); + let n1 = graph.add_node(()); + let n2 = graph.add_node(()); + graph.add_edge(n0, n1, 1); + graph.add_edge(n0, n2, 2); + graph.add_edge(n1, n2, 1); + let weight_fn = |edge: petgraph::graph::EdgeReference| { + if *edge.weight() == 2 { + Err("Error: edge weight is 2") + } else { + Ok::(*edge.weight()) + } + }; + let result = longest_path(&graph, weight_fn); + assert_eq!(result, Err("Error: edge weight is 2")); + } +} + // pub fn lexicographical_topological_sort( // dag: G, // mut key: F, @@ -766,4 +765,4 @@ mod test_lexicographical_topological_sort { let result = lexicographical_topological_sort(&graph, sort_fn, false, Some(&initial)); assert_eq!(result, Ok(Some(vec![nodes[7]]))); } -} +} \ No newline at end of file diff --git a/src/dag_algo/mod.rs b/src/dag_algo/mod.rs index f3138384c..65c7d8087 100644 --- a/src/dag_algo/mod.rs +++ b/src/dag_algo/mod.rs @@ -18,8 +18,9 @@ use rustworkx_core::dictmap::InitWithHasher; use super::iterators::NodeIndices; use crate::{digraph, DAGHasCycle, InvalidNode, StablePyGraph}; -use rustworkx_core::dag_algo::lexicographical_topological_sort as core_lexico_topo_sort; +use rustworkx_core::dag_algo::{CollectBicolorError, lexicographical_topological_sort as core_lexico_topo_sort}; use rustworkx_core::dag_algo::longest_path as core_longest_path; +use rustworkx_core::dag_algo::collect_bicolor_runs as core_collect_bicolor_runs; use rustworkx_core::traversal::dfs_edges; use pyo3::exceptions::PyValueError; @@ -603,7 +604,18 @@ pub fn collect_runs( Ok(out_list) } -/// Collect runs that match a filter function given edge colors +/// Define custom error conversion logic for collect_bicolor_runs. +fn convert_error(err: CollectBicolorError) -> PyErr { +// Note that we cannot implement From> for PyErr +// because nor PyErr nor CollectBicolorError are defined in this crate, +// so we use .map_err(convert_error) to convert a CollectBicolorError to PyErr instead. + match err { + CollectBicolorError::DAGWouldCycle => PyErr::new::("DAG would cycle"), + CollectBicolorError::CallableError(err) => err, + } +} + +/// Collect runs that match a filter function given edge colors. /// /// A bicolor run is a list of group of nodes connected by edges of exactly /// two colors. In addition, all nodes in the group must match the given @@ -633,103 +645,44 @@ pub fn collect_bicolor_runs( filter_fn: PyObject, color_fn: PyObject, ) -> PyResult>> { - let mut pending_list: Vec> = Vec::new(); - let mut block_id: Vec> = Vec::new(); - let mut block_list: Vec> = Vec::new(); - - let filter_node = |node: &PyObject| -> PyResult> { - let res = filter_fn.call1(py, (node,))?; - res.extract(py) - }; - let color_edge = |edge: &PyObject| -> PyResult> { - let res = color_fn.call1(py, (edge,))?; - res.extract(py) - }; + let dag = &graph.graph; - let nodes = match algo::toposort(&graph.graph, None) { - Ok(nodes) => nodes, - Err(_err) => return Err(DAGHasCycle::new_err("Sort encountered a cycle")), + // Wrap filter_fn to return Result, CollectBicolorError> + let filter_fn_wrapper = + |node: &PyObject| -> Result, CollectBicolorError> { + match filter_fn.call1(py, (node,)) { + Ok(res) => res.extract(py).map_err(CollectBicolorError::CallableError), + Err(err) => Err(CollectBicolorError::CallableError(err)), + } }; - // Utility for ensuring pending_list has the color index - macro_rules! ensure_vector_has_index { - ($pending_list: expr, $block_id: expr, $color: expr) => { - if $color >= $pending_list.len() { - $pending_list.resize($color + 1, Vec::new()); - $block_id.resize($color + 1, None); - } - }; - } - - for node in nodes { - if let Some(is_match) = filter_node(&graph.graph[node])? { - let raw_edges = graph - .graph - .edges_directed(node, petgraph::Direction::Outgoing); - - // Remove all edges that do not yield errors from color_fn - let colors = raw_edges - .map(|edge| { - let edge_weight = edge.weight(); - color_edge(edge_weight) - }) - .collect::>>>()?; - - // Remove null edges from color_fn - let colors = colors.into_iter().flatten().collect::>(); - - if colors.len() <= 2 && is_match { - if colors.len() == 1 { - let c0 = colors[0]; - ensure_vector_has_index!(pending_list, block_id, c0); - if let Some(c0_block_id) = block_id[c0] { - block_list[c0_block_id].push(graph.graph[node].clone_ref(py)); - } else { - pending_list[c0].push(graph.graph[node].clone_ref(py)); - } - } else if colors.len() == 2 { - let c0 = colors[0]; - let c1 = colors[1]; - ensure_vector_has_index!(pending_list, block_id, c0); - ensure_vector_has_index!(pending_list, block_id, c1); - - if block_id[c0].is_some() - && block_id[c1].is_some() - && block_id[c0] == block_id[c1] - { - block_list[block_id[c0].unwrap_or_default()] - .push(graph.graph[node].clone_ref(py)); - } else { - let mut new_block: Vec = - Vec::with_capacity(pending_list[c0].len() + pending_list[c1].len() + 1); - - // Clears pending lits and add to new block - new_block.append(&mut pending_list[c0]); - new_block.append(&mut pending_list[c1]); - - new_block.push(graph.graph[node].clone_ref(py)); - - // Create new block, assign its id to color pair - block_id[c0] = Some(block_list.len()); - block_id[c1] = Some(block_list.len()); - block_list.push(new_block); - } - } - } else { - for color in colors { - ensure_vector_has_index!(pending_list, block_id, color); - if let Some(color_block_id) = block_id[color] { - block_list[color_block_id].append(&mut pending_list[color]); - } - block_id[color] = None; - pending_list[color].clear(); - } - } + // Wrap color_fn to return Result, CollectBicolorError> + let color_fn_wrapper = + |edge: &PyObject| -> Result, CollectBicolorError> { + match color_fn.call1(py, (edge,)) { + Ok(res) => res.extract(py).map_err(CollectBicolorError::CallableError), + Err(err) => Err(CollectBicolorError::CallableError(err)), } - } + }; - Ok(block_list) + // Map CollectBicolorError to PyErr using custom convert_error function + let block_list = + core_collect_bicolor_runs::<&StablePyGraph, _, _, (), PyErr>( + dag, + filter_fn_wrapper, + color_fn_wrapper + ).map_err(convert_error)?; + + // Convert the result list from Vec> to Vec> + let py_block_list: Vec> = block_list.into_iter().map(|index_list| { + index_list.into_iter().map(|node_index| { + let node_weight = dag.node_weight(node_index).expect("Invalid NodeId"); + node_weight.into_py(py) + }).collect() + }).collect(); + + Ok(py_block_list) } /// Returns the transitive reduction of a directed acyclic graph diff --git a/tests/digraph/test_collect_bicolor_runs.py b/tests/digraph/test_collect_bicolor_runs.py index b0a903e15..f5ff15b20 100644 --- a/tests/digraph/test_collect_bicolor_runs.py +++ b/tests/digraph/test_collect_bicolor_runs.py @@ -19,7 +19,7 @@ class TestCollectBicolorRuns(unittest.TestCase): def test_cycle(self): dag = rustworkx.PyDiGraph() dag.extend_from_edge_list([(0, 1), (1, 2), (2, 0)]) - with self.assertRaises(rustworkx.DAGHasCycle): + with self.assertRaises(ValueError): rustworkx.collect_bicolor_runs(dag, lambda _: True, lambda _: None) def test_filter_function_inner_exception(self): From b77b269f06db22149de41785108d463062eef03e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Elena=20Pe=C3=B1a=20Tapia?= Date: Wed, 29 May 2024 13:21:22 +0200 Subject: [PATCH 06/16] Fix DagHasCycle error to match previous type --- rustworkx-core/src/dag_algo.rs | 10 +++++----- src/dag_algo/mod.rs | 2 +- tests/digraph/test_collect_bicolor_runs.py | 2 +- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/rustworkx-core/src/dag_algo.rs b/rustworkx-core/src/dag_algo.rs index 6cd2bc2bb..0894717b6 100644 --- a/rustworkx-core/src/dag_algo.rs +++ b/rustworkx-core/src/dag_algo.rs @@ -319,7 +319,7 @@ where /// Define custom error classes for collect_bicolor_runs #[derive(Debug)] pub enum CollectBicolorError { - DAGWouldCycle, + DAGHasCycle, CallableError(E) } @@ -328,14 +328,14 @@ impl Error for CollectBicolorError {} impl Display for CollectBicolorError { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { match self { - CollectBicolorError::DAGWouldCycle => fmt_dag_would_cycle(f), + CollectBicolorError::DAGHasCycle => fmt_dag_has_cycle(f), CollectBicolorError::CallableError(ref e) => fmt_callable_error(f, e), } } } -fn fmt_dag_would_cycle(f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "The operation would introduce a cycle.") +fn fmt_dag_has_cycle(f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "Sort encountered a cycle") } fn fmt_callable_error(f: &mut Formatter<'_>, inner: &E) -> std::fmt::Result { @@ -400,7 +400,7 @@ where let nodes = match algo::toposort(&graph, None) { Ok(nodes) => nodes, - Err(_err) => return Err(CollectBicolorError::DAGWouldCycle), + Err(_err) => return Err(CollectBicolorError::DAGHasCycle), }; // Utility for ensuring pending_list has the color index diff --git a/src/dag_algo/mod.rs b/src/dag_algo/mod.rs index 65c7d8087..a0ca92ed0 100644 --- a/src/dag_algo/mod.rs +++ b/src/dag_algo/mod.rs @@ -610,7 +610,7 @@ fn convert_error(err: CollectBicolorError) -> PyErr { // because nor PyErr nor CollectBicolorError are defined in this crate, // so we use .map_err(convert_error) to convert a CollectBicolorError to PyErr instead. match err { - CollectBicolorError::DAGWouldCycle => PyErr::new::("DAG would cycle"), + CollectBicolorError::DAGHasCycle => PyErr::new::("Sort encountered a cycle"), CollectBicolorError::CallableError(err) => err, } } diff --git a/tests/digraph/test_collect_bicolor_runs.py b/tests/digraph/test_collect_bicolor_runs.py index f5ff15b20..b0a903e15 100644 --- a/tests/digraph/test_collect_bicolor_runs.py +++ b/tests/digraph/test_collect_bicolor_runs.py @@ -19,7 +19,7 @@ class TestCollectBicolorRuns(unittest.TestCase): def test_cycle(self): dag = rustworkx.PyDiGraph() dag.extend_from_edge_list([(0, 1), (1, 2), (2, 0)]) - with self.assertRaises(ValueError): + with self.assertRaises(rustworkx.DAGHasCycle): rustworkx.collect_bicolor_runs(dag, lambda _: True, lambda _: None) def test_filter_function_inner_exception(self): From c98b27977fb9b123b99cf67934a18501da67d6bc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Elena=20Pe=C3=B1a=20Tapia?= Date: Wed, 29 May 2024 13:30:42 +0200 Subject: [PATCH 07/16] Clean up unused parameters and functions, simplify error handling, add unit tests and reno. --- ...collect-bicolor-runs-ff95d7fdbc546e2e.yaml | 6 + rustworkx-core/src/dag_algo.rs | 295 ++++++++++++++---- rustworkx-core/src/lib.rs | 4 +- src/dag_algo/mod.rs | 71 ++--- 4 files changed, 272 insertions(+), 104 deletions(-) create mode 100644 releasenotes/notes/migrate-collect-bicolor-runs-ff95d7fdbc546e2e.yaml diff --git a/releasenotes/notes/migrate-collect-bicolor-runs-ff95d7fdbc546e2e.yaml b/releasenotes/notes/migrate-collect-bicolor-runs-ff95d7fdbc546e2e.yaml new file mode 100644 index 000000000..300d75832 --- /dev/null +++ b/releasenotes/notes/migrate-collect-bicolor-runs-ff95d7fdbc546e2e.yaml @@ -0,0 +1,6 @@ +--- +features: + - | + Added a new function ``collect_bicolor_runs`` to rustworkx-core's ``dag_algo`` module. + Previously, the ``collect_bicolor_runs`` functionality for DAGs was only exposed + via the Python interface. Now Rust users can take advantage of this functionality in rustworkx-core. diff --git a/rustworkx-core/src/dag_algo.rs b/rustworkx-core/src/dag_algo.rs index 0894717b6..c0be386a9 100644 --- a/rustworkx-core/src/dag_algo.rs +++ b/rustworkx-core/src/dag_algo.rs @@ -12,8 +12,6 @@ use std::cmp::{Eq, Ordering}; use std::collections::BinaryHeap; -use std::error::Error; -use std::fmt::{Debug, Display, Formatter}; use std::hash::Hash; use hashbrown::HashMap; @@ -22,7 +20,7 @@ use petgraph::algo; use petgraph::data::DataMap; use petgraph::visit::{ Data, EdgeRef, GraphBase, GraphProp, IntoEdgesDirected, IntoNeighborsDirected, - IntoNodeIdentifiers, NodeIndexable, NodeCount, Visitable, + IntoNodeIdentifiers, NodeCount, Visitable, }; use petgraph::Directed; @@ -316,31 +314,6 @@ where Ok(Some((path, path_weight))) } -/// Define custom error classes for collect_bicolor_runs -#[derive(Debug)] -pub enum CollectBicolorError { - DAGHasCycle, - CallableError(E) -} - -impl Error for CollectBicolorError {} - -impl Display for CollectBicolorError { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - match self { - CollectBicolorError::DAGHasCycle => fmt_dag_has_cycle(f), - CollectBicolorError::CallableError(ref e) => fmt_callable_error(f, e), - } - } -} - -fn fmt_dag_has_cycle(f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "Sort encountered a cycle") -} - -fn fmt_callable_error(f: &mut Formatter<'_>, inner: &E) -> std::fmt::Result { - write!(f, "The function failed with: {:?}", inner) -} /// Collect runs that match a filter function given edge colors. /// /// A bicolor run is a list of groups of nodes connected by edges of exactly @@ -366,41 +339,56 @@ fn fmt_callable_error(f: &mut Formatter<'_>, inner: &E) -> std::fmt::R /// /// * `Vec>`: a list of groups with exactly two edge colors, where each group /// is a list of node data payload/weight for the nodes in the bicolor run -/// * `CollectBicolorError>` if there is an error computing the bicolor runs -pub fn collect_bicolor_runs( +/// * `None` if a cycle is found in the graph +/// * Raises an error if found computing the bicolor runs +/// +/// # Example: +/// +/// ```rust +/// use rustworkx_core::dag_algo::collect_bicolor_runs; +/// use petgraph::graph::DiGraph; +/// use std::error::Error; +/// +/// let mut graph = DiGraph::new(); +/// let n0 = graph.add_node(0); +/// let n1 = graph.add_node(1); +/// let n2 = graph.add_node(2); +/// let n3 = graph.add_node(3); +/// graph.add_edge(n0, n2, 0); +/// graph.add_edge(n1, n2, 0); +/// graph.add_edge(n2, n3, 0); +/// graph.add_edge(n2, n3, 0); +/// // filter_fn and color_fn must share error type +/// fn filter_fn(node: &i32) -> Result, Box> { +/// Ok(Some(*node > 1)) +/// } +/// fn color_fn(edge: &i32) -> Result, Box> { +/// Ok(Some(*edge as usize)) +/// } +/// let result = collect_bicolor_runs(&graph, filter_fn, color_fn).unwrap(); +/// ``` +pub fn collect_bicolor_runs( graph: G, filter_fn: F, color_fn: C, -) -> Result>, CollectBicolorError> +) -> Result>>, E> where - E: Error, - F: Fn(&::NodeWeight) -> Result, CollectBicolorError>, - C: Fn(&::EdgeWeight) -> Result, CollectBicolorError>, - G: NodeIndexable - + IntoNodeIdentifiers - + IntoNeighborsDirected - + IntoEdgesDirected - + Visitable - + DataMap, + F: Fn(&::NodeWeight) -> Result, E>, + C: Fn(&::EdgeWeight) -> Result, E>, + G: IntoNodeIdentifiers // Used in toposort + + IntoNeighborsDirected // Used in toposort + + IntoEdgesDirected // Used for .edges_directed + + Visitable // Used in toposort + + DataMap, // Used for .node_weight ::NodeId: Eq + Hash, { let mut pending_list: Vec> = Vec::new(); let mut block_id: Vec> = Vec::new(); let mut block_list: Vec> = Vec::new(); - let filter_node = - |node: &::NodeWeight| -> Result, CollectBicolorError> { - filter_fn(node) - }; - - let color_edge = - |edge: &::EdgeWeight| -> Result, CollectBicolorError> { - color_fn(edge) - }; - - let nodes = match algo::toposort(&graph, None) { + let nodes = match algo::toposort(graph, None) { Ok(nodes) => nodes, - Err(_err) => return Err(CollectBicolorError::DAGHasCycle), + Err(_) => return Ok(None), // Return None if the graph contains a cycle }; // Utility for ensuring pending_list has the color index @@ -414,14 +402,14 @@ where } for node in nodes { - if let Some(is_match) = filter_node(&graph.node_weight(node).expect("Invalid NodeId"))? { + if let Some(is_match) = filter_fn(graph.node_weight(node).expect("Invalid NodeId"))? { let raw_edges = graph.edges_directed(node, petgraph::Direction::Outgoing); - // Remove all edges that do not yield errors from color_fn + // Remove all edges that yield errors from color_fn let colors = raw_edges .map(|edge| { let edge_weight = edge.weight(); - color_edge(edge_weight) + color_fn(edge_weight) }) .collect::>, _>>()?; @@ -477,9 +465,10 @@ where } } - Ok(block_list) + Ok(Some(block_list)) } +/// Tests for longest_path #[cfg(test)] mod test_longest_path { use super::*; @@ -588,6 +577,7 @@ mod test_longest_path { } } +/// Tests for lexicographical_topological_sort // pub fn lexicographical_topological_sort( // dag: G, // mut key: F, @@ -765,4 +755,195 @@ mod test_lexicographical_topological_sort { let result = lexicographical_topological_sort(&graph, sort_fn, false, Some(&initial)); assert_eq!(result, Ok(Some(vec![nodes[7]]))); } -} \ No newline at end of file +} + +/// Tests for collect_bicolor_runs +#[cfg(test)] +mod test_collect_bicolor_runs { + + use super::*; + use petgraph::graph::{DiGraph, NodeIndex}; + use std::error::Error; + + fn test_filter_fn(node: &i32) -> Result, Box> { + Ok(Some(*node > 1)) + } + + fn test_color_fn(edge: &i32) -> Result, Box> { + Ok(Some(*edge as usize)) + } + + #[test] + fn test_cycle() { + let mut graph = DiGraph::new(); + let n0 = graph.add_node(2); + let n1 = graph.add_node(2); + let n2 = graph.add_node(2); + graph.add_edge(n0, n1, 1); + graph.add_edge(n1, n2, 1); + graph.add_edge(n2, n0, 1); + + let result = match collect_bicolor_runs(&graph, test_filter_fn, test_color_fn) { + Ok(Some(_value)) => "Not None", + Ok(None) => "None", + Err(_) => "Error", + }; + assert_eq!(result, "None") + } + + #[test] + fn test_filter_function_inner_exception() { + let mut graph = DiGraph::new(); + graph.add_node(0); + + fn fail_function(node: &i32) -> Result, Box> { + if *node > 0 { + Ok(Some(true)) + } else { + Err(Box::from("Failed!")) + } + } + + let result = match collect_bicolor_runs(&graph, fail_function, test_color_fn) { + Ok(Some(_value)) => "Not None", + Ok(None) => "None", + Err(_) => "Error", + }; + assert_eq!(result, "Error") + } + + #[test] + fn test_empty() { + let graph = DiGraph::new(); + let result = collect_bicolor_runs(&graph, test_filter_fn, test_color_fn).unwrap(); + let expected: Vec> = vec![]; + assert_eq!(result, Some(expected)) + } + + #[test] + fn test_two_colors() { + // Based on the following graph from the Python unit tests: + // Input: + // ┌─────────────┐ ┌─────────────┐ + // │ │ │ │ + // │ q0 │ │ q1 │ + // │ │ │ │ + // └───┬─────────┘ └──────┬──────┘ + // │ ┌─────────────┐ │ + // q0 │ │ │ │ q1 + // │ │ │ │ + // └─────────►│ cx │◄────────┘ + // ┌──────────┤ ├─────────┐ + // │ │ │ │ + // q0 │ └─────────────┘ │ q1 + // │ │ + // │ ┌─────────────┐ │ + // │ │ │ │ + // └─────────►│ cz │◄────────┘ + // ┌─────────┤ ├─────────┐ + // │ └─────────────┘ │ + // q0 │ │ q1 + // │ │ + // ┌───▼─────────┐ ┌──────▼──────┐ + // │ │ │ │ + // │ q0 │ │ q1 │ + // │ │ │ │ + // └─────────────┘ └─────────────┘ + // + // Expected: [[cx, cz]] + let mut graph = DiGraph::new(); + // The node weight will correspond to the type of node + // All edges have the same weight in this example + let n0 = graph.add_node(0); //q0 + let n1 = graph.add_node(1); //q1 + let n2 = graph.add_node(2); //cx + let n3 = graph.add_node(3); //cz + let n4 = graph.add_node(0); //q0_1 + let n5 = graph.add_node(1); //q1_1 + graph.add_edge(n0, n2, 0); //q0 -> cx + graph.add_edge(n1, n2, 0); //q1 -> cx + graph.add_edge(n2, n3, 0); //cx -> cz + graph.add_edge(n2, n3, 0); //cx -> cz + graph.add_edge(n3, n4, 0); //cz -> q0_1 + graph.add_edge(n3, n5, 0); //cz -> q1_1 + + let result = collect_bicolor_runs(&graph, test_filter_fn, test_color_fn).unwrap(); + let expected: Vec> = vec![vec![n2, n3]]; //[[cx, cz]] + assert_eq!(result, Some(expected)) + } + #[test] + fn test_two_colors_with_pending() { + // Based on the following graph from the Python unit tests: + // Input: + // ┌─────────────┐ + // │ │ + // │ q0 │ + // │ │ + // └───┬─────────┘ + // | q0 + // │ + // ┌───▼─────────┐ + // │ │ + // │ h │ + // │ │ + // └───┬─────────┘ + // | q0 + // │ ┌─────────────┐ + // │ │ │ + // │ │ q1 │ + // │ │ │ + // | └──────┬──────┘ + // │ ┌─────────────┐ │ + // q0 │ │ │ │ q1 + // │ │ │ │ + // └─────────►│ cx │◄────────┘ + // ┌──────────┤ ├─────────┐ + // │ │ │ │ + // q0 │ └─────────────┘ │ q1 + // │ │ + // │ ┌─────────────┐ │ + // │ │ │ │ + // └─────────►│ cz │◄────────┘ + // ┌─────────┤ ├─────────┐ + // │ └─────────────┘ │ + // q0 │ │ q1 + // │ │ + // ┌───▼─────────┐ ┌──────▼──────┐ + // │ │ │ │ + // │ q0 │ │ y │ + // │ │ │ │ + // └─────────────┘ └─────────────┘ + // | q1 + // │ + // ┌───▼─────────┐ + // │ │ + // │ q1 │ + // │ │ + // └─────────────┘ + // + // Expected: [[h, cx, cz, y]] + let mut graph = DiGraph::new(); + // The node weight will correspond to the type of node + // All edges have the same weight in this example + let n0 = graph.add_node(0); //q0 + let n1 = graph.add_node(1); //q1 + let n2 = graph.add_node(2); //h + let n3 = graph.add_node(3); //cx + let n4 = graph.add_node(4); //cz + let n5 = graph.add_node(5); //y + let n6 = graph.add_node(0); //q0_1 + let n7 = graph.add_node(1); //q1_1 + graph.add_edge(n0, n2, 0); //q0 -> h + graph.add_edge(n2, n3, 0); //h -> cx + graph.add_edge(n1, n3, 0); //q1 -> cx + graph.add_edge(n3, n4, 0); //cx -> cz + graph.add_edge(n3, n4, 0); //cx -> cz + graph.add_edge(n4, n6, 0); //cz -> q0_1 + graph.add_edge(n4, n5, 0); //cz -> y + graph.add_edge(n5, n7, 0); //y -> q1_1 + + let result = collect_bicolor_runs(&graph, test_filter_fn, test_color_fn).unwrap(); + let expected: Vec> = vec![vec![n2, n3, n4, n5]]; //[[h, cx, cz, y]] + assert_eq!(result, Some(expected)) + } +} diff --git a/rustworkx-core/src/lib.rs b/rustworkx-core/src/lib.rs index 373f9a32d..fc5d6f5df 100644 --- a/rustworkx-core/src/lib.rs +++ b/rustworkx-core/src/lib.rs @@ -98,11 +98,11 @@ pub mod centrality; /// Module for coloring algorithms. pub mod coloring; pub mod connectivity; +/// Module for algorithms that work on DAGs. +pub mod dag_algo; pub mod generators; pub mod graph_ext; pub mod line_graph; -/// Module for algorithms that work on DAGs. -pub mod dag_algo; /// Module for maximum weight matching algorithms. pub mod max_weight_matching; pub mod planar; diff --git a/src/dag_algo/mod.rs b/src/dag_algo/mod.rs index a0ca92ed0..9ac62f00d 100644 --- a/src/dag_algo/mod.rs +++ b/src/dag_algo/mod.rs @@ -18,9 +18,9 @@ use rustworkx_core::dictmap::InitWithHasher; use super::iterators::NodeIndices; use crate::{digraph, DAGHasCycle, InvalidNode, StablePyGraph}; -use rustworkx_core::dag_algo::{CollectBicolorError, lexicographical_topological_sort as core_lexico_topo_sort}; -use rustworkx_core::dag_algo::longest_path as core_longest_path; use rustworkx_core::dag_algo::collect_bicolor_runs as core_collect_bicolor_runs; +use rustworkx_core::dag_algo::lexicographical_topological_sort as core_lexico_topo_sort; +use rustworkx_core::dag_algo::longest_path as core_longest_path; use rustworkx_core::traversal::dfs_edges; use pyo3::exceptions::PyValueError; @@ -604,17 +604,6 @@ pub fn collect_runs( Ok(out_list) } -/// Define custom error conversion logic for collect_bicolor_runs. -fn convert_error(err: CollectBicolorError) -> PyErr { -// Note that we cannot implement From> for PyErr -// because nor PyErr nor CollectBicolorError are defined in this crate, -// so we use .map_err(convert_error) to convert a CollectBicolorError to PyErr instead. - match err { - CollectBicolorError::DAGHasCycle => PyErr::new::("Sort encountered a cycle"), - CollectBicolorError::CallableError(err) => err, - } -} - /// Collect runs that match a filter function given edge colors. /// /// A bicolor run is a list of group of nodes connected by edges of exactly @@ -645,44 +634,36 @@ pub fn collect_bicolor_runs( filter_fn: PyObject, color_fn: PyObject, ) -> PyResult>> { - let dag = &graph.graph; - // Wrap filter_fn to return Result, CollectBicolorError> - let filter_fn_wrapper = - |node: &PyObject| -> Result, CollectBicolorError> { - match filter_fn.call1(py, (node,)) { - Ok(res) => res.extract(py).map_err(CollectBicolorError::CallableError), - Err(err) => Err(CollectBicolorError::CallableError(err)), - } + let filter_fn_wrapper = |node: &PyObject| -> Result, PyErr> { + let res = filter_fn.call1(py, (node,))?; + res.extract(py) }; - // Wrap color_fn to return Result, CollectBicolorError> - let color_fn_wrapper = - |edge: &PyObject| -> Result, CollectBicolorError> { - match color_fn.call1(py, (edge,)) { - Ok(res) => res.extract(py).map_err(CollectBicolorError::CallableError), - Err(err) => Err(CollectBicolorError::CallableError(err)), - } + let color_fn_wrapper = |edge: &PyObject| -> Result, PyErr> { + let res = color_fn.call1(py, (edge,))?; + res.extract(py) + }; + + let block_list = match core_collect_bicolor_runs(dag, filter_fn_wrapper, color_fn_wrapper) { + Ok(Some(block_list)) => block_list + .into_iter() + .map(|index_list| { + index_list + .into_iter() + .map(|node_index| { + let node_weight = dag.node_weight(node_index).expect("Invalid NodeId"); + node_weight.into_py(py) + }) + .collect() + }) + .collect(), + Ok(None) => return Err(DAGHasCycle::new_err("The graph contains a cycle")), + Err(e) => return Err(e), }; - // Map CollectBicolorError to PyErr using custom convert_error function - let block_list = - core_collect_bicolor_runs::<&StablePyGraph, _, _, (), PyErr>( - dag, - filter_fn_wrapper, - color_fn_wrapper - ).map_err(convert_error)?; - - // Convert the result list from Vec> to Vec> - let py_block_list: Vec> = block_list.into_iter().map(|index_list| { - index_list.into_iter().map(|node_index| { - let node_weight = dag.node_weight(node_index).expect("Invalid NodeId"); - node_weight.into_py(py) - }).collect() - }).collect(); - - Ok(py_block_list) + Ok(block_list) } /// Returns the transitive reduction of a directed acyclic graph From 842e5bffce88502fec9a9e52355cb62789b6aa12 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Elena=20Pe=C3=B1a=20Tapia?= Date: Tue, 4 Jun 2024 13:27:51 +0200 Subject: [PATCH 08/16] Use PyResult, fix comments --- rustworkx-core/src/dag_algo.rs | 162 ++++++++++++++++----------------- src/dag_algo/mod.rs | 4 +- 2 files changed, 83 insertions(+), 83 deletions(-) diff --git a/rustworkx-core/src/dag_algo.rs b/rustworkx-core/src/dag_algo.rs index c0be386a9..3aec4fd68 100644 --- a/rustworkx-core/src/dag_algo.rs +++ b/rustworkx-core/src/dag_algo.rs @@ -468,7 +468,7 @@ where Ok(Some(block_list)) } -/// Tests for longest_path +// Tests for longest_path #[cfg(test)] mod test_longest_path { use super::*; @@ -577,7 +577,7 @@ mod test_longest_path { } } -/// Tests for lexicographical_topological_sort +// Tests for lexicographical_topological_sort // pub fn lexicographical_topological_sort( // dag: G, // mut key: F, @@ -757,7 +757,7 @@ mod test_lexicographical_topological_sort { } } -/// Tests for collect_bicolor_runs +// Tests for collect_bicolor_runs #[cfg(test)] mod test_collect_bicolor_runs { @@ -822,35 +822,35 @@ mod test_collect_bicolor_runs { #[test] fn test_two_colors() { - // Based on the following graph from the Python unit tests: - // Input: - // ┌─────────────┐ ┌─────────────┐ - // │ │ │ │ - // │ q0 │ │ q1 │ - // │ │ │ │ - // └───┬─────────┘ └──────┬──────┘ - // │ ┌─────────────┐ │ - // q0 │ │ │ │ q1 - // │ │ │ │ - // └─────────►│ cx │◄────────┘ - // ┌──────────┤ ├─────────┐ - // │ │ │ │ - // q0 │ └─────────────┘ │ q1 - // │ │ - // │ ┌─────────────┐ │ - // │ │ │ │ - // └─────────►│ cz │◄────────┘ - // ┌─────────┤ ├─────────┐ - // │ └─────────────┘ │ - // q0 │ │ q1 - // │ │ - // ┌───▼─────────┐ ┌──────▼──────┐ - // │ │ │ │ - // │ q0 │ │ q1 │ - // │ │ │ │ - // └─────────────┘ └─────────────┘ - // - // Expected: [[cx, cz]] + /* Based on the following graph from the Python unit tests: + Input: + ┌─────────────┐ ┌─────────────┐ + │ │ │ │ + │ q0 │ │ q1 │ + │ │ │ │ + └───┬─────────┘ └──────┬──────┘ + │ ┌─────────────┐ │ + q0 │ │ │ │ q1 + │ │ │ │ + └─────────►│ cx │◄────────┘ + ┌──────────┤ ├─────────┐ + │ │ │ │ + q0 │ └─────────────┘ │ q1 + │ │ + │ ┌─────────────┐ │ + │ │ │ │ + └─────────►│ cz │◄────────┘ + ┌─────────┤ ├─────────┐ + │ └─────────────┘ │ + q0 │ │ q1 + │ │ + ┌───▼─────────┐ ┌──────▼──────┐ + │ │ │ │ + │ q0 │ │ q1 │ + │ │ │ │ + └─────────────┘ └─────────────┘ + Expected: [[cx, cz]] + */ let mut graph = DiGraph::new(); // The node weight will correspond to the type of node // All edges have the same weight in this example @@ -873,55 +873,55 @@ mod test_collect_bicolor_runs { } #[test] fn test_two_colors_with_pending() { - // Based on the following graph from the Python unit tests: - // Input: - // ┌─────────────┐ - // │ │ - // │ q0 │ - // │ │ - // └───┬─────────┘ - // | q0 - // │ - // ┌───▼─────────┐ - // │ │ - // │ h │ - // │ │ - // └───┬─────────┘ - // | q0 - // │ ┌─────────────┐ - // │ │ │ - // │ │ q1 │ - // │ │ │ - // | └──────┬──────┘ - // │ ┌─────────────┐ │ - // q0 │ │ │ │ q1 - // │ │ │ │ - // └─────────►│ cx │◄────────┘ - // ┌──────────┤ ├─────────┐ - // │ │ │ │ - // q0 │ └─────────────┘ │ q1 - // │ │ - // │ ┌─────────────┐ │ - // │ │ │ │ - // └─────────►│ cz │◄────────┘ - // ┌─────────┤ ├─────────┐ - // │ └─────────────┘ │ - // q0 │ │ q1 - // │ │ - // ┌───▼─────────┐ ┌──────▼──────┐ - // │ │ │ │ - // │ q0 │ │ y │ - // │ │ │ │ - // └─────────────┘ └─────────────┘ - // | q1 - // │ - // ┌───▼─────────┐ - // │ │ - // │ q1 │ - // │ │ - // └─────────────┘ - // - // Expected: [[h, cx, cz, y]] + /* Based on the following graph from the Python unit tests: + Input: + ┌─────────────┐ + │ │ + │ q0 │ + │ │ + └───┬─────────┘ + | q0 + │ + ┌───▼─────────┐ + │ │ + │ h │ + │ │ + └───┬─────────┘ + | q0 + │ ┌─────────────┐ + │ │ │ + │ │ q1 │ + │ │ │ + | └──────┬──────┘ + │ ┌─────────────┐ │ + q0 │ │ │ │ q1 + │ │ │ │ + └─────────►│ cx │◄────────┘ + ┌──────────┤ ├─────────┐ + │ │ │ │ + q0 │ └─────────────┘ │ q1 + │ │ + │ ┌─────────────┐ │ + │ │ │ │ + └─────────►│ cz │◄────────┘ + ┌─────────┤ ├─────────┐ + │ └─────────────┘ │ + q0 │ │ q1 + │ │ + ┌───▼─────────┐ ┌──────▼──────┐ + │ │ │ │ + │ q0 │ │ y │ + │ │ │ │ + └─────────────┘ └─────────────┘ + | q1 + │ + ┌───▼─────────┐ + │ │ + │ q1 │ + │ │ + └─────────────┘ + Expected: [[h, cx, cz, y]] + */ let mut graph = DiGraph::new(); // The node weight will correspond to the type of node // All edges have the same weight in this example diff --git a/src/dag_algo/mod.rs b/src/dag_algo/mod.rs index 9ac62f00d..6343958d6 100644 --- a/src/dag_algo/mod.rs +++ b/src/dag_algo/mod.rs @@ -636,12 +636,12 @@ pub fn collect_bicolor_runs( ) -> PyResult>> { let dag = &graph.graph; - let filter_fn_wrapper = |node: &PyObject| -> Result, PyErr> { + let filter_fn_wrapper = |node: &PyObject| -> PyResult> { let res = filter_fn.call1(py, (node,))?; res.extract(py) }; - let color_fn_wrapper = |edge: &PyObject| -> Result, PyErr> { + let color_fn_wrapper = |edge: &PyObject| -> PyResult> { let res = color_fn.call1(py, (edge,))?; res.extract(py) }; From cd5b752732eb6d53ab6487c38e2a15382c652184 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Elena=20Pe=C3=B1a=20Tapia?= Date: Thu, 6 Jun 2024 12:01:40 +0200 Subject: [PATCH 09/16] Apply feedback from Matt's review, make tests make more sense. --- rustworkx-core/src/dag_algo.rs | 48 +++++++++++----------- tests/digraph/test_collect_bicolor_runs.py | 1 + 2 files changed, 25 insertions(+), 24 deletions(-) diff --git a/rustworkx-core/src/dag_algo.rs b/rustworkx-core/src/dag_algo.rs index 3aec4fd68..6712e3466 100644 --- a/rustworkx-core/src/dag_algo.rs +++ b/rustworkx-core/src/dag_algo.rs @@ -346,26 +346,32 @@ where /// /// ```rust /// use rustworkx_core::dag_algo::collect_bicolor_runs; -/// use petgraph::graph::DiGraph; +/// use petgraph::graph::{DiGraph, NodeIndex}; +/// use std::convert::Infallible; /// use std::error::Error; -/// /// let mut graph = DiGraph::new(); /// let n0 = graph.add_node(0); -/// let n1 = graph.add_node(1); -/// let n2 = graph.add_node(2); -/// let n3 = graph.add_node(3); +/// let n1 = graph.add_node(0); +/// let n2 = graph.add_node(1); +/// let n3 = graph.add_node(1); +/// let n4 = graph.add_node(0); +/// let n5 = graph.add_node(0); /// graph.add_edge(n0, n2, 0); -/// graph.add_edge(n1, n2, 0); -/// graph.add_edge(n2, n3, 0); +/// graph.add_edge(n1, n2, 1); /// graph.add_edge(n2, n3, 0); +/// graph.add_edge(n2, n3, 1); +/// graph.add_edge(n3, n4, 0); +/// graph.add_edge(n3, n5, 1); /// // filter_fn and color_fn must share error type -/// fn filter_fn(node: &i32) -> Result, Box> { -/// Ok(Some(*node > 1)) +/// fn filter_fn(node: &i32) -> Result, Infallible> { +/// Ok(Some(*node > 0)) /// } -/// fn color_fn(edge: &i32) -> Result, Box> { +/// fn color_fn(edge: &i32) -> Result, Infallible> { /// Ok(Some(*edge as usize)) /// } /// let result = collect_bicolor_runs(&graph, filter_fn, color_fn).unwrap(); +/// let expected: Vec> = vec![vec![n2, n3]]; +/// assert_eq!(result, Some(expected)) /// ``` pub fn collect_bicolor_runs( graph: G, @@ -578,13 +584,6 @@ mod test_longest_path { } // Tests for lexicographical_topological_sort -// pub fn lexicographical_topological_sort( -// dag: G, -// mut key: F, -// reverse: bool, -// initial: Option<&[G::NodeId]>, -// ) -> Result>, E> - #[cfg(test)] mod test_lexicographical_topological_sort { use super::*; @@ -861,16 +860,17 @@ mod test_collect_bicolor_runs { let n4 = graph.add_node(0); //q0_1 let n5 = graph.add_node(1); //q1_1 graph.add_edge(n0, n2, 0); //q0 -> cx - graph.add_edge(n1, n2, 0); //q1 -> cx - graph.add_edge(n2, n3, 0); //cx -> cz + graph.add_edge(n1, n2, 1); //q1 -> cx graph.add_edge(n2, n3, 0); //cx -> cz + graph.add_edge(n2, n3, 1); //cx -> cz graph.add_edge(n3, n4, 0); //cz -> q0_1 - graph.add_edge(n3, n5, 0); //cz -> q1_1 + graph.add_edge(n3, n5, 1); //cz -> q1_1 let result = collect_bicolor_runs(&graph, test_filter_fn, test_color_fn).unwrap(); let expected: Vec> = vec![vec![n2, n3]]; //[[cx, cz]] assert_eq!(result, Some(expected)) } + #[test] fn test_two_colors_with_pending() { /* Based on the following graph from the Python unit tests: @@ -935,12 +935,12 @@ mod test_collect_bicolor_runs { let n7 = graph.add_node(1); //q1_1 graph.add_edge(n0, n2, 0); //q0 -> h graph.add_edge(n2, n3, 0); //h -> cx - graph.add_edge(n1, n3, 0); //q1 -> cx - graph.add_edge(n3, n4, 0); //cx -> cz + graph.add_edge(n1, n3, 1); //q1 -> cx graph.add_edge(n3, n4, 0); //cx -> cz + graph.add_edge(n3, n4, 1); //cx -> cz graph.add_edge(n4, n6, 0); //cz -> q0_1 - graph.add_edge(n4, n5, 0); //cz -> y - graph.add_edge(n5, n7, 0); //y -> q1_1 + graph.add_edge(n4, n5, 1); //cz -> y + graph.add_edge(n5, n7, 1); //y -> q1_1 let result = collect_bicolor_runs(&graph, test_filter_fn, test_color_fn).unwrap(); let expected: Vec> = vec![vec![n2, n3, n4, n5]]; //[[h, cx, cz, y]] diff --git a/tests/digraph/test_collect_bicolor_runs.py b/tests/digraph/test_collect_bicolor_runs.py index b0a903e15..b1c32297c 100644 --- a/tests/digraph/test_collect_bicolor_runs.py +++ b/tests/digraph/test_collect_bicolor_runs.py @@ -99,6 +99,7 @@ def filter_function(node): return None def color_function(node): + print("node name:", node) if "q" in node: return int(node[1:]) else: From f60cf6b25169d79dc248ccd59fe114154a7b687a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Elena=20Pe=C3=B1a=20Tapia?= <57907331+ElePT@users.noreply.github.com> Date: Thu, 6 Jun 2024 13:09:10 +0200 Subject: [PATCH 10/16] Add suggestion from Matt's code review Co-authored-by: Matthew Treinish --- rustworkx-core/src/dag_algo.rs | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/rustworkx-core/src/dag_algo.rs b/rustworkx-core/src/dag_algo.rs index 6712e3466..394f4e1ee 100644 --- a/rustworkx-core/src/dag_algo.rs +++ b/rustworkx-core/src/dag_algo.rs @@ -379,13 +379,12 @@ pub fn collect_bicolor_runs( color_fn: C, ) -> Result>>, E> where - F: Fn(&::NodeWeight) -> Result, E>, - C: Fn(&::EdgeWeight) -> Result, E>, + F: Fn(G::NodeId) -> Result, E>, + C: Fn(G::EdgeId) -> Result, E>, G: IntoNodeIdentifiers // Used in toposort + IntoNeighborsDirected // Used in toposort + IntoEdgesDirected // Used for .edges_directed + Visitable // Used in toposort - + DataMap, // Used for .node_weight ::NodeId: Eq + Hash, { let mut pending_list: Vec> = Vec::new(); From 23ba36b4ac7ebae07af04595804a11b345b7dc04 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Elena=20Pe=C3=B1a=20Tapia?= <57907331+ElePT@users.noreply.github.com> Date: Thu, 6 Jun 2024 13:09:38 +0200 Subject: [PATCH 11/16] Remove stray print Co-authored-by: Matthew Treinish --- tests/digraph/test_collect_bicolor_runs.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/digraph/test_collect_bicolor_runs.py b/tests/digraph/test_collect_bicolor_runs.py index b1c32297c..b0a903e15 100644 --- a/tests/digraph/test_collect_bicolor_runs.py +++ b/tests/digraph/test_collect_bicolor_runs.py @@ -99,7 +99,6 @@ def filter_function(node): return None def color_function(node): - print("node name:", node) if "q" in node: return int(node[1:]) else: From c256e27721c9b836899a64cfd93b2e71993765ae Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Elena=20Pe=C3=B1a=20Tapia?= Date: Thu, 6 Jun 2024 14:28:34 +0200 Subject: [PATCH 12/16] Change callbacks to take ids instead of weights, rename some variables in tests for clarity. --- ...collect-bicolor-runs-ff95d7fdbc546e2e.yaml | 2 +- rustworkx-core/src/dag_algo.rs | 105 +++++++++++------- src/dag_algo/mod.rs | 10 +- tests/digraph/test_collect_bicolor_runs.py | 24 ++-- 4 files changed, 83 insertions(+), 58 deletions(-) diff --git a/releasenotes/notes/migrate-collect-bicolor-runs-ff95d7fdbc546e2e.yaml b/releasenotes/notes/migrate-collect-bicolor-runs-ff95d7fdbc546e2e.yaml index 300d75832..51705416b 100644 --- a/releasenotes/notes/migrate-collect-bicolor-runs-ff95d7fdbc546e2e.yaml +++ b/releasenotes/notes/migrate-collect-bicolor-runs-ff95d7fdbc546e2e.yaml @@ -3,4 +3,4 @@ features: - | Added a new function ``collect_bicolor_runs`` to rustworkx-core's ``dag_algo`` module. Previously, the ``collect_bicolor_runs`` functionality for DAGs was only exposed - via the Python interface. Now Rust users can take advantage of this functionality in rustworkx-core. + via the Python interface. Now Rust users can take advantage of this functionality in ``rustworkx-core``. diff --git a/rustworkx-core/src/dag_algo.rs b/rustworkx-core/src/dag_algo.rs index 394f4e1ee..81105d186 100644 --- a/rustworkx-core/src/dag_algo.rs +++ b/rustworkx-core/src/dag_algo.rs @@ -19,7 +19,7 @@ use hashbrown::HashMap; use petgraph::algo; use petgraph::data::DataMap; use petgraph::visit::{ - Data, EdgeRef, GraphBase, GraphProp, IntoEdgesDirected, IntoNeighborsDirected, + EdgeRef, GraphBase, GraphProp, IntoEdgeReferences, IntoEdgesDirected, IntoNeighborsDirected, IntoNodeIdentifiers, NodeCount, Visitable, }; use petgraph::Directed; @@ -362,13 +362,12 @@ where /// graph.add_edge(n2, n3, 1); /// graph.add_edge(n3, n4, 0); /// graph.add_edge(n3, n5, 1); -/// // filter_fn and color_fn must share error type -/// fn filter_fn(node: &i32) -> Result, Infallible> { -/// Ok(Some(*node > 0)) -/// } -/// fn color_fn(edge: &i32) -> Result, Infallible> { -/// Ok(Some(*edge as usize)) -/// } +/// let filter_fn = |node_id| -> Result, Infallible> { +/// Ok(Some(*graph.node_weight(node_id).unwrap() > 0)) +/// }; +/// let color_fn = |edge_id| -> Result, Infallible> { +/// Ok(Some(*graph.edge_weight(edge_id).unwrap() as usize)) +/// }; /// let result = collect_bicolor_runs(&graph, filter_fn, color_fn).unwrap(); /// let expected: Vec> = vec![vec![n2, n3]]; /// assert_eq!(result, Some(expected)) @@ -379,13 +378,16 @@ pub fn collect_bicolor_runs( color_fn: C, ) -> Result>>, E> where - F: Fn(G::NodeId) -> Result, E>, - C: Fn(G::EdgeId) -> Result, E>, + F: Fn(::NodeId) -> Result, E>, + C: Fn(::EdgeId) -> Result, E>, G: IntoNodeIdentifiers // Used in toposort + IntoNeighborsDirected // Used in toposort + IntoEdgesDirected // Used for .edges_directed + + IntoEdgeReferences + Visitable // Used in toposort + + DataMap, // Used for .node_weight ::NodeId: Eq + Hash, + ::EdgeId: Eq + Hash, { let mut pending_list: Vec> = Vec::new(); let mut block_id: Vec> = Vec::new(); @@ -407,15 +409,12 @@ where } for node in nodes { - if let Some(is_match) = filter_fn(graph.node_weight(node).expect("Invalid NodeId"))? { + if let Some(is_match) = filter_fn(node)? { let raw_edges = graph.edges_directed(node, petgraph::Direction::Outgoing); // Remove all edges that yield errors from color_fn let colors = raw_edges - .map(|edge| { - let edge_weight = edge.weight(); - color_fn(edge_weight) - }) + .map(|edge| color_fn(edge.id())) .collect::>, _>>()?; // Remove null edges from color_fn @@ -760,27 +759,23 @@ mod test_lexicographical_topological_sort { mod test_collect_bicolor_runs { use super::*; - use petgraph::graph::{DiGraph, NodeIndex}; + use petgraph::graph::{DiGraph, EdgeIndex, NodeIndex}; use std::error::Error; - fn test_filter_fn(node: &i32) -> Result, Box> { - Ok(Some(*node > 1)) - } - - fn test_color_fn(edge: &i32) -> Result, Box> { - Ok(Some(*edge as usize)) - } - #[test] fn test_cycle() { let mut graph = DiGraph::new(); - let n0 = graph.add_node(2); - let n1 = graph.add_node(2); - let n2 = graph.add_node(2); + let n0 = graph.add_node(0); + let n1 = graph.add_node(0); + let n2 = graph.add_node(0); graph.add_edge(n0, n1, 1); graph.add_edge(n1, n2, 1); graph.add_edge(n2, n0, 1); + let test_filter_fn = + |_node_id: NodeIndex| -> Result, Box> { Ok(Some(true)) }; + let test_color_fn = + |_edge_id: EdgeIndex| -> Result, Box> { Ok(Some(1)) }; let result = match collect_bicolor_runs(&graph, test_filter_fn, test_color_fn) { Ok(Some(_value)) => "Not None", Ok(None) => "None", @@ -794,14 +789,18 @@ mod test_collect_bicolor_runs { let mut graph = DiGraph::new(); graph.add_node(0); - fn fail_function(node: &i32) -> Result, Box> { - if *node > 0 { + let fail_function = |node_id: NodeIndex| -> Result, Box> { + let node_weight: &i32 = graph.node_weight(node_id).expect("Invalid NodeId"); + if *node_weight > 0 { Ok(Some(true)) } else { Err(Box::from("Failed!")) } - } - + }; + let test_color_fn = |edge_id: EdgeIndex| -> Result, Box> { + let edge_weight: &i32 = graph.edge_weight(edge_id).expect("Invalid Edge"); + Ok(Some(*edge_weight as usize)) + }; let result = match collect_bicolor_runs(&graph, fail_function, test_color_fn) { Ok(Some(_value)) => "Not None", Ok(None) => "None", @@ -813,6 +812,14 @@ mod test_collect_bicolor_runs { #[test] fn test_empty() { let graph = DiGraph::new(); + let test_filter_fn = |node_id: NodeIndex| -> Result, Box> { + let node_weight: &i32 = graph.node_weight(node_id).expect("Invalid NodeId"); + Ok(Some(*node_weight > 1)) + }; + let test_color_fn = |edge_id: EdgeIndex| -> Result, Box> { + let edge_weight: &i32 = graph.edge_weight(edge_id).expect("Invalid Edge"); + Ok(Some(*edge_weight as usize)) + }; let result = collect_bicolor_runs(&graph, test_filter_fn, test_color_fn).unwrap(); let expected: Vec> = vec![]; assert_eq!(result, Some(expected)) @@ -850,8 +857,6 @@ mod test_collect_bicolor_runs { Expected: [[cx, cz]] */ let mut graph = DiGraph::new(); - // The node weight will correspond to the type of node - // All edges have the same weight in this example let n0 = graph.add_node(0); //q0 let n1 = graph.add_node(1); //q1 let n2 = graph.add_node(2); //cx @@ -865,6 +870,16 @@ mod test_collect_bicolor_runs { graph.add_edge(n3, n4, 0); //cz -> q0_1 graph.add_edge(n3, n5, 1); //cz -> q1_1 + // Filter out q0, q1, q0_1 and q1_1 + let test_filter_fn = |node_id: NodeIndex| -> Result, Box> { + let node_weight: &i32 = graph.node_weight(node_id).expect("Invalid NodeId"); + Ok(Some(*node_weight > 0)) + }; + // The edge color will match its weight + let test_color_fn = |edge_id: EdgeIndex| -> Result, Box> { + let edge_weight: &i32 = graph.edge_weight(edge_id).expect("Invalid Edge"); + Ok(Some(*edge_weight as usize)) + }; let result = collect_bicolor_runs(&graph, test_filter_fn, test_color_fn).unwrap(); let expected: Vec> = vec![vec![n2, n3]]; //[[cx, cz]] assert_eq!(result, Some(expected)) @@ -922,16 +937,14 @@ mod test_collect_bicolor_runs { Expected: [[h, cx, cz, y]] */ let mut graph = DiGraph::new(); - // The node weight will correspond to the type of node - // All edges have the same weight in this example let n0 = graph.add_node(0); //q0 - let n1 = graph.add_node(1); //q1 - let n2 = graph.add_node(2); //h - let n3 = graph.add_node(3); //cx - let n4 = graph.add_node(4); //cz - let n5 = graph.add_node(5); //y + let n1 = graph.add_node(0); //q1 + let n2 = graph.add_node(1); //h + let n3 = graph.add_node(1); //cx + let n4 = graph.add_node(1); //cz + let n5 = graph.add_node(1); //y let n6 = graph.add_node(0); //q0_1 - let n7 = graph.add_node(1); //q1_1 + let n7 = graph.add_node(0); //q1_1 graph.add_edge(n0, n2, 0); //q0 -> h graph.add_edge(n2, n3, 0); //h -> cx graph.add_edge(n1, n3, 1); //q1 -> cx @@ -941,6 +954,16 @@ mod test_collect_bicolor_runs { graph.add_edge(n4, n5, 1); //cz -> y graph.add_edge(n5, n7, 1); //y -> q1_1 + // Filter out q0, q1, q0_1 and q1_1 + let test_filter_fn = |node_id: NodeIndex| -> Result, Box> { + let node_weight: &i32 = graph.node_weight(node_id).expect("Invalid NodeId"); + Ok(Some(*node_weight > 0)) + }; + // The edge color will match its weight + let test_color_fn = |edge_id: EdgeIndex| -> Result, Box> { + let edge_weight: &i32 = graph.edge_weight(edge_id).expect("Invalid Edge"); + Ok(Some(*edge_weight as usize)) + }; let result = collect_bicolor_runs(&graph, test_filter_fn, test_color_fn).unwrap(); let expected: Vec> = vec![vec![n2, n3, n4, n5]]; //[[h, cx, cz, y]] assert_eq!(result, Some(expected)) diff --git a/src/dag_algo/mod.rs b/src/dag_algo/mod.rs index 6343958d6..99ccc86c9 100644 --- a/src/dag_algo/mod.rs +++ b/src/dag_algo/mod.rs @@ -636,13 +636,15 @@ pub fn collect_bicolor_runs( ) -> PyResult>> { let dag = &graph.graph; - let filter_fn_wrapper = |node: &PyObject| -> PyResult> { - let res = filter_fn.call1(py, (node,))?; + let filter_fn_wrapper = |node_index| -> Result, PyErr> { + let node_weight = dag.node_weight(node_index).expect("Invalid NodeId"); + let res = filter_fn.call1(py, (node_weight,))?; res.extract(py) }; - let color_fn_wrapper = |edge: &PyObject| -> PyResult> { - let res = color_fn.call1(py, (edge,))?; + let color_fn_wrapper = |edge_index| -> Result, PyErr> { + let edge_weight = dag.edge_weight(edge_index).expect("Invalid EdgeId"); + let res = color_fn.call1(py, (edge_weight,))?; res.extract(py) }; diff --git a/tests/digraph/test_collect_bicolor_runs.py b/tests/digraph/test_collect_bicolor_runs.py index b0a903e15..8a47c6740 100644 --- a/tests/digraph/test_collect_bicolor_runs.py +++ b/tests/digraph/test_collect_bicolor_runs.py @@ -98,9 +98,9 @@ def filter_function(node): else: return None - def color_function(node): - if "q" in node: - return int(node[1:]) + def color_function(edge): + if "q" in edge: + return int(edge[1:]) else: return None @@ -187,9 +187,9 @@ def filter_function(node): else: return None - def color_function(node): - if "q" in node: - return int(node[1:]) + def color_function(edge): + if "q" in edge: + return int(edge[1:]) else: return None @@ -264,9 +264,9 @@ def filter_function(node): else: return None - def color_function(node): - if "q" in node: - return int(node[1:]) + def color_function(edge): + if "q" in edge: + return int(edge[1:]) else: return None @@ -338,9 +338,9 @@ def filter_function(node): else: return None - def color_function(node): - if "q" in node: - return int(node[1:]) + def color_function(edge): + if "q" in edge: + return int(edge[1:]) else: return None From 0222319eca09f6e13b575cf359293f894b62267b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Elena=20Pe=C3=B1a=20Tapia?= Date: Thu, 6 Jun 2024 14:42:54 +0200 Subject: [PATCH 13/16] Fix test --- rustworkx-core/src/dag_algo.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/rustworkx-core/src/dag_algo.rs b/rustworkx-core/src/dag_algo.rs index 81105d186..2277b5ff6 100644 --- a/rustworkx-core/src/dag_algo.rs +++ b/rustworkx-core/src/dag_algo.rs @@ -858,11 +858,11 @@ mod test_collect_bicolor_runs { */ let mut graph = DiGraph::new(); let n0 = graph.add_node(0); //q0 - let n1 = graph.add_node(1); //q1 - let n2 = graph.add_node(2); //cx - let n3 = graph.add_node(3); //cz + let n1 = graph.add_node(0); //q1 + let n2 = graph.add_node(1); //cx + let n3 = graph.add_node(1); //cz let n4 = graph.add_node(0); //q0_1 - let n5 = graph.add_node(1); //q1_1 + let n5 = graph.add_node(0); //q1_1 graph.add_edge(n0, n2, 0); //q0 -> cx graph.add_edge(n1, n2, 1); //q1 -> cx graph.add_edge(n2, n3, 0); //cx -> cz From 030eee78e71902f0e007cb23f850eb0948d8912c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Elena=20Pe=C3=B1a=20Tapia?= <57907331+ElePT@users.noreply.github.com> Date: Mon, 10 Jun 2024 11:44:55 +0200 Subject: [PATCH 14/16] Apply suggestions from code review Co-authored-by: Matthew Treinish --- rustworkx-core/src/dag_algo.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/rustworkx-core/src/dag_algo.rs b/rustworkx-core/src/dag_algo.rs index 2277b5ff6..77e249c33 100644 --- a/rustworkx-core/src/dag_algo.rs +++ b/rustworkx-core/src/dag_algo.rs @@ -349,6 +349,7 @@ where /// use petgraph::graph::{DiGraph, NodeIndex}; /// use std::convert::Infallible; /// use std::error::Error; +/// /// let mut graph = DiGraph::new(); /// let n0 = graph.add_node(0); /// let n1 = graph.add_node(0); @@ -362,12 +363,15 @@ where /// graph.add_edge(n2, n3, 1); /// graph.add_edge(n3, n4, 0); /// graph.add_edge(n3, n5, 1); +/// /// let filter_fn = |node_id| -> Result, Infallible> { /// Ok(Some(*graph.node_weight(node_id).unwrap() > 0)) /// }; +/// /// let color_fn = |edge_id| -> Result, Infallible> { /// Ok(Some(*graph.edge_weight(edge_id).unwrap() as usize)) /// }; +/// /// let result = collect_bicolor_runs(&graph, filter_fn, color_fn).unwrap(); /// let expected: Vec> = vec![vec![n2, n3]]; /// assert_eq!(result, Some(expected)) From 9c262e03bc9179a2c0896034d8a43937832bd28e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Elena=20Pe=C3=B1a=20Tapia?= Date: Mon, 10 Jun 2024 11:47:06 +0200 Subject: [PATCH 15/16] Get rid of unused trait --- rustworkx-core/src/dag_algo.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/rustworkx-core/src/dag_algo.rs b/rustworkx-core/src/dag_algo.rs index 77e249c33..a60c1e478 100644 --- a/rustworkx-core/src/dag_algo.rs +++ b/rustworkx-core/src/dag_algo.rs @@ -19,7 +19,7 @@ use hashbrown::HashMap; use petgraph::algo; use petgraph::data::DataMap; use petgraph::visit::{ - EdgeRef, GraphBase, GraphProp, IntoEdgeReferences, IntoEdgesDirected, IntoNeighborsDirected, + EdgeRef, GraphBase, GraphProp, IntoEdgesDirected, IntoNeighborsDirected, IntoNodeIdentifiers, NodeCount, Visitable, }; use petgraph::Directed; @@ -387,7 +387,6 @@ where G: IntoNodeIdentifiers // Used in toposort + IntoNeighborsDirected // Used in toposort + IntoEdgesDirected // Used for .edges_directed - + IntoEdgeReferences + Visitable // Used in toposort + DataMap, // Used for .node_weight ::NodeId: Eq + Hash, From 7c0f7bcf3cbba3e1db4ee3f55f00c1bcb4e1631f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Elena=20Pe=C3=B1a=20Tapia?= Date: Mon, 10 Jun 2024 12:08:39 +0200 Subject: [PATCH 16/16] Fix fmt --- rustworkx-core/src/dag_algo.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/rustworkx-core/src/dag_algo.rs b/rustworkx-core/src/dag_algo.rs index a60c1e478..081802fa7 100644 --- a/rustworkx-core/src/dag_algo.rs +++ b/rustworkx-core/src/dag_algo.rs @@ -19,8 +19,8 @@ use hashbrown::HashMap; use petgraph::algo; use petgraph::data::DataMap; use petgraph::visit::{ - EdgeRef, GraphBase, GraphProp, IntoEdgesDirected, IntoNeighborsDirected, - IntoNodeIdentifiers, NodeCount, Visitable, + EdgeRef, GraphBase, GraphProp, IntoEdgesDirected, IntoNeighborsDirected, IntoNodeIdentifiers, + NodeCount, Visitable, }; use petgraph::Directed;