diff --git a/releasenotes/notes/graph-ext-3f008015c31c592e.yaml b/releasenotes/notes/graph-ext-3f008015c31c592e.yaml new file mode 100644 index 000000000..d3b336908 --- /dev/null +++ b/releasenotes/notes/graph-ext-3f008015c31c592e.yaml @@ -0,0 +1,17 @@ +--- +features: + - | + Node contraction is now supported for ``petgraph`` types ``StableGraph`` + and ``GraphMap`` in rustworkx-core. To use it, import one of the ``ContractNodes*`` traits + from ``graph_ext`` and call the ``contract_nodes`` method on your graph. + - | + All current ``petgraph`` data structures now support testing for parallel + edges in ``rustworkx-core``. To use this, import ``HasParallelEdgesDirected`` + or ``HasParallelEdgesUndirected`` depending on your graph type, and call + the ``has_parallel_edges`` method on your graph. + - | + A new trait ``NodeRemovable`` has been added to ``graph_ext`` module in + ``rustworkx-core`` which provides a consistent interface for performing + node removal operations on ``petgraph`` types ``Graph``, ``StableGraph``, + ``GraphMap``, and ``MatrixGraph``. To use it, import ``NodeRemovable`` from + ``graph_ext``. diff --git a/rustworkx-core/src/err.rs b/rustworkx-core/src/err.rs new file mode 100644 index 000000000..e48965e20 --- /dev/null +++ b/rustworkx-core/src/err.rs @@ -0,0 +1,56 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations +// under the License. + +//! This module contains common error types and trait impls. + +use std::error::Error; +use std::fmt::{Debug, Display, Formatter}; + +#[derive(Debug)] +pub enum ContractError { + DAGWouldCycle, +} + +impl Display for ContractError { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + ContractError::DAGWouldCycle => fmt_dag_would_cycle(f), + } + } +} + +impl Error for ContractError {} + +#[derive(Debug)] +pub enum ContractSimpleError { + DAGWouldCycle, + MergeError(E), +} + +impl Display for ContractSimpleError { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + ContractSimpleError::DAGWouldCycle => fmt_dag_would_cycle(f), + ContractSimpleError::MergeError(ref e) => fmt_merge_error(f, e), + } + } +} + +impl Error for ContractSimpleError {} + +fn fmt_dag_would_cycle(f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "The operation would introduce a cycle.") +} + +fn fmt_merge_error(f: &mut Formatter<'_>, inner: &E) -> std::fmt::Result { + write!(f, "The prov failed with: {:?}", inner) +} diff --git a/rustworkx-core/src/graph_ext/contraction.rs b/rustworkx-core/src/graph_ext/contraction.rs new file mode 100644 index 000000000..bebb5b0bc --- /dev/null +++ b/rustworkx-core/src/graph_ext/contraction.rs @@ -0,0 +1,600 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations +// under the License. + +//! This module defines graph traits for node contraction. + +use crate::dictmap::{DictMap, InitWithHasher}; +use crate::err::{ContractError, ContractSimpleError}; +use crate::graph_ext::NodeRemovable; +use indexmap::map::Entry::{Occupied, Vacant}; +use indexmap::IndexSet; +use petgraph::data::Build; +use petgraph::graphmap; +use petgraph::stable_graph; +use petgraph::visit::{Data, Dfs, EdgeRef, GraphBase, GraphProp, IntoEdgesDirected, Visitable}; +use petgraph::{Directed, Direction, Undirected}; +use std::convert::Infallible; +use std::error::Error; +use std::hash::Hash; +use std::ops::Deref; + +pub trait ContractNodesDirected: Data { + /// The error type returned by contraction. + type Error: Error; + + /// Substitute a set of nodes with a single new node. + /// + /// The specified `nodes` are removed and replaced with a new node + /// with the given `weight`. Any nodes not in the graph are ignored. + /// It is valid for `nodes` to be empty, in which case the new node + /// is added to the graph without edges. + /// + /// The contraction may result in multiple edges between nodes if + /// the underlying graph is a multi-graph. If this is not desired, + /// use [ContractNodesSimpleDirected::contract_nodes_simple]. + /// + /// If `check_cycle` is enabled and the contraction would introduce + /// a cycle, an error is returned and the graph is not modified. + /// + /// The `NodeId` of the newly created node is returned. + /// + /// # Example + /// ``` + /// use std::convert::Infallible; + /// use petgraph::prelude::*; + /// use rustworkx_core::graph_ext::*; + /// + /// // Performs the following transformation: + /// // ┌─┐ + /// // │a│ + /// // └┬┘ ┌─┐ + /// // 0 │a│ + /// // ┌▼┐ └┬┘ + /// // │b│ 0 + /// // └┬┘ ┌▼┐ + /// // 1 ───► │m│ + /// // ┌▼┐ └┬┘ + /// // │c│ 2 + /// // └┬┘ ┌▼┐ + /// // 2 │d│ + /// // ┌▼┐ └─┘ + /// // │d│ + /// // └─┘ + /// let mut dag: StableDiGraph = StableDiGraph::default(); + /// let a = dag.add_node('a'); + /// let b = dag.add_node('b'); + /// let c = dag.add_node('c'); + /// let d = dag.add_node('d'); + /// dag.add_edge(a.clone(), b.clone(), 0); + /// dag.add_edge(b.clone(), c.clone(), 1); + /// dag.add_edge(c.clone(), d.clone(), 2); + /// + /// let m = dag.contract_nodes([b, c], 'm', true).unwrap(); + /// assert_eq!(dag.edge_weight(dag.find_edge(a.clone(), m.clone()).unwrap()).unwrap(), &0); + /// assert_eq!(dag.edge_weight(dag.find_edge(m.clone(), d.clone()).unwrap()).unwrap(), &2); + /// ``` + fn contract_nodes( + &mut self, + nodes: I, + weight: Self::NodeWeight, + check_cycle: bool, + ) -> Result + where + I: IntoIterator; +} + +impl ContractNodesDirected for stable_graph::StableGraph +where + Ix: stable_graph::IndexType, + E: Clone, +{ + type Error = ContractError; + + fn contract_nodes( + &mut self, + nodes: I, + obj: Self::NodeWeight, + check_cycle: bool, + ) -> Result + where + I: IntoIterator, + { + let nodes = IndexSet::from_iter(nodes); + if check_cycle && !can_contract(self.deref(), &nodes) { + return Err(ContractError::DAGWouldCycle); + } + Ok(contract_stable(self, nodes, obj, NoCallback::None).unwrap()) + } +} + +impl ContractNodesDirected for graphmap::GraphMap +where + for<'a> N: graphmap::NodeTrait + 'a, + for<'a> E: Clone + 'a, +{ + type Error = ContractError; + + fn contract_nodes( + &mut self, + nodes: I, + obj: Self::NodeWeight, + check_cycle: bool, + ) -> Result + where + I: IntoIterator, + { + let nodes = IndexSet::from_iter(nodes); + if check_cycle && !can_contract(self.deref(), &nodes) { + return Err(ContractError::DAGWouldCycle); + } + Ok(contract_stable(self, nodes, obj, NoCallback::None).unwrap()) + } +} + +pub trait ContractNodesSimpleDirected: Data { + /// The error type returned by contraction. + type Error: Error; + + /// Substitute a set of nodes with a single new node. + /// + /// The specified `nodes` are removed and replaced with a new node + /// with the given `weight`. Any nodes not in the graph are ignored. + /// It is valid for `nodes` to be empty, in which case the new node + /// is added to the graph without edges. + /// + /// The specified function `weight_combo_fn` is used to merge + /// would-be parallel edges during contraction; this function + /// preserves simple graphs. + /// + /// If `check_cycle` is enabled and the contraction would introduce + /// a cycle, an error is returned and the graph is not modified. + /// + /// The `NodeId` of the newly created node is returned. + /// + /// # Example + /// ``` + /// use std::convert::Infallible; + /// use petgraph::prelude::*; + /// use rustworkx_core::graph_ext::*; + /// + /// // Performs the following transformation: + /// // ┌─┐ + /// // ┌─┐ │a│ + /// // ┌0─┤a├─1┐ └┬┘ + /// // │ └─┘ │ 1 + /// // ┌▼┐ ┌▼┐ ┌▼┐ + /// // │b│ │c│ ───► │m│ + /// // └┬┘ └┬┘ └┬┘ + /// // │ ┌─┐ │ 3 + /// // └2►│d│◄3┘ ┌▼┐ + /// // └─┘ │d│ + /// // └─┘ + /// let mut dag: StableDiGraph = StableDiGraph::default(); + /// let a = dag.add_node('a'); + /// let b = dag.add_node('b'); + /// let c = dag.add_node('c'); + /// let d = dag.add_node('d'); + /// dag.add_edge(a.clone(), b.clone(), 0); + /// dag.add_edge(a.clone(), c.clone(), 1); + /// dag.add_edge(b.clone(), d.clone(), 2); + /// dag.add_edge(c.clone(), d.clone(), 3); + /// + /// let m = dag.contract_nodes_simple([b, c], 'm', true, |&e1, &e2| Ok::<_, Infallible>(if e1 > e2 { e1 } else { e2 } )).unwrap(); + /// assert_eq!(dag.edge_weight(dag.find_edge(a.clone(), m.clone()).unwrap()).unwrap(), &1); + /// assert_eq!(dag.edge_weight(dag.find_edge(m.clone(), d.clone()).unwrap()).unwrap(), &3); + /// ``` + fn contract_nodes_simple( + &mut self, + nodes: I, + weight: Self::NodeWeight, + check_cycle: bool, + weight_combo_fn: F, + ) -> Result> + where + I: IntoIterator, + F: FnMut(&Self::EdgeWeight, &Self::EdgeWeight) -> Result; +} + +impl ContractNodesSimpleDirected for stable_graph::StableGraph +where + Ix: stable_graph::IndexType, + E: Clone, +{ + type Error = ContractSimpleError; + + fn contract_nodes_simple( + &mut self, + nodes: I, + weight: Self::NodeWeight, + check_cycle: bool, + weight_combo_fn: F, + ) -> Result> + where + I: IntoIterator, + F: FnMut(&Self::EdgeWeight, &Self::EdgeWeight) -> Result, + { + let nodes = IndexSet::from_iter(nodes); + if check_cycle && !can_contract(self.deref(), &nodes) { + return Err(ContractSimpleError::DAGWouldCycle); + } + contract_stable(self, nodes, weight, Some(weight_combo_fn)) + .map_err(ContractSimpleError::MergeError) + } +} + +impl ContractNodesSimpleDirected for graphmap::GraphMap +where + for<'a> N: graphmap::NodeTrait + 'a, + for<'a> E: Clone + 'a, +{ + type Error = ContractSimpleError; + + fn contract_nodes_simple( + &mut self, + nodes: I, + weight: Self::NodeWeight, + check_cycle: bool, + weight_combo_fn: F, + ) -> Result> + where + I: IntoIterator, + F: FnMut(&Self::EdgeWeight, &Self::EdgeWeight) -> Result, + { + let nodes = IndexSet::from_iter(nodes); + if check_cycle && !can_contract(self.deref(), &nodes) { + return Err(ContractSimpleError::DAGWouldCycle); + } + contract_stable(self, nodes, weight, Some(weight_combo_fn)) + .map_err(ContractSimpleError::MergeError) + } +} + +pub trait ContractNodesUndirected: Data { + /// Substitute a set of nodes with a single new node. + /// + /// The specified `nodes` are removed and replaced with a new node + /// with the given `weight`. Any nodes not in the graph are ignored. + /// It is valid for `nodes` to be empty, in which case the new node + /// is added to the graph without edges. + /// + /// The contraction may result in multiple edges between nodes if + /// the underlying graph is a multi-graph. If this is not desired, + /// use [ContractNodesSimpleUndirected::contract_nodes_simple]. + /// + /// The `NodeId` of the newly created node is returned. + /// + /// # Example + /// ``` + /// use petgraph::prelude::*; + /// use rustworkx_core::graph_ext::*; + /// + /// // Performs the following transformation: + /// // ┌─┐ + /// // │a│ + /// // └┬┘ ┌─┐ + /// // 0 │a│ + /// // ┌┴┐ └┬┘ + /// // │b│ 0 + /// // └┬┘ ┌┴┐ + /// // 1 ───► │m│ + /// // ┌┴┐ └┬┘ + /// // │c│ 2 + /// // └┬┘ ┌┴┐ + /// // 2 │d│ + /// // ┌┴┐ └─┘ + /// // │d│ + /// // └─┘ + /// let mut dag: StableUnGraph = StableUnGraph::default(); + /// let a = dag.add_node('a'); + /// let b = dag.add_node('b'); + /// let c = dag.add_node('c'); + /// let d = dag.add_node('d'); + /// dag.add_edge(a.clone(), b.clone(), 0); + /// dag.add_edge(b.clone(), c.clone(), 1); + /// dag.add_edge(c.clone(), d.clone(), 2); + /// + /// let m = dag.contract_nodes([b, c], 'm'); + /// assert_eq!(dag.edge_weight(dag.find_edge(a.clone(), m.clone()).unwrap()).unwrap(), &0); + /// assert_eq!(dag.edge_weight(dag.find_edge(m.clone(), d.clone()).unwrap()).unwrap(), &2); + /// ``` + fn contract_nodes(&mut self, nodes: I, weight: Self::NodeWeight) -> Self::NodeId + where + I: IntoIterator; +} + +impl ContractNodesUndirected for stable_graph::StableGraph +where + Ix: stable_graph::IndexType, + E: Clone, +{ + fn contract_nodes(&mut self, nodes: I, obj: Self::NodeWeight) -> Self::NodeId + where + I: IntoIterator, + { + contract_stable(self, IndexSet::from_iter(nodes), obj, NoCallback::None).unwrap() + } +} + +impl ContractNodesUndirected for graphmap::GraphMap +where + for<'a> N: graphmap::NodeTrait + 'a, + for<'a> E: Clone + 'a, +{ + fn contract_nodes(&mut self, nodes: I, obj: Self::NodeWeight) -> Self::NodeId + where + I: IntoIterator, + { + contract_stable(self, IndexSet::from_iter(nodes), obj, NoCallback::None).unwrap() + } +} + +pub trait ContractNodesSimpleUndirected: Data { + type Error: Error; + + /// Substitute a set of nodes with a single new node. + /// + /// The specified `nodes` are removed and replaced with a new node + /// with the given `weight`. Any nodes not in the graph are ignored. + /// It is valid for `nodes` to be empty, in which case the new node + /// is added to the graph without edges. + /// + /// The specified function `weight_combo_fn` is used to merge + /// would-be parallel edges during contraction; this function + /// preserves simple graphs. + /// + /// The `NodeId` of the newly created node is returned. + /// + /// # Example + /// ``` + /// use std::convert::Infallible; + /// use petgraph::prelude::*; + /// use rustworkx_core::graph_ext::*; + /// + /// // Performs the following transformation: + /// // ┌─┐ + /// // ┌─┐ │a│ + /// // ┌0─┤a├─1┐ └┬┘ + /// // │ └─┘ │ 1 + /// // ┌┴┐ ┌┴┐ ┌┴┐ + /// // │b│ │c│ ───► │m│ + /// // └┬┘ └┬┘ └┬┘ + /// // │ ┌─┐ │ 3 + /// // └2─│d├─3┘ ┌┴┐ + /// // └─┘ │d│ + /// // └─┘ + /// let mut dag: StableUnGraph = StableUnGraph::default(); + /// let a = dag.add_node('a'); + /// let b = dag.add_node('b'); + /// let c = dag.add_node('c'); + /// let d = dag.add_node('d'); + /// dag.add_edge(a.clone(), b.clone(), 0); + /// dag.add_edge(a.clone(), c.clone(), 1); + /// dag.add_edge(b.clone(), d.clone(), 2); + /// dag.add_edge(c.clone(), d.clone(), 3); + /// + /// let m = dag.contract_nodes_simple([b, c], 'm', |&e1, &e2| Ok::<_, Infallible>(if e1 > e2 { e1 } else { e2 } )).unwrap(); + /// assert_eq!(dag.edge_weight(dag.find_edge(a.clone(), m.clone()).unwrap()).unwrap(), &1); + /// assert_eq!(dag.edge_weight(dag.find_edge(m.clone(), d.clone()).unwrap()).unwrap(), &3); + /// ``` + fn contract_nodes_simple( + &mut self, + nodes: I, + weight: Self::NodeWeight, + weight_combo_fn: F, + ) -> Result> + where + I: IntoIterator, + F: FnMut(&Self::EdgeWeight, &Self::EdgeWeight) -> Result; +} + +impl ContractNodesSimpleUndirected for stable_graph::StableGraph +where + Ix: stable_graph::IndexType, + E: Clone, +{ + type Error = ContractSimpleError; + + fn contract_nodes_simple( + &mut self, + nodes: I, + weight: Self::NodeWeight, + weight_combo_fn: F, + ) -> Result> + where + I: IntoIterator, + F: FnMut(&Self::EdgeWeight, &Self::EdgeWeight) -> Result, + { + contract_stable( + self, + IndexSet::from_iter(nodes), + weight, + Some(weight_combo_fn), + ) + .map_err(ContractSimpleError::MergeError) + } +} + +impl ContractNodesSimpleUndirected for graphmap::GraphMap +where + for<'a> N: graphmap::NodeTrait + 'a, + for<'a> E: Clone + 'a, +{ + type Error = ContractSimpleError; + + fn contract_nodes_simple( + &mut self, + nodes: I, + weight: Self::NodeWeight, + weight_combo_fn: F, + ) -> Result> + where + I: IntoIterator, + F: FnMut(&Self::EdgeWeight, &Self::EdgeWeight) -> Result, + { + contract_stable( + self, + IndexSet::from_iter(nodes), + weight, + Some(weight_combo_fn), + ) + .map_err(ContractSimpleError::MergeError) + } +} + +fn merge_duplicates(xs: Vec<(K, V)>, mut merge_fn: F) -> Result, E> +where + K: Hash + Eq, + F: FnMut(&V, &V) -> Result, +{ + let mut kvs = DictMap::with_capacity(xs.len()); + for (k, v) in xs { + match kvs.entry(k) { + Occupied(entry) => { + *entry.into_mut() = merge_fn(&v, entry.get())?; + } + Vacant(entry) => { + entry.insert(v); + } + } + } + Ok(kvs.into_iter().collect::>()) +} + +fn contract_stable( + graph: &mut G, + mut nodes: IndexSet, + weight: G::NodeWeight, + weight_combo_fn: Option, +) -> Result +where + G: GraphProp + NodeRemovable + Build + Visitable, + for<'b> &'b G: + GraphBase + Data + IntoEdgesDirected, + G::NodeId: Ord + Hash, + G::EdgeWeight: Clone, + F: FnMut(&G::EdgeWeight, &G::EdgeWeight) -> Result, +{ + let node_index = graph.add_node(weight); + + // Sanitize new node index from user input. + nodes.swap_remove(&node_index); + + // Connect old node edges to the replacement. + add_edges(graph, node_index, &nodes, weight_combo_fn).unwrap(); + + // Remove nodes that have been replaced. + for index in nodes { + graph.remove_node(index); + } + + Ok(node_index) +} + +fn can_contract(graph: G, nodes: &IndexSet) -> bool +where + G: Data + Visitable + IntoEdgesDirected, + G::NodeId: Eq + Hash, +{ + // Start with successors of `nodes` that aren't in `nodes` itself. + let visit_next: Vec = nodes + .iter() + .flat_map(|n| graph.edges(*n)) + .filter_map(|edge| { + let target_node = edge.target(); + if !nodes.contains(&target_node) { + Some(target_node) + } else { + None + } + }) + .collect(); + + // Now, if we can reach any of `nodes`, there exists a path from `nodes` + // back to `nodes` of length > 1, meaning contraction is disallowed. + let mut dfs = Dfs::from_parts(visit_next, graph.visit_map()); + while let Some(node) = dfs.next(graph) { + if nodes.contains(&node) { + // we found a path back to `nodes` + return false; + } + } + true +} + +// Helper type for specifying `NoCallback::None` at callsites of `contract`. +type NoCallback = Option Result>; + +fn add_edges( + graph: &mut G, + new_node: G::NodeId, + nodes: &IndexSet, + mut weight_combo_fn: Option, +) -> Result<(), E> +where + G: GraphProp + Build + Visitable, + for<'b> &'b G: + GraphBase + Data + IntoEdgesDirected, + G::NodeId: Ord + Hash, + G::EdgeWeight: Clone, + F: FnMut(&G::EdgeWeight, &G::EdgeWeight) -> Result, +{ + // Determine and add edges for new node. + { + // Note: even when the graph is undirected, we used edges_directed because + // it gives us a consistent endpoint order. + let mut incoming_edges: Vec<(G::NodeId, G::EdgeWeight)> = nodes + .iter() + .flat_map(|i| graph.edges_directed(*i, Direction::Incoming)) + .filter_map(|edge| { + let pred = edge.source(); + if !nodes.contains(&pred) { + Some((pred, edge.weight().clone())) + } else { + None + } + }) + .collect(); + + if let Some(merge_fn) = &mut weight_combo_fn { + incoming_edges = merge_duplicates(incoming_edges, merge_fn)?; + } + + for (source, weight) in incoming_edges.into_iter() { + graph.add_edge(source, new_node, weight); + } + } + + if graph.is_directed() { + let mut outgoing_edges: Vec<(G::NodeId, G::EdgeWeight)> = nodes + .iter() + .flat_map(|&i| graph.edges_directed(i, Direction::Outgoing)) + .filter_map(|edge| { + let succ = edge.target(); + if !nodes.contains(&succ) { + Some((succ, edge.weight().clone())) + } else { + None + } + }) + .collect(); + + if let Some(merge_fn) = &mut weight_combo_fn { + outgoing_edges = merge_duplicates(outgoing_edges, merge_fn)?; + } + + for (target, weight) in outgoing_edges.into_iter() { + graph.add_edge(new_node, target, weight); + } + } + + Ok(()) +} diff --git a/rustworkx-core/src/graph_ext/mod.rs b/rustworkx-core/src/graph_ext/mod.rs new file mode 100644 index 000000000..256d6ac0a --- /dev/null +++ b/rustworkx-core/src/graph_ext/mod.rs @@ -0,0 +1,132 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations +// under the License. + +//! This module defines traits that extend PetGraph's graph +//! data structures. +//! +//! The `-Directed` and `-Undirected` trait variants are implemented as +//! applicable for directed and undirected graph types. For example, only +//! directed graph types are concerned with cycle checking and corresponding +//! error handling, so these traits provide applicable parameters and return +//! types to account for this. +//! +//! ### Node Contraction +//! +//! There are four traits related to node contraction available for different +//! graphs / configurations, including: +//! +//! - [`ContractNodesDirected`] +//! - [`ContractNodesSimpleDirected`] +//! - [`ContractNodesUndirected`] +//! - [`ContractNodesSimpleUndirected`] +//! +//! Of these, the `ContractNodesSimple-` traits provide a +//! `contract_nodes_simple` method for applicable graph types, which performs +//! node contraction without introducing parallel edges between nodes (edges +//! between any two given nodes are merged via the method's merge function). +//! These traits can be used for node contraction within simple graphs to +//! preserve this property, or on multi-graphs to ensure that the contraction +//! does not introduce additional parallel edges. +//! +//! The other `ContractNodes-` traits provide a `contract_nodes` method, which +//! happily introduces parallel edges when multiple nodes in the contraction +//! have an incoming edge from the same source node or when multiple nodes in +//! the contraction have an outgoing edge to the same target node. +//! +//! ### Multi-graph Extensions +//! +//! These traits provide additional helper methods for use with multi-graphs, +//! e.g. [`HasParallelEdgesDirected`]. +//! +//! ### Graph Extension Trait Implementations +//! +//! The following table lists the traits that are currently implemented for +//! each graph type: +//! +//! | | Graph | StableGraph | GraphMap | MatrixGraph | Csr | List | +//! | ----------------------------- | :---: | :---------: | :------: | :---------: | :---: | :---: | +//! | ContractNodesDirected | | x | x | | | | +//! | ContractNodesSimpleDirected | | x | x | | | | +//! | ContractNodesUndirected | | x | x | | | | +//! | ContractNodesSimpleUndirected | | x | x | | | | +//! | HasParallelEdgesDirected | x | x | x | x | x | x | +//! | HasParallelEdgesUndirected | x | x | x | x | x | x | +//! | NodeRemovable | x | x | x | x | | | + +use petgraph::graph::IndexType; +use petgraph::graphmap::{GraphMap, NodeTrait}; +use petgraph::matrix_graph::{MatrixGraph, Nullable}; +use petgraph::stable_graph::StableGraph; +use petgraph::visit::{Data, IntoNodeIdentifiers}; +use petgraph::{EdgeType, Graph}; + +pub mod contraction; +pub mod multigraph; + +pub use contraction::{ + ContractNodesDirected, ContractNodesSimpleDirected, ContractNodesSimpleUndirected, + ContractNodesUndirected, +}; +pub use multigraph::{HasParallelEdgesDirected, HasParallelEdgesUndirected}; + +/// A graph whose nodes may be removed. +pub trait NodeRemovable: Data { + type Output; + fn remove_node(&mut self, node: Self::NodeId) -> Self::Output; +} + +impl NodeRemovable for StableGraph +where + Ty: EdgeType, + Ix: IndexType, +{ + type Output = Option; + fn remove_node(&mut self, node: Self::NodeId) -> Option { + self.remove_node(node) + } +} + +impl NodeRemovable for Graph +where + Ty: EdgeType, + Ix: IndexType, +{ + type Output = Option; + fn remove_node(&mut self, node: Self::NodeId) -> Option { + self.remove_node(node) + } +} + +impl NodeRemovable for GraphMap +where + N: NodeTrait, + Ty: EdgeType, +{ + type Output = bool; + fn remove_node(&mut self, node: Self::NodeId) -> Self::Output { + self.remove_node(node) + } +} + +impl, Ix: IndexType> NodeRemovable + for MatrixGraph +{ + type Output = Option; + fn remove_node(&mut self, node: Self::NodeId) -> Self::Output { + for n in self.node_identifiers() { + if node == n { + return Some(self.remove_node(node)); + } + } + None + } +} diff --git a/rustworkx-core/src/graph_ext/multigraph.rs b/rustworkx-core/src/graph_ext/multigraph.rs new file mode 100644 index 000000000..ac81037be --- /dev/null +++ b/rustworkx-core/src/graph_ext/multigraph.rs @@ -0,0 +1,66 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations +// under the License. + +//! This module defines graph traits for multi-graphs. + +use hashbrown::HashSet; +use petgraph::visit::{EdgeCount, EdgeRef, GraphBase, GraphProp, IntoEdgeReferences, Visitable}; +use petgraph::{Directed, Undirected}; +use std::hash::Hash; + +pub trait HasParallelEdgesUndirected: GraphBase { + fn has_parallel_edges(&self) -> bool; +} + +impl HasParallelEdgesUndirected for G +where + G: GraphProp + Visitable + EdgeCount, + G::NodeId: Eq + Hash, + for<'b> &'b G: GraphBase + IntoEdgeReferences, +{ + fn has_parallel_edges(&self) -> bool { + let mut edges: HashSet<[Self::NodeId; 2]> = HashSet::with_capacity(2 * self.edge_count()); + for edge in self.edge_references() { + let endpoints = [edge.source(), edge.target()]; + let endpoints_rev = [edge.target(), edge.source()]; + if edges.contains(&endpoints) || edges.contains(&endpoints_rev) { + return true; + } + edges.insert(endpoints); + edges.insert(endpoints_rev); + } + false + } +} + +pub trait HasParallelEdgesDirected: GraphBase { + fn has_parallel_edges(&self) -> bool; +} + +impl HasParallelEdgesDirected for G +where + G: GraphProp + Visitable + EdgeCount, + G::NodeId: Eq + Hash, + for<'b> &'b G: GraphBase + IntoEdgeReferences, +{ + fn has_parallel_edges(&self) -> bool { + let mut edges: HashSet<[Self::NodeId; 2]> = HashSet::with_capacity(self.edge_count()); + for edge in self.edge_references() { + let endpoints = [edge.source(), edge.target()]; + if edges.contains(&endpoints) { + return true; + } + edges.insert(endpoints); + } + false + } +} diff --git a/rustworkx-core/src/lib.rs b/rustworkx-core/src/lib.rs index f3d08309c..6be610425 100644 --- a/rustworkx-core/src/lib.rs +++ b/rustworkx-core/src/lib.rs @@ -44,7 +44,7 @@ //! //! ## Algorithm Modules //! -//! The crate is organized into +//! The crate provides the following graph algorithm modules //! //! * [`centrality`](./centrality/index.html) //! * [`connectivity`](./connectivity/index.html) @@ -54,6 +54,27 @@ //! * [`traversal`](./traversal/index.html) //! * [`generators`](./generators/index.html) //! +//! ## Graph Extensions +//! +//! The crate also provides traits which extend `petgraph` types with +//! additional methods, when imported. +//! +//! For example, the +//! [`contract_nodes`][graph_ext::ContractNodesDirected::contract_nodes] method +//! becomes available for applicable graph types when the following trait is +//! imported: +//! +//! ```rust +//! use petgraph::prelude::*; +//! use rustworkx_core::graph_ext::ContractNodesDirected; +//! +//! let mut dag: StableDiGraph = StableDiGraph::default(); +//! let m = dag.contract_nodes([], 'm', true).unwrap(); +//! ``` +//! +//! See the documentation of [`graph_ext`] for a full listing of the available +//! extensions and their compatibility with petgraph types. +//! //! ## Release Notes //! //! The release notes for rustworkx-core are included as part of the rustworkx @@ -69,6 +90,7 @@ use std::convert::Infallible; /// to use needs a callback that returns [`Result`] but in your case no /// error can happen. pub type Result = core::result::Result; +pub mod err; pub mod bipartite_coloring; /// Module for centrality algorithms. @@ -78,6 +100,7 @@ pub mod coloring; pub mod connectivity; pub mod dag_algo; pub mod generators; +pub mod graph_ext; pub mod line_graph; /// Module for maximum weight matching algorithms. pub mod max_weight_matching; diff --git a/rustworkx-core/tests/graph_ext/contraction.rs b/rustworkx-core/tests/graph_ext/contraction.rs new file mode 100644 index 000000000..856beea19 --- /dev/null +++ b/rustworkx-core/tests/graph_ext/contraction.rs @@ -0,0 +1,403 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations +// under the License. + +use ahash::HashSet; +use hashbrown::HashMap; +use petgraph::data::Build; +use petgraph::visit::{ + Data, EdgeCount, EdgeRef, GraphBase, IntoEdgeReferences, IntoNodeIdentifiers, +}; +use rustworkx_core::err::ContractError; +use rustworkx_core::graph_ext::*; +use std::convert::Infallible; +use std::fmt::Debug; +use std::hash::Hash; + +mod graph_map { + use petgraph::prelude::*; + type G = DiGraphMap; + + common_test!(test_cycle_check_enabled, G); + common_test!(test_cycle_check_disabled, G); + common_test!(test_empty_nodes, G); + common_test!(test_unknown_nodes, G); + common_test!(test_cycle_path_len_gt_1, G); + common_test!(test_multiple_paths_would_cycle, G); + common_test!(test_replace_node_no_neighbors, G); + common_test!(test_keep_edges_multigraph, G); + common_test!(test_collapse_parallel_edges, G); + common_test!(test_replace_all_nodes, G); +} + +mod stable_graph { + use petgraph::prelude::*; + type G = StableDiGraph; + + common_test!(test_cycle_check_enabled, G); + common_test!(test_cycle_check_disabled, G); + common_test!(test_empty_nodes, G); + common_test!(test_unknown_nodes, G); + common_test!(test_cycle_path_len_gt_1, G); + common_test!(test_multiple_paths_would_cycle, G); + common_test!(test_replace_node_no_neighbors, G); + common_test!(test_keep_edges_multigraph, G); + common_test!(test_collapse_parallel_edges, G); + common_test!(test_replace_all_nodes, G); +} + +/// ┌─┐ ┌─┐ +/// ┌─┤a│ ┌─────────┤m│ +/// │ └─┘ │ └▲┘ +/// ┌▼┐ ┌▼┐ │ +/// │b│ ───► │b├─────────┘ +/// └┬┘ └─┘ +/// │ ┌─┐ +/// └─►┤c│ +/// └─┘ +pub fn test_cycle_check_enabled() +where + G: Default + + Data + + Build + + ContractNodesDirected, + G::NodeId: Debug, +{ + let mut dag = G::default(); + let a = dag.add_node('a'); + let b = dag.add_node('b'); + let c = dag.add_node('c'); + dag.add_edge(a, b, 1); + dag.add_edge(b, c, 2); + let result = dag.contract_nodes([a, c], 'm', true); + match result.expect_err("Cycle should cause return error.") { + ContractError::DAGWouldCycle => (), + } +} + +fn test_cycle_check_disabled() +where + G: Default + + Data + + Build + + ContractNodesDirected, + G::NodeId: Debug, +{ + let mut dag = G::default(); + let a = dag.add_node('a'); + let b = dag.add_node('b'); + let c = dag.add_node('c'); + dag.add_edge(a, b, 1); + dag.add_edge(b, c, 2); + let result = dag.contract_nodes([a, c], 'm', false); + result.expect("No error should be raised for a cycle when cycle check is disabled."); +} + +fn test_empty_nodes() +where + G: Default + + Data + + Build + + ContractNodesDirected, + G::NodeId: Debug, +{ + let mut dag = G::default(); + dag.contract_nodes([], 'm', false).unwrap(); + assert_eq!(dag.node_count(), 1); +} + +fn test_unknown_nodes() +where + G: Default + + Data + + Build + + ContractNodesDirected + + NodeRemovable, + G::NodeId: Debug + Copy, +{ + let mut dag = G::default(); + + // A -> B -> C + let a = dag.add_node('a'); + let b = dag.add_node('b'); + let c = dag.add_node('c'); + + dag.add_edge(a, b, 1); + dag.add_edge(b, c, 2); + + // Leave just A. + dag.remove_node(b); + dag.remove_node(c); + + // Replacement should ignore the unknown nodes, making + // the behavior equivalent to adding a new node in + // this case. + dag.contract_nodes([b, c], 'm', false).unwrap(); + assert_eq!(dag.node_count(), 2); +} + +/// ┌─┐ ┌─┐ +/// ┌4─┤a├─1┐ │m├──1───┐ +/// │ └─┘ │ └▲┘ │ +/// ┌▼┐ ┌▼┐ │ ┌▼┐ +/// │d│ │b│ ───► │ │b│ +/// └▲┘ └┬┘ │ └┬┘ +/// │ ┌─┐ 2 │ ┌─┐ 2 +/// └3─┤c│◄─┘ └3─┤c│◄─┘ +/// └─┘ └─┘ +fn test_cycle_path_len_gt_1() +where + G: Default + + Data + + Build + + ContractNodesDirected + + NodeRemovable, + G::NodeId: Debug + Copy, +{ + let mut dag = G::default(); + let a = dag.add_node('a'); + let b = dag.add_node('b'); + let c = dag.add_node('c'); + let d = dag.add_node('d'); + dag.add_edge(a, b, 1); + dag.add_edge(b, c, 2); + dag.add_edge(c, d, 3); + dag.add_edge(a, d, 4); + + dag.contract_nodes([a, d], 'm', true) + .expect_err("Cycle should be detected."); +} + +/// ┌─┐ ┌─┐ ┌─┐ ┌─┐ +/// ┌3─┤c│ │e├─5┐ ┌──┤c│ │e├──┐ +/// │ └▲┘ └▲┘ │ │ └▲┘ └▲┘ │ +/// ┌▼┐ 2 ┌─┐ 4 ┌▼┐ │ 2 ┌─┐ 4 │ +/// │d│ └──┤b├──┘ │f│ ───► │ └──┤b├──┘ │ +/// └─┘ └▲┘ └─┘ 3 └▲┘ 5 +/// 1 │ 1 │ +/// ┌┴┐ │ ┌┴┐ │ +/// │a│ └─────►│m│◄─────┘ +/// └─┘ └─┘ +fn test_multiple_paths_would_cycle() +where + G: Default + + Data + + Build + + ContractNodesDirected, + for<'b> &'b G: GraphBase + + Data + + IntoEdgeReferences + + IntoNodeIdentifiers, + G::NodeId: Eq + Hash + Debug + Copy, +{ + let mut dag = G::default(); + let a = dag.add_node('a'); + let b = dag.add_node('b'); + let c = dag.add_node('c'); + let d = dag.add_node('d'); + let e = dag.add_node('e'); + let f = dag.add_node('f'); + + dag.add_edge(a, b, 1); + dag.add_edge(b, c, 2); + dag.add_edge(c, d, 3); + dag.add_edge(b, e, 4); + dag.add_edge(e, f, 5); + + let result = dag.contract_nodes([a, d, f], 'm', true); + match result.expect_err("Cycle should cause return error.") { + ContractError::DAGWouldCycle => (), + } + + // Proceed, ignoring cycles. + dag.contract_nodes([a, d, f], 'm', false) + .expect("Contraction should be allowed without cycle check."); + + let edge_refs: Vec<_> = dag.edge_references().collect(); + assert_eq!(edge_refs.len(), 5, "Missing expected edge!"); + + // Build up a map of node weight to node ID and ensure + // IDs cross reference as expected between edges. + let mut seen = HashMap::new(); + for edge_ref in edge_refs.into_iter() { + match (edge_ref.source(), edge_ref.target(), edge_ref.weight()) { + (m, b, 1) => { + assert_eq!(*seen.entry('m').or_insert(m), m); + assert_eq!(*seen.entry('b').or_insert(b), b); + } + (b, c, 2) => { + assert_eq!(*seen.entry('b').or_insert(b), b); + assert_eq!(*seen.entry('c').or_insert(c), c); + } + (c, m, 3) => { + assert_eq!(*seen.entry('c').or_insert(c), c); + assert_eq!(*seen.entry('m').or_insert(m), m); + } + (b, e, 4) => { + assert_eq!(*seen.entry('b').or_insert(b), b); + assert_eq!(*seen.entry('e').or_insert(e), e); + } + (e, m, 5) => { + assert_eq!(*seen.entry('e').or_insert(e), e); + assert_eq!(*seen.entry('m').or_insert(m), m); + } + (_, _, w) => panic!("Unexpected edge weight: {}", w), + } + } + + assert_eq!(seen.len(), 4, "Missing expected node!"); +} + +fn test_replace_node_no_neighbors() +where + G: Default + + Data + + Build + + ContractNodesDirected, + G::NodeId: Debug, +{ + let mut dag = G::default(); + let a = dag.add_node('a'); + dag.contract_nodes([a], 'm', true).unwrap(); + assert_eq!(dag.node_count(), 1); +} + +/// ┌─┐ ┌─┐ +/// ┌─┤a│◄┐ ┌─┤a│◄┐ +/// │ └─┘ │ │ └─┘ │ +/// 1 2 ──► 1 2 +/// ┌▼┐ ┌┴┐ │ ┌─┐ │ +/// │b│ │c│ └►│m├─┘ +/// └─┘ └─┘ └─┘ +fn test_keep_edges_multigraph() +where + G: Default + + Data + + Build + + ContractNodesDirected, + for<'b> &'b G: GraphBase + + Data + + IntoEdgeReferences + + IntoNodeIdentifiers, + G::NodeId: Eq + Hash + Debug + Copy, +{ + let mut dag = G::default(); + let a = dag.add_node('a'); + let b = dag.add_node('b'); + let c = dag.add_node('c'); + + dag.add_edge(a, b, 1); + dag.add_edge(c, a, 2); + + let result = dag.contract_nodes([b, c], 'm', true); + match result.expect_err("Cycle should cause return error.") { + ContractError::DAGWouldCycle => (), + } + + // Proceed, ignoring cycles. + let m = dag + .contract_nodes([b, c], 'm', false) + .expect("Contraction should be allowed without cycle check."); + + assert_eq!(dag.node_count(), 2); + + let edges: HashSet<_> = dag + .edge_references() + .map(|e| (e.source(), e.target(), *e.weight())) + .collect(); + let expected = HashSet::from_iter([(a, m, 1), (m, a, 2)]); + assert_eq!(edges, expected); +} + +/// Parallel edges are collapsed using weight_combo_fn. +/// ┌─┐ ┌─┐ +/// │a│ │a│ +/// ┌──┴┬┴──┐ └┬┘ +/// 1 2 3 6 +/// ┌▼┐ ┌▼┐ ┌▼┐ ┌▼┐ +/// │b│ │c│ │d│ ──► │m│ +/// └┬┘ └┬┘ └┬┘ └┬┘ +/// 4 5 6 15 +/// └──►▼◄──┘ ┌▼┐ +/// │e│ │e│ +/// └─┘ └─┘ +fn test_collapse_parallel_edges() +where + G: Default + Data + Build + ContractNodesSimpleDirected, + for<'b> &'b G: GraphBase + + Data + + IntoEdgeReferences + + IntoNodeIdentifiers, + G::NodeId: Eq + Hash + Debug + Copy, +{ + let mut dag = G::default(); + let a = dag.add_node('a'); + let b = dag.add_node('b'); + let c = dag.add_node('c'); + let d = dag.add_node('d'); + let e = dag.add_node('e'); + + dag.add_edge(a, b, 1); + dag.add_edge(a, c, 2); + dag.add_edge(a, d, 3); + dag.add_edge(b, e, 4); + dag.add_edge(c, e, 5); + dag.add_edge(d, e, 6); + + let m = dag + .contract_nodes_simple([b, c, d], 'm', true, |w1, w2| { + Ok::(w1 + w2) + }) + .unwrap(); + + assert_eq!(dag.node_count(), 3); + + let edges: HashSet<_> = dag + .edge_references() + .map(|e| (e.source(), e.target(), *e.weight())) + .collect(); + let expected = HashSet::from_iter([(a, m, 6), (m, e, 15)]); + assert_eq!(edges, expected); +} + +fn test_replace_all_nodes() +where + G: Default + + Data + + Build + + ContractNodesDirected + + EdgeCount, + for<'b> &'b G: GraphBase + + Data + + IntoEdgeReferences + + IntoNodeIdentifiers, + G::NodeId: Eq + Hash + Debug + Copy, +{ + let mut dag = G::default(); + let a = dag.add_node('a'); + let b = dag.add_node('b'); + let c = dag.add_node('c'); + let d = dag.add_node('d'); + let e = dag.add_node('e'); + + dag.add_edge(a, b, 1); + dag.add_edge(a, c, 2); + dag.add_edge(a, d, 3); + dag.add_edge(b, e, 4); + dag.add_edge(c, e, 5); + dag.add_edge(d, e, 6); + + dag.contract_nodes(dag.node_identifiers().collect::>(), 'm', true) + .unwrap(); + + assert_eq!(dag.node_count(), 1); + assert_eq!(dag.edge_count(), 0); +} diff --git a/rustworkx-core/tests/graph_ext/main.rs b/rustworkx-core/tests/graph_ext/main.rs new file mode 100644 index 000000000..c0d79f1fc --- /dev/null +++ b/rustworkx-core/tests/graph_ext/main.rs @@ -0,0 +1,22 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations +// under the License. + +macro_rules! common_test { + ($func:ident, $graph_ty:ty) => { + #[test] + fn $func() { + super::$func::<$graph_ty>(); + } + }; +} + +mod contraction; diff --git a/src/digraph.rs b/src/digraph.rs index d3384af5d..c31638188 100644 --- a/src/digraph.rs +++ b/src/digraph.rs @@ -22,9 +22,9 @@ use std::io::{BufReader, BufWriter}; use std::str; use hashbrown::{HashMap, HashSet}; -use indexmap::IndexSet; use rustworkx_core::dictmap::*; +use rustworkx_core::graph_ext::*; use smallvec::SmallVec; @@ -44,6 +44,7 @@ use petgraph::algo; use petgraph::graph::{EdgeIndex, NodeIndex}; use petgraph::prelude::*; +use crate::RxPyResult; use petgraph::visit::{ EdgeIndexable, GraphBase, IntoEdgeReferences, IntoNodeReferences, NodeCount, NodeFiltered, NodeIndexable, Visitable, @@ -54,8 +55,8 @@ use super::iterators::{ EdgeIndexMap, EdgeIndices, EdgeList, NodeIndices, NodeMap, WeightedEdgeList, }; use super::{ - find_node_by_weight, merge_duplicates, weight_callable, DAGHasCycle, DAGWouldCycle, IsNan, - NoEdgeBetweenNodes, NoSuitableNeighbors, NodesRemoved, StablePyGraph, + find_node_by_weight, weight_callable, DAGHasCycle, DAGWouldCycle, IsNan, NoEdgeBetweenNodes, + NoSuitableNeighbors, NodesRemoved, StablePyGraph, }; use super::dag_algo::is_directed_acyclic_graph; @@ -480,15 +481,7 @@ impl PyDiGraph { if !self.multigraph { return false; } - let mut edges: HashSet<[NodeIndex; 2]> = HashSet::with_capacity(self.graph.edge_count()); - for edge in self.graph.edge_references() { - let endpoints = [edge.source(), edge.target()]; - if edges.contains(&endpoints) { - return true; - } - edges.insert(endpoints); - } - false + self.graph.has_parallel_edges() } /// Clear all nodes and edges @@ -2665,98 +2658,26 @@ impl PyDiGraph { obj: PyObject, check_cycle: Option, weight_combo_fn: Option, - ) -> PyResult { - let can_contract = |nodes: &IndexSet| { - // Start with successors of `nodes` that aren't in `nodes` itself. - let visit_next: Vec = nodes - .iter() - .flat_map(|n| self.graph.edges(*n)) - .filter_map(|edge| { - let target_node = edge.target(); - if !nodes.contains(&target_node) { - Some(target_node) - } else { - None - } - }) - .collect(); - - // Now, if we can reach any of `nodes`, there exists a path from `nodes` - // back to `nodes` of length > 1, meaning contraction is disallowed. - let mut dfs = Dfs::from_parts(visit_next, self.graph.visit_map()); - while let Some(node) = dfs.next(&self.graph) { - if nodes.contains(&node) { - // we found a path back to `nodes` - return false; - } + ) -> RxPyResult { + let nodes = nodes.into_iter().map(|i| NodeIndex::new(i)); + let check_cycle = check_cycle.unwrap_or(self.check_cycle); + let res = match (weight_combo_fn, &self.multigraph) { + (Some(user_callback), _) => { + self.graph + .contract_nodes_simple(nodes, obj, check_cycle, |w1, w2| { + user_callback.call1(py, (w1, w2)) + })? } - true + (None, false) => { + // By default, just take first edge. + self.graph + .contract_nodes_simple(nodes, obj, check_cycle, move |w1, _| { + Ok::<_, PyErr>(w1.clone_ref(py)) + })? + } + (None, true) => self.graph.contract_nodes(nodes, obj, check_cycle)?, }; - - let mut indices_to_remove: IndexSet = - nodes.into_iter().map(NodeIndex::new).collect(); - - if check_cycle.unwrap_or(self.check_cycle) && !can_contract(&indices_to_remove) { - return Err(DAGWouldCycle::new_err("Contraction would create cycle(s)")); - } - - // Create new node. - let node_index = self.graph.add_node(obj); - - // Sanitize new node index from user input. - indices_to_remove.swap_remove(&node_index); - - // Determine edges for new node. - let mut incoming_edges: Vec<_> = indices_to_remove - .iter() - .flat_map(|&i| self.graph.edges_directed(i, Direction::Incoming)) - .filter_map(|edge| { - let pred = edge.source(); - if !indices_to_remove.contains(&pred) { - Some((pred, edge.weight().clone_ref(py))) - } else { - None - } - }) - .collect(); - - let mut outgoing_edges: Vec<_> = indices_to_remove - .iter() - .flat_map(|&i| self.graph.edges_directed(i, Direction::Outgoing)) - .filter_map(|edge| { - let succ = edge.target(); - if !indices_to_remove.contains(&succ) { - Some((succ, edge.weight().clone_ref(py))) - } else { - None - } - }) - .collect(); - - // Remove nodes that will be replaced. - for index in indices_to_remove { - self.remove_node(index.index())?; - } - - // If `weight_combo_fn` was specified, merge edges according - // to that function, even if this is a multigraph. If unspecified, - // defer parallel edge handling to `add_edge_no_cycle_check`. - if let Some(merge_fn) = weight_combo_fn { - let f = |w1: &Py<_>, w2: &Py<_>| merge_fn.call1(py, (w1, w2)); - - incoming_edges = merge_duplicates(incoming_edges, f)?; - outgoing_edges = merge_duplicates(outgoing_edges, f)?; - } - - for (source, weight) in incoming_edges { - self.add_edge_no_cycle_check(source, node_index, weight); - } - - for (target, weight) in outgoing_edges { - self.add_edge_no_cycle_check(node_index, target, weight); - } - - Ok(node_index.index()) + Ok(res.index()) } /// Return a new PyDiGraph object for a subgraph of this graph diff --git a/src/graph.rs b/src/graph.rs index 9f49e3f2a..06ba8cc4a 100644 --- a/src/graph.rs +++ b/src/graph.rs @@ -20,8 +20,8 @@ use std::io::{BufReader, BufWriter}; use std::str; use hashbrown::{HashMap, HashSet}; -use indexmap::IndexSet; use rustworkx_core::dictmap::*; +use rustworkx_core::graph_ext::*; use pyo3::exceptions::PyIndexError; use pyo3::gc::PyVisit; @@ -40,10 +40,10 @@ use crate::iterators::NodeMap; use super::dot_utils::build_dot; use super::iterators::{EdgeIndexMap, EdgeIndices, EdgeList, NodeIndices, WeightedEdgeList}; use super::{ - find_node_by_weight, merge_duplicates, weight_callable, IsNan, NoEdgeBetweenNodes, - NodesRemoved, StablePyGraph, + find_node_by_weight, weight_callable, IsNan, NoEdgeBetweenNodes, NodesRemoved, StablePyGraph, }; +use crate::RxPyResult; use petgraph::algo; use petgraph::graph::{EdgeIndex, NodeIndex}; use petgraph::prelude::*; @@ -350,18 +350,7 @@ impl PyGraph { if !self.multigraph { return false; } - let mut edges: HashSet<[NodeIndex; 2]> = - HashSet::with_capacity(2 * self.graph.edge_count()); - for edge in self.graph.edge_references() { - let endpoints = [edge.source(), edge.target()]; - let endpoints_rev = [edge.target(), edge.source()]; - if edges.contains(&endpoints) || edges.contains(&endpoints_rev) { - return true; - } - edges.insert(endpoints); - edges.insert(endpoints_rev); - } - false + self.graph.has_parallel_edges() } /// Clears all nodes and edges @@ -1819,50 +1808,22 @@ impl PyGraph { nodes: Vec, obj: PyObject, weight_combo_fn: Option, - ) -> PyResult { - let mut indices_to_remove: IndexSet = - nodes.into_iter().map(NodeIndex::new).collect(); - - // Create new node. - let node_index = self.graph.add_node(obj); - - // Sanitize new node index from user input. - indices_to_remove.swap_remove(&node_index); - - // Determine edges for new node. - // note: `edges_directed` returns all edges with `i` as - // an endpoint. `Direction::Incoming` configures `edge.target()` - // to return `i` and `edge.source()` to return the other node. - let mut edges: Vec<_> = indices_to_remove - .iter() - .flat_map(|&i| self.graph.edges_directed(i, Direction::Incoming)) - .filter_map(|edge| { - let pred = edge.source(); - if !indices_to_remove.contains(&pred) { - Some((pred, edge.weight().clone_ref(py))) - } else { - None - } - }) - .collect(); - - // Remove nodes that will be replaced. - for index in indices_to_remove { - self.remove_node(index.index())?; - } - - // If `weight_combo_fn` was specified, merge edges according - // to that function, even if this is a multigraph. If unspecified, - // defer parallel edge handling to `add_edge`. - if let Some(merge_fn) = weight_combo_fn { - edges = merge_duplicates(edges, |w1, w2| merge_fn.call1(py, (w1, w2)))?; - } - - for (source, weight) in edges { - self.add_edge(source.index(), node_index.index(), weight)?; - } - - Ok(node_index.index()) + ) -> RxPyResult { + let nodes = nodes.into_iter().map(|i| NodeIndex::new(i)); + let res = match (weight_combo_fn, &self.multigraph) { + (Some(user_callback), _) => { + self.graph + .contract_nodes_simple(nodes, obj, |w1, w2| user_callback.call1(py, (w1, w2)))? + } + (None, false) => { + // By default, just take first edge. + self.graph.contract_nodes_simple(nodes, obj, move |w1, _| { + Ok::<_, PyErr>(w1.clone_ref(py)) + })? + } + (None, true) => self.graph.contract_nodes(nodes, obj), + }; + Ok(res.index()) } /// Return a new PyGraph object for a subgraph of this graph diff --git a/src/lib.rs b/src/lib.rs index cde6034fc..ce0843b8e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -65,7 +65,6 @@ use tree::*; use union::*; use hashbrown::HashMap; -use indexmap::map::Entry::{Occupied, Vacant}; use numpy::Complex64; use pyo3::create_exception; @@ -86,9 +85,73 @@ use petgraph::visit::{ use petgraph::EdgeType; use std::convert::TryFrom; -use std::hash::Hash; use rustworkx_core::dictmap::*; +use rustworkx_core::err::{ContractError, ContractSimpleError}; + +/// An ergonomic error type used to map Rustworkx core errors to +/// [PyErr] automatically, via [From::from]. +/// +/// It is constructable from both [PyErr] and core errors and implements +/// [IntoPy], so it can be returned directly from PyO3 methods and +/// functions. Additionally, a [PyErr] can be constructed from this +/// type, since it's just a wrapper around one, so you can even go +/// from a core error => [RxPyErr] => [PyErr]. +/// +/// # Usage +/// When calling Rustworkx core functions from PyO3 code, use +/// [RxPyResult] as the return type of the calling function and use +/// the `?` operator to unwrap the result with error propagation. +/// Since Rust automatically applies [From::from] to unwrapped error +/// values, a core error will be automatically converted to a +/// Python-friendly error and stored in [RxPyErr], assuming you've +/// added an implementation of [From] for it below. The standard +/// [PyErr] type will be converted to [RxPyErr] using the same +/// mechanism, allowing Rustworkx core and PyO3 API usage to be +/// intermixed within the same calling function. +pub struct RxPyErr { + pyerr: PyErr, +} + +/// Type alias for a [Result] with error type [RxPyErr]. +pub type RxPyResult = Result; + +fn map_dag_would_cycle(value: E) -> PyErr { + DAGWouldCycle::new_err(format!("{:?}", value)) +} + +impl From for RxPyErr { + fn from(value: ContractError) -> Self { + RxPyErr { + pyerr: match value { + ContractError::DAGWouldCycle => map_dag_would_cycle(value), + }, + } + } +} + +impl From> for RxPyErr { + fn from(value: ContractSimpleError) -> Self { + RxPyErr { + pyerr: match value { + ContractSimpleError::DAGWouldCycle => map_dag_would_cycle(value), + ContractSimpleError::MergeError(e) => e, + }, + } + } +} + +impl IntoPy for RxPyErr { + fn into_py(self, py: Python<'_>) -> PyObject { + self.pyerr.into_value(py).into() + } +} + +impl From for PyErr { + fn from(value: RxPyErr) -> Self { + value.pyerr + } +} trait IsNan { fn is_nan(&self) -> bool; @@ -289,25 +352,6 @@ fn find_node_by_weight( Ok(index) } -fn merge_duplicates(xs: Vec<(K, V)>, mut merge_fn: F) -> Result, E> -where - K: Hash + Eq, - F: FnMut(&V, &V) -> Result, -{ - let mut kvs = DictMap::with_capacity(xs.len()); - for (k, v) in xs { - match kvs.entry(k) { - Occupied(entry) => { - *entry.into_mut() = merge_fn(&v, entry.get())?; - } - Vacant(entry) => { - entry.insert(v); - } - } - } - Ok(kvs.into_iter().collect::>()) -} - // The provided node is invalid. create_exception!(rustworkx, InvalidNode, PyException); // Performing this operation would result in trying to add a cycle to a DAG.