diff --git a/rustworkx-core/src/generators/complete_graph.rs b/rustworkx-core/src/generators/complete_graph.rs new file mode 100644 index 000000000..0aa7753c1 --- /dev/null +++ b/rustworkx-core/src/generators/complete_graph.rs @@ -0,0 +1,201 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations +// under the License. + +use petgraph::data::{Build, Create}; +use petgraph::visit::{Data, GraphProp, NodeIndexable}; + +use super::utils::get_num_nodes; +use super::InvalidInputError; + +/// Generate a complete graph +/// +/// Arguments: +/// +/// * `num_nodes` - The number of nodes to create a complete graph for. Either this or +/// `weights must be specified. If both this and `weights are specified, weights +/// will take priorty and this argument will be ignored +/// * `weights` - A `Vec` of node weight objects. +/// * `default_node_weight` - A callable that will return the weight to use +/// for newly created nodes. This is ignored if `weights` is specified, +/// as the weights from that argument will be used instead. +/// * `default_edge_weight` - A callable that will return the weight object +/// to use for newly created edges. +/// +/// # Example +/// ```rust +/// use rustworkx_core::petgraph; +/// use rustworkx_core::generators::complete_graph; +/// use rustworkx_core::petgraph::visit::EdgeRef; +/// +/// let g: petgraph::graph::UnGraph<(), ()> = complete_graph( +/// Some(4), +/// None, +/// || {()}, +/// || {()}, +/// ).unwrap(); +/// assert_eq!( +/// vec![(0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3)], +/// g.edge_references() +/// .map(|edge| (edge.source().index(), edge.target().index())) +/// .collect::>(), +/// ) +/// ``` +pub fn complete_graph( + num_nodes: Option, + weights: Option>, + mut default_node_weight: F, + mut default_edge_weight: H, +) -> Result +where + G: Build + Create + Data + NodeIndexable + GraphProp, + F: FnMut() -> T, + H: FnMut() -> M, +{ + if weights.is_none() && num_nodes.is_none() { + return Err(InvalidInputError {}); + } + let node_len = get_num_nodes(&num_nodes, &weights); + let mut graph = G::with_capacity(node_len, node_len); + if node_len == 0 { + return Ok(graph); + } + + match weights { + Some(weights) => { + for weight in weights { + graph.add_node(weight); + } + } + None => { + for _ in 0..node_len { + graph.add_node(default_node_weight()); + } + } + }; + for i in 0..node_len - 1 { + for j in i + 1..node_len { + let node_i = graph.from_index(i); + let node_j = graph.from_index(j); + graph.add_edge(node_i, node_j, default_edge_weight()); + if graph.is_directed() { + graph.add_edge(node_j, node_i, default_edge_weight()); + } + } + } + Ok(graph) +} + +#[cfg(test)] +mod tests { + use crate::generators::complete_graph; + use crate::generators::InvalidInputError; + use crate::petgraph::graph::{DiGraph, NodeIndex, UnGraph}; + use crate::petgraph::visit::EdgeRef; + + #[test] + fn test_directed_complete_graph() { + let g: DiGraph<(), ()> = complete_graph(Some(10), None, || (), || ()).unwrap(); + assert_eq!(g.node_count(), 10); + assert_eq!(g.edge_count(), 90); + let mut elist = vec![]; + for i in 0..10 { + for j in i..10 { + if i != j { + elist.push((i, j)); + elist.push((j, i)); + } + } + } + assert_eq!( + elist, + g.edge_references() + .map(|edge| (edge.source().index(), edge.target().index())) + .collect::>(), + ); + } + + #[test] + fn test_directed_complete_graph_weights() { + let g: DiGraph = + complete_graph(None, Some(vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), || 4, || ()).unwrap(); + assert_eq!(g.node_count(), 10); + assert_eq!(g.edge_count(), 90); + let mut elist = vec![]; + for i in 0..10 { + for j in i..10 { + if i != j { + elist.push((i, j)); + elist.push((j, i)); + } + } + assert_eq!(*g.node_weight(NodeIndex::new(i)).unwrap(), i); + } + assert_eq!( + elist, + g.edge_references() + .map(|edge| (edge.source().index(), edge.target().index())) + .collect::>(), + ); + } + + #[test] + fn test_compete_graph_error() { + match complete_graph::, (), _, _, ()>(None, None, || (), || ()) { + Ok(_) => panic!("Returned a non-error"), + Err(e) => assert_eq!(e, InvalidInputError), + }; + } + + #[test] + fn test_complete_graph() { + let g: UnGraph<(), ()> = complete_graph(Some(10), None, || (), || ()).unwrap(); + assert_eq!(g.node_count(), 10); + assert_eq!(g.edge_count(), 45); + let mut elist = vec![]; + for i in 0..10 { + for j in i..10 { + if i != j { + elist.push((i, j)); + } + } + } + assert_eq!( + elist, + g.edge_references() + .map(|edge| (edge.source().index(), edge.target().index())) + .collect::>(), + ); + } + + #[test] + fn test_complete_graph_weights() { + let g: UnGraph = + complete_graph(None, Some(vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), || 4, || ()).unwrap(); + assert_eq!(g.node_count(), 10); + assert_eq!(g.edge_count(), 45); + let mut elist = vec![]; + for i in 0..10 { + for j in i..10 { + if i != j { + elist.push((i, j)); + } + } + assert_eq!(*g.node_weight(NodeIndex::new(i)).unwrap(), i); + } + assert_eq!( + elist, + g.edge_references() + .map(|edge| (edge.source().index(), edge.target().index())) + .collect::>(), + ); + } +} diff --git a/rustworkx-core/src/generators/mod.rs b/rustworkx-core/src/generators/mod.rs index 9ecc58f19..d6991c22a 100644 --- a/rustworkx-core/src/generators/mod.rs +++ b/rustworkx-core/src/generators/mod.rs @@ -14,6 +14,7 @@ mod barbell_graph; mod binomial_tree_graph; +mod complete_graph; mod cycle_graph; mod grid_graph; mod heavy_hex_graph; @@ -43,6 +44,7 @@ impl fmt::Display for InvalidInputError { pub use barbell_graph::barbell_graph; pub use binomial_tree_graph::binomial_tree_graph; +pub use complete_graph::complete_graph; pub use cycle_graph::cycle_graph; pub use grid_graph::grid_graph; pub use heavy_hex_graph::heavy_hex_graph; diff --git a/src/generators.rs b/src/generators.rs index f181d59f5..944256677 100644 --- a/src/generators.rs +++ b/src/generators.rs @@ -34,15 +34,6 @@ where left.zip(right) } -#[inline] -fn get_num_nodes(num_nodes: &Option, weights: &Option>) -> usize { - if weights.is_some() { - weights.as_ref().unwrap().len() - } else { - num_nodes.unwrap() - } -} - /// Generate a cycle graph /// /// :param int num_node: The number of nodes to generate the graph with. Node @@ -428,48 +419,7 @@ pub fn mesh_graph( weights: Option>, multigraph: bool, ) -> PyResult { - if weights.is_none() && num_nodes.is_none() { - return Err(PyIndexError::new_err( - "num_nodes and weights list not specified", - )); - } - let node_len = get_num_nodes(&num_nodes, &weights); - if node_len == 0 { - return Ok(graph::PyGraph { - graph: StablePyGraph::::default(), - node_removed: false, - multigraph, - attrs: py.None(), - }); - } - let num_edges = (node_len * (node_len - 1)) / 2; - let mut graph = StablePyGraph::::with_capacity(node_len, num_edges); - match weights { - Some(weights) => { - for weight in weights { - graph.add_node(weight); - } - } - None => { - (0..node_len).for_each(|_| { - graph.add_node(py.None()); - }); - } - }; - - for i in 0..node_len - 1 { - for j in i + 1..node_len { - let i_index = NodeIndex::new(i); - let j_index = NodeIndex::new(j); - graph.add_edge(i_index, j_index, py.None()); - } - } - Ok(graph::PyGraph { - graph, - node_removed: false, - multigraph, - attrs: py.None(), - }) + complete_graph(py, num_nodes, weights, multigraph) } /// Generate a directed mesh graph where every node is connected to every other @@ -504,52 +454,7 @@ pub fn directed_mesh_graph( weights: Option>, multigraph: bool, ) -> PyResult { - if weights.is_none() && num_nodes.is_none() { - return Err(PyIndexError::new_err( - "num_nodes and weights list not specified", - )); - } - let node_len = get_num_nodes(&num_nodes, &weights); - if node_len == 0 { - return Ok(digraph::PyDiGraph { - graph: StablePyGraph::::default(), - node_removed: false, - check_cycle: false, - cycle_state: algo::DfsSpace::default(), - multigraph, - attrs: py.None(), - }); - } - let num_edges = node_len * (node_len - 1); - let mut graph = StablePyGraph::::with_capacity(node_len, num_edges); - match weights { - Some(weights) => { - for weight in weights { - graph.add_node(weight); - } - } - None => { - (0..node_len).for_each(|_| { - graph.add_node(py.None()); - }); - } - }; - for i in 0..node_len - 1 { - for j in i + 1..node_len { - let i_index = NodeIndex::new(i); - let j_index = NodeIndex::new(j); - graph.add_edge(i_index, j_index, py.None()); - graph.add_edge(j_index, i_index, py.None()); - } - } - Ok(digraph::PyDiGraph { - graph, - node_removed: false, - check_cycle: false, - cycle_state: algo::DfsSpace::default(), - multigraph, - attrs: py.None(), - }) + directed_complete_graph(py, num_nodes, weights, multigraph) } /// Generate an undirected grid graph. @@ -1603,7 +1508,22 @@ pub fn complete_graph( weights: Option>, multigraph: bool, ) -> PyResult { - mesh_graph(py, num_nodes, weights, multigraph) + let default_fn = || py.None(); + let graph: StablePyGraph = + match core_generators::complete_graph(num_nodes, weights, default_fn, default_fn) { + Ok(graph) => graph, + Err(_) => { + return Err(PyIndexError::new_err( + "num_nodes and weights list not specified", + )) + } + }; + Ok(graph::PyGraph { + graph, + node_removed: false, + multigraph, + attrs: py.None(), + }) } /// Generate a directed complete graph with ``n`` nodes. @@ -1636,7 +1556,7 @@ pub fn complete_graph( /// graph = rustworkx.generators.directed_complete_graph(5) /// mpl_draw(graph) /// -#[pyfunction(multigraph = true)] +#[pyfunction(multigraph = "true")] #[pyo3(text_signature = "(/, num_nodes=None, weights=None, multigraph=True)")] pub fn directed_complete_graph( py: Python, @@ -1644,7 +1564,24 @@ pub fn directed_complete_graph( weights: Option>, multigraph: bool, ) -> PyResult { - directed_mesh_graph(py, num_nodes, weights, multigraph) + let default_fn = || py.None(); + let graph: StablePyGraph = + match core_generators::complete_graph(num_nodes, weights, default_fn, default_fn) { + Ok(graph) => graph, + Err(_) => { + return Err(PyIndexError::new_err( + "num_nodes and weights list not specified", + )) + } + }; + Ok(digraph::PyDiGraph { + graph, + node_removed: false, + check_cycle: false, + cycle_state: algo::DfsSpace::default(), + multigraph, + attrs: py.None(), + }) } #[pymodule]