Skip to content

Commit

Permalink
Clean up unused parameters and functions, simplify error handling, ad…
Browse files Browse the repository at this point in the history
…d unit tests and reno.
  • Loading branch information
ElePT committed Jun 4, 2024
1 parent b77b269 commit c98b279
Show file tree
Hide file tree
Showing 4 changed files with 272 additions and 104 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
features:
- |
Added a new function ``collect_bicolor_runs`` to rustworkx-core's ``dag_algo`` module.
Previously, the ``collect_bicolor_runs`` functionality for DAGs was only exposed
via the Python interface. Now Rust users can take advantage of this functionality in rustworkx-core.
295 changes: 238 additions & 57 deletions rustworkx-core/src/dag_algo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@

use std::cmp::{Eq, Ordering};
use std::collections::BinaryHeap;
use std::error::Error;
use std::fmt::{Debug, Display, Formatter};
use std::hash::Hash;

use hashbrown::HashMap;
Expand All @@ -22,7 +20,7 @@ use petgraph::algo;
use petgraph::data::DataMap;
use petgraph::visit::{
Data, EdgeRef, GraphBase, GraphProp, IntoEdgesDirected, IntoNeighborsDirected,
IntoNodeIdentifiers, NodeIndexable, NodeCount, Visitable,
IntoNodeIdentifiers, NodeCount, Visitable,
};
use petgraph::Directed;

Expand Down Expand Up @@ -316,31 +314,6 @@ where
Ok(Some((path, path_weight)))
}

/// Define custom error classes for collect_bicolor_runs
#[derive(Debug)]
pub enum CollectBicolorError<E: Error> {
DAGHasCycle,
CallableError(E)
}

impl<E: Error> Error for CollectBicolorError<E> {}

impl<E: Error> Display for CollectBicolorError<E> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
CollectBicolorError::DAGHasCycle => fmt_dag_has_cycle(f),
CollectBicolorError::CallableError(ref e) => fmt_callable_error(f, e),
}
}
}

fn fmt_dag_has_cycle(f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "Sort encountered a cycle")
}

fn fmt_callable_error<E: Error>(f: &mut Formatter<'_>, inner: &E) -> std::fmt::Result {
write!(f, "The function failed with: {:?}", inner)
}
/// Collect runs that match a filter function given edge colors.
///
/// A bicolor run is a list of groups of nodes connected by edges of exactly
Expand All @@ -366,41 +339,56 @@ fn fmt_callable_error<E: Error>(f: &mut Formatter<'_>, inner: &E) -> std::fmt::R
///
/// * `Vec<Vec<G::NodeId>>`: a list of groups with exactly two edge colors, where each group
/// is a list of node data payload/weight for the nodes in the bicolor run
/// * `CollectBicolorError<E>>` if there is an error computing the bicolor runs
pub fn collect_bicolor_runs<G, F, C, B, E>(
/// * `None` if a cycle is found in the graph
/// * Raises an error if found computing the bicolor runs
///
/// # Example:
///
/// ```rust
/// use rustworkx_core::dag_algo::collect_bicolor_runs;
/// use petgraph::graph::DiGraph;
/// use std::error::Error;
///
/// let mut graph = DiGraph::new();
/// let n0 = graph.add_node(0);
/// let n1 = graph.add_node(1);
/// let n2 = graph.add_node(2);
/// let n3 = graph.add_node(3);
/// graph.add_edge(n0, n2, 0);
/// graph.add_edge(n1, n2, 0);
/// graph.add_edge(n2, n3, 0);
/// graph.add_edge(n2, n3, 0);
/// // filter_fn and color_fn must share error type
/// fn filter_fn(node: &i32) -> Result<Option<bool>, Box<dyn Error>> {
/// Ok(Some(*node > 1))
/// }
/// fn color_fn(edge: &i32) -> Result<Option<usize>, Box<dyn Error>> {
/// Ok(Some(*edge as usize))
/// }
/// let result = collect_bicolor_runs(&graph, filter_fn, color_fn).unwrap();
/// ```
pub fn collect_bicolor_runs<G, F, C, E>(
graph: G,
filter_fn: F,
color_fn: C,
) -> Result<Vec<Vec<G::NodeId>>, CollectBicolorError<E>>
) -> Result<Option<Vec<Vec<G::NodeId>>>, E>
where
E: Error,
F: Fn(&<G as Data>::NodeWeight) -> Result<Option<bool>, CollectBicolorError<E>>,
C: Fn(&<G as Data>::EdgeWeight) -> Result<Option<usize>, CollectBicolorError<E>>,
G: NodeIndexable
+ IntoNodeIdentifiers
+ IntoNeighborsDirected
+ IntoEdgesDirected
+ Visitable
+ DataMap,
F: Fn(&<G as Data>::NodeWeight) -> Result<Option<bool>, E>,
C: Fn(&<G as Data>::EdgeWeight) -> Result<Option<usize>, E>,
G: IntoNodeIdentifiers // Used in toposort
+ IntoNeighborsDirected // Used in toposort
+ IntoEdgesDirected // Used for .edges_directed
+ Visitable // Used in toposort
+ DataMap, // Used for .node_weight
<G as GraphBase>::NodeId: Eq + Hash,
{
let mut pending_list: Vec<Vec<G::NodeId>> = Vec::new();
let mut block_id: Vec<Option<usize>> = Vec::new();
let mut block_list: Vec<Vec<G::NodeId>> = Vec::new();

let filter_node =
|node: &<G as Data>::NodeWeight| -> Result<Option<bool>, CollectBicolorError<E>> {
filter_fn(node)
};

let color_edge =
|edge: &<G as Data>::EdgeWeight| -> Result<Option<usize>, CollectBicolorError<E>> {
color_fn(edge)
};

let nodes = match algo::toposort(&graph, None) {
let nodes = match algo::toposort(graph, None) {
Ok(nodes) => nodes,
Err(_err) => return Err(CollectBicolorError::DAGHasCycle),
Err(_) => return Ok(None), // Return None if the graph contains a cycle
};

// Utility for ensuring pending_list has the color index
Expand All @@ -414,14 +402,14 @@ where
}

for node in nodes {
if let Some(is_match) = filter_node(&graph.node_weight(node).expect("Invalid NodeId"))? {
if let Some(is_match) = filter_fn(graph.node_weight(node).expect("Invalid NodeId"))? {
let raw_edges = graph.edges_directed(node, petgraph::Direction::Outgoing);

// Remove all edges that do not yield errors from color_fn
// Remove all edges that yield errors from color_fn
let colors = raw_edges
.map(|edge| {
let edge_weight = edge.weight();
color_edge(edge_weight)
color_fn(edge_weight)
})
.collect::<Result<Vec<Option<usize>>, _>>()?;

Expand Down Expand Up @@ -477,9 +465,10 @@ where
}
}

Ok(block_list)
Ok(Some(block_list))
}

/// Tests for longest_path
#[cfg(test)]
mod test_longest_path {
use super::*;
Expand Down Expand Up @@ -588,6 +577,7 @@ mod test_longest_path {
}
}

/// Tests for lexicographical_topological_sort
// pub fn lexicographical_topological_sort<G, F, E>(
// dag: G,
// mut key: F,
Expand Down Expand Up @@ -765,4 +755,195 @@ mod test_lexicographical_topological_sort {
let result = lexicographical_topological_sort(&graph, sort_fn, false, Some(&initial));
assert_eq!(result, Ok(Some(vec![nodes[7]])));
}
}
}

/// Tests for collect_bicolor_runs
#[cfg(test)]
mod test_collect_bicolor_runs {

use super::*;
use petgraph::graph::{DiGraph, NodeIndex};
use std::error::Error;

fn test_filter_fn(node: &i32) -> Result<Option<bool>, Box<dyn Error>> {
Ok(Some(*node > 1))
}

fn test_color_fn(edge: &i32) -> Result<Option<usize>, Box<dyn Error>> {
Ok(Some(*edge as usize))
}

#[test]
fn test_cycle() {
let mut graph = DiGraph::new();
let n0 = graph.add_node(2);
let n1 = graph.add_node(2);
let n2 = graph.add_node(2);
graph.add_edge(n0, n1, 1);
graph.add_edge(n1, n2, 1);
graph.add_edge(n2, n0, 1);

let result = match collect_bicolor_runs(&graph, test_filter_fn, test_color_fn) {
Ok(Some(_value)) => "Not None",
Ok(None) => "None",
Err(_) => "Error",
};
assert_eq!(result, "None")
}

#[test]
fn test_filter_function_inner_exception() {
let mut graph = DiGraph::new();
graph.add_node(0);

fn fail_function(node: &i32) -> Result<Option<bool>, Box<dyn Error>> {
if *node > 0 {
Ok(Some(true))
} else {
Err(Box::from("Failed!"))
}
}

let result = match collect_bicolor_runs(&graph, fail_function, test_color_fn) {
Ok(Some(_value)) => "Not None",
Ok(None) => "None",
Err(_) => "Error",
};
assert_eq!(result, "Error")
}

#[test]
fn test_empty() {
let graph = DiGraph::new();
let result = collect_bicolor_runs(&graph, test_filter_fn, test_color_fn).unwrap();
let expected: Vec<Vec<NodeIndex>> = vec![];
assert_eq!(result, Some(expected))
}

#[test]
fn test_two_colors() {
// Based on the following graph from the Python unit tests:
// Input:
// ┌─────────────┐ ┌─────────────┐
// │ │ │ │
// │ q0 │ │ q1 │
// │ │ │ │
// └───┬─────────┘ └──────┬──────┘
// │ ┌─────────────┐ │
// q0 │ │ │ │ q1
// │ │ │ │
// └─────────►│ cx │◄────────┘
// ┌──────────┤ ├─────────┐
// │ │ │ │
// q0 │ └─────────────┘ │ q1
// │ │
// │ ┌─────────────┐ │
// │ │ │ │
// └─────────►│ cz │◄────────┘
// ┌─────────┤ ├─────────┐
// │ └─────────────┘ │
// q0 │ │ q1
// │ │
// ┌───▼─────────┐ ┌──────▼──────┐
// │ │ │ │
// │ q0 │ │ q1 │
// │ │ │ │
// └─────────────┘ └─────────────┘
//
// Expected: [[cx, cz]]
let mut graph = DiGraph::new();
// The node weight will correspond to the type of node
// All edges have the same weight in this example
let n0 = graph.add_node(0); //q0
let n1 = graph.add_node(1); //q1
let n2 = graph.add_node(2); //cx
let n3 = graph.add_node(3); //cz
let n4 = graph.add_node(0); //q0_1
let n5 = graph.add_node(1); //q1_1
graph.add_edge(n0, n2, 0); //q0 -> cx
graph.add_edge(n1, n2, 0); //q1 -> cx
graph.add_edge(n2, n3, 0); //cx -> cz
graph.add_edge(n2, n3, 0); //cx -> cz
graph.add_edge(n3, n4, 0); //cz -> q0_1
graph.add_edge(n3, n5, 0); //cz -> q1_1

let result = collect_bicolor_runs(&graph, test_filter_fn, test_color_fn).unwrap();
let expected: Vec<Vec<NodeIndex>> = vec![vec![n2, n3]]; //[[cx, cz]]
assert_eq!(result, Some(expected))
}
#[test]
fn test_two_colors_with_pending() {
// Based on the following graph from the Python unit tests:
// Input:
// ┌─────────────┐
// │ │
// │ q0 │
// │ │
// └───┬─────────┘
// | q0
// │
// ┌───▼─────────┐
// │ │
// │ h │
// │ │
// └───┬─────────┘
// | q0
// │ ┌─────────────┐
// │ │ │
// │ │ q1 │
// │ │ │
// | └──────┬──────┘
// │ ┌─────────────┐ │
// q0 │ │ │ │ q1
// │ │ │ │
// └─────────►│ cx │◄────────┘
// ┌──────────┤ ├─────────┐
// │ │ │ │
// q0 │ └─────────────┘ │ q1
// │ │
// │ ┌─────────────┐ │
// │ │ │ │
// └─────────►│ cz │◄────────┘
// ┌─────────┤ ├─────────┐
// │ └─────────────┘ │
// q0 │ │ q1
// │ │
// ┌───▼─────────┐ ┌──────▼──────┐
// │ │ │ │
// │ q0 │ │ y │
// │ │ │ │
// └─────────────┘ └─────────────┘
// | q1
// │
// ┌───▼─────────┐
// │ │
// │ q1 │
// │ │
// └─────────────┘
//
// Expected: [[h, cx, cz, y]]
let mut graph = DiGraph::new();
// The node weight will correspond to the type of node
// All edges have the same weight in this example
let n0 = graph.add_node(0); //q0
let n1 = graph.add_node(1); //q1
let n2 = graph.add_node(2); //h
let n3 = graph.add_node(3); //cx
let n4 = graph.add_node(4); //cz
let n5 = graph.add_node(5); //y
let n6 = graph.add_node(0); //q0_1
let n7 = graph.add_node(1); //q1_1
graph.add_edge(n0, n2, 0); //q0 -> h
graph.add_edge(n2, n3, 0); //h -> cx
graph.add_edge(n1, n3, 0); //q1 -> cx
graph.add_edge(n3, n4, 0); //cx -> cz
graph.add_edge(n3, n4, 0); //cx -> cz
graph.add_edge(n4, n6, 0); //cz -> q0_1
graph.add_edge(n4, n5, 0); //cz -> y
graph.add_edge(n5, n7, 0); //y -> q1_1

let result = collect_bicolor_runs(&graph, test_filter_fn, test_color_fn).unwrap();
let expected: Vec<Vec<NodeIndex>> = vec![vec![n2, n3, n4, n5]]; //[[h, cx, cz, y]]
assert_eq!(result, Some(expected))
}
}
4 changes: 2 additions & 2 deletions rustworkx-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,11 +98,11 @@ pub mod centrality;
/// Module for coloring algorithms.
pub mod coloring;
pub mod connectivity;
/// Module for algorithms that work on DAGs.
pub mod dag_algo;
pub mod generators;
pub mod graph_ext;
pub mod line_graph;
/// Module for algorithms that work on DAGs.
pub mod dag_algo;
/// Module for maximum weight matching algorithms.
pub mod max_weight_matching;
pub mod planar;
Expand Down
Loading

0 comments on commit c98b279

Please sign in to comment.