Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move collect_bicolor_runs() to rustworkx-core #1186

Merged
merged 18 commits into from
Jun 10, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
features:
- |
Added a new function ``collect_bicolor_runs`` to rustworkx-core's ``dag_algo`` module.
Previously, the ``collect_bicolor_runs`` functionality for DAGs was only exposed
via the Python interface. Now Rust users can take advantage of this functionality in rustworkx-core.
352 changes: 350 additions & 2 deletions rustworkx-core/src/dag_algo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@ use std::hash::Hash;
use hashbrown::HashMap;

use petgraph::algo;
use petgraph::data::DataMap;
use petgraph::visit::{
EdgeRef, GraphBase, GraphProp, IntoEdgesDirected, IntoNeighborsDirected, IntoNodeIdentifiers,
NodeCount, Visitable,
Data, EdgeRef, GraphBase, GraphProp, IntoEdgesDirected, IntoNeighborsDirected,
IntoNodeIdentifiers, NodeCount, Visitable,
};
use petgraph::Directed;

Expand Down Expand Up @@ -313,6 +314,161 @@ where
Ok(Some((path, path_weight)))
}

/// Collect runs that match a filter function given edge colors.
///
/// A bicolor run is a list of groups of nodes connected by edges of exactly
/// two colors. In addition, all nodes in the group must match the given
/// condition. Each node in the graph can appear in only a single group
/// in the bicolor run.
///
/// # Arguments:
///
/// * `dag`: The DAG to find bicolor runs in
/// * `filter_fn`: The filter function to use for matching nodes. It takes
/// in one argument, the node data payload/weight object, and will return a
/// boolean whether the node matches the conditions or not.
/// If it returns ``true``, it will continue the bicolor chain.
/// If it returns ``false``, it will stop the bicolor chain.
/// If it returns ``None`` it will skip that node.
/// * `color_fn`: The function that gives the color of the edge. It takes
/// in one argument, the edge data payload/weight object, and will
/// return a non-negative integer, the edge color. If the color is None,
/// the edge is ignored.
///
/// # Returns:
///
/// * `Vec<Vec<G::NodeId>>`: a list of groups with exactly two edge colors, where each group
/// is a list of node data payload/weight for the nodes in the bicolor run
/// * `None` if a cycle is found in the graph
/// * Raises an error if found computing the bicolor runs
///
/// # Example:
///
/// ```rust
/// use rustworkx_core::dag_algo::collect_bicolor_runs;
/// use petgraph::graph::DiGraph;
/// use std::error::Error;
///
/// let mut graph = DiGraph::new();
ElePT marked this conversation as resolved.
Show resolved Hide resolved
/// let n0 = graph.add_node(0);
/// let n1 = graph.add_node(1);
/// let n2 = graph.add_node(2);
/// let n3 = graph.add_node(3);
/// graph.add_edge(n0, n2, 0);
/// graph.add_edge(n1, n2, 0);
/// graph.add_edge(n2, n3, 0);
/// graph.add_edge(n2, n3, 0);
/// // filter_fn and color_fn must share error type
/// fn filter_fn(node: &i32) -> Result<Option<bool>, Box<dyn Error>> {
ElePT marked this conversation as resolved.
Show resolved Hide resolved
/// Ok(Some(*node > 1))
/// }
/// fn color_fn(edge: &i32) -> Result<Option<usize>, Box<dyn Error>> {
/// Ok(Some(*edge as usize))
/// }
/// let result = collect_bicolor_runs(&graph, filter_fn, color_fn).unwrap();
ElePT marked this conversation as resolved.
Show resolved Hide resolved
ElePT marked this conversation as resolved.
Show resolved Hide resolved
/// ```
pub fn collect_bicolor_runs<G, F, C, E>(
graph: G,
filter_fn: F,
color_fn: C,
) -> Result<Option<Vec<Vec<G::NodeId>>>, E>
where
F: Fn(&<G as Data>::NodeWeight) -> Result<Option<bool>, E>,
C: Fn(&<G as Data>::EdgeWeight) -> Result<Option<usize>, E>,
G: IntoNodeIdentifiers // Used in toposort
+ IntoNeighborsDirected // Used in toposort
+ IntoEdgesDirected // Used for .edges_directed
+ Visitable // Used in toposort
+ DataMap, // Used for .node_weight
ElePT marked this conversation as resolved.
Show resolved Hide resolved
<G as GraphBase>::NodeId: Eq + Hash,
{
let mut pending_list: Vec<Vec<G::NodeId>> = Vec::new();
let mut block_id: Vec<Option<usize>> = Vec::new();
let mut block_list: Vec<Vec<G::NodeId>> = Vec::new();

let nodes = match algo::toposort(graph, None) {
Ok(nodes) => nodes,
Err(_) => return Ok(None), // Return None if the graph contains a cycle
};

// Utility for ensuring pending_list has the color index
macro_rules! ensure_vector_has_index {
($pending_list: expr, $block_id: expr, $color: expr) => {
if $color >= $pending_list.len() {
$pending_list.resize($color + 1, Vec::new());
$block_id.resize($color + 1, None);
}
};
}

for node in nodes {
if let Some(is_match) = filter_fn(graph.node_weight(node).expect("Invalid NodeId"))? {
let raw_edges = graph.edges_directed(node, petgraph::Direction::Outgoing);

// Remove all edges that yield errors from color_fn
let colors = raw_edges
.map(|edge| {
let edge_weight = edge.weight();
color_fn(edge_weight)
})
.collect::<Result<Vec<Option<usize>>, _>>()?;

// Remove null edges from color_fn
let colors = colors.into_iter().flatten().collect::<Vec<usize>>();

if colors.len() <= 2 && is_match {
if colors.len() == 1 {
let c0 = colors[0];
ensure_vector_has_index!(pending_list, block_id, c0);
if let Some(c0_block_id) = block_id[c0] {
block_list[c0_block_id].push(node);
} else {
pending_list[c0].push(node);
}
} else if colors.len() == 2 {
let c0 = colors[0];
let c1 = colors[1];
ensure_vector_has_index!(pending_list, block_id, c0);
ensure_vector_has_index!(pending_list, block_id, c1);

if block_id[c0].is_some()
&& block_id[c1].is_some()
&& block_id[c0] == block_id[c1]
{
block_list[block_id[c0].unwrap_or_default()].push(node);
} else {
let mut new_block: Vec<G::NodeId> =
Vec::with_capacity(pending_list[c0].len() + pending_list[c1].len() + 1);

// Clears pending lits and add to new block
new_block.append(&mut pending_list[c0]);
new_block.append(&mut pending_list[c1]);

new_block.push(node);

// Create new block, assign its id to color pair
block_id[c0] = Some(block_list.len());
block_id[c1] = Some(block_list.len());
block_list.push(new_block);
}
}
} else {
for color in colors {
ensure_vector_has_index!(pending_list, block_id, color);
if let Some(color_block_id) = block_id[color] {
block_list[color_block_id].append(&mut pending_list[color]);
}
block_id[color] = None;
pending_list[color].clear();
}
}
}
}

Ok(Some(block_list))
}

// Tests for longest_path
#[cfg(test)]
mod test_longest_path {
use super::*;
Expand Down Expand Up @@ -421,6 +577,7 @@ mod test_longest_path {
}
}

// Tests for lexicographical_topological_sort
// pub fn lexicographical_topological_sort<G, F, E>(
ElePT marked this conversation as resolved.
Show resolved Hide resolved
// dag: G,
// mut key: F,
Expand Down Expand Up @@ -599,3 +756,194 @@ mod test_lexicographical_topological_sort {
assert_eq!(result, Ok(Some(vec![nodes[7]])));
}
}

// Tests for collect_bicolor_runs
#[cfg(test)]
mod test_collect_bicolor_runs {
ElePT marked this conversation as resolved.
Show resolved Hide resolved

use super::*;
use petgraph::graph::{DiGraph, NodeIndex};
use std::error::Error;

fn test_filter_fn(node: &i32) -> Result<Option<bool>, Box<dyn Error>> {
Ok(Some(*node > 1))
}

fn test_color_fn(edge: &i32) -> Result<Option<usize>, Box<dyn Error>> {
Ok(Some(*edge as usize))
}

#[test]
fn test_cycle() {
let mut graph = DiGraph::new();
let n0 = graph.add_node(2);
let n1 = graph.add_node(2);
let n2 = graph.add_node(2);
graph.add_edge(n0, n1, 1);
graph.add_edge(n1, n2, 1);
graph.add_edge(n2, n0, 1);

let result = match collect_bicolor_runs(&graph, test_filter_fn, test_color_fn) {
Ok(Some(_value)) => "Not None",
Ok(None) => "None",
Err(_) => "Error",
};
assert_eq!(result, "None")
}

#[test]
fn test_filter_function_inner_exception() {
let mut graph = DiGraph::new();
graph.add_node(0);

fn fail_function(node: &i32) -> Result<Option<bool>, Box<dyn Error>> {
if *node > 0 {
Ok(Some(true))
} else {
Err(Box::from("Failed!"))
}
}

let result = match collect_bicolor_runs(&graph, fail_function, test_color_fn) {
Ok(Some(_value)) => "Not None",
Ok(None) => "None",
Err(_) => "Error",
};
assert_eq!(result, "Error")
}

#[test]
fn test_empty() {
let graph = DiGraph::new();
let result = collect_bicolor_runs(&graph, test_filter_fn, test_color_fn).unwrap();
let expected: Vec<Vec<NodeIndex>> = vec![];
assert_eq!(result, Some(expected))
}

#[test]
fn test_two_colors() {
/* Based on the following graph from the Python unit tests:
Input:
┌─────────────┐ ┌─────────────┐
│ │ │ │
│ q0 │ │ q1 │
│ │ │ │
└───┬─────────┘ └──────┬──────┘
│ ┌─────────────┐ │
q0 │ │ │ │ q1
│ │ │ │
└─────────►│ cx │◄────────┘
┌──────────┤ ├─────────┐
│ │ │ │
q0 │ └─────────────┘ │ q1
│ │
│ ┌─────────────┐ │
│ │ │ │
└─────────►│ cz │◄────────┘
┌─────────┤ ├─────────┐
│ └─────────────┘ │
q0 │ │ q1
│ │
┌───▼─────────┐ ┌──────▼──────┐
│ │ │ │
│ q0 │ │ q1 │
│ │ │ │
└─────────────┘ └─────────────┘
Expected: [[cx, cz]]
*/
let mut graph = DiGraph::new();
// The node weight will correspond to the type of node
// All edges have the same weight in this example
let n0 = graph.add_node(0); //q0
let n1 = graph.add_node(1); //q1
let n2 = graph.add_node(2); //cx
let n3 = graph.add_node(3); //cz
let n4 = graph.add_node(0); //q0_1
let n5 = graph.add_node(1); //q1_1
graph.add_edge(n0, n2, 0); //q0 -> cx
graph.add_edge(n1, n2, 0); //q1 -> cx
graph.add_edge(n2, n3, 0); //cx -> cz
graph.add_edge(n2, n3, 0); //cx -> cz
graph.add_edge(n3, n4, 0); //cz -> q0_1
graph.add_edge(n3, n5, 0); //cz -> q1_1

let result = collect_bicolor_runs(&graph, test_filter_fn, test_color_fn).unwrap();
let expected: Vec<Vec<NodeIndex>> = vec![vec![n2, n3]]; //[[cx, cz]]
assert_eq!(result, Some(expected))
}
#[test]
fn test_two_colors_with_pending() {
/* Based on the following graph from the Python unit tests:
Input:
┌─────────────┐
│ │
│ q0 │
│ │
└───┬─────────┘
| q0
┌───▼─────────┐
│ │
│ h │
│ │
└───┬─────────┘
| q0
│ ┌─────────────┐
│ │ │
│ │ q1 │
│ │ │
| └──────┬──────┘
│ ┌─────────────┐ │
q0 │ │ │ │ q1
│ │ │ │
└─────────►│ cx │◄────────┘
┌──────────┤ ├─────────┐
│ │ │ │
q0 │ └─────────────┘ │ q1
│ │
│ ┌─────────────┐ │
│ │ │ │
└─────────►│ cz │◄────────┘
┌─────────┤ ├─────────┐
│ └─────────────┘ │
q0 │ │ q1
│ │
┌───▼─────────┐ ┌──────▼──────┐
│ │ │ │
│ q0 │ │ y │
│ │ │ │
└─────────────┘ └─────────────┘
| q1
┌───▼─────────┐
│ │
│ q1 │
│ │
└─────────────┘
Expected: [[h, cx, cz, y]]
*/
let mut graph = DiGraph::new();
// The node weight will correspond to the type of node
// All edges have the same weight in this example
let n0 = graph.add_node(0); //q0
let n1 = graph.add_node(1); //q1
let n2 = graph.add_node(2); //h
let n3 = graph.add_node(3); //cx
let n4 = graph.add_node(4); //cz
let n5 = graph.add_node(5); //y
let n6 = graph.add_node(0); //q0_1
let n7 = graph.add_node(1); //q1_1
graph.add_edge(n0, n2, 0); //q0 -> h
graph.add_edge(n2, n3, 0); //h -> cx
graph.add_edge(n1, n3, 0); //q1 -> cx
graph.add_edge(n3, n4, 0); //cx -> cz
graph.add_edge(n3, n4, 0); //cx -> cz
graph.add_edge(n4, n6, 0); //cz -> q0_1
graph.add_edge(n4, n5, 0); //cz -> y
graph.add_edge(n5, n7, 0); //y -> q1_1

let result = collect_bicolor_runs(&graph, test_filter_fn, test_color_fn).unwrap();
let expected: Vec<Vec<NodeIndex>> = vec![vec![n2, n3, n4, n5]]; //[[h, cx, cz, y]]
assert_eq!(result, Some(expected))
}
}
1 change: 1 addition & 0 deletions rustworkx-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ pub mod centrality;
/// Module for coloring algorithms.
pub mod coloring;
pub mod connectivity;
/// Module for algorithms that work on DAGs.
pub mod dag_algo;
pub mod generators;
pub mod graph_ext;
Expand Down
Loading