diff --git a/rustworkx-core/src/generators/grid_graph.rs b/rustworkx-core/src/generators/grid_graph.rs new file mode 100644 index 000000000..2fdeb3ee7 --- /dev/null +++ b/rustworkx-core/src/generators/grid_graph.rs @@ -0,0 +1,316 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations +// under the License. + +use petgraph::data::{Build, Create}; +use petgraph::visit::{Data, NodeIndexable}; + +use super::InvalidInputError; + +/// Generate a grid graph +/// +/// Arguments: +/// +/// * `rows` - The number of rows to generate the graph with. +/// If specified, cols also need to be specified. +/// * `cols`: The number of rows to generate the graph with. +/// If specified, rows also need to be specified. rows*cols +/// defines the number of nodes in the graph. +/// * `weights`: A `Vec` of node weights. Nodes are filled row wise. +/// If rows and cols are not specified, then a linear graph containing +/// all the values in weights list is created. +/// If number of nodes(rows*cols) is less than length of +/// weights list, the trailing weights are ignored. +/// If number of nodes(rows*cols) is greater than length of +/// weights list, extra nodes with None weight are appended. +/// * `default_node_weight` - A callable that will return the weight to use +/// for newly created nodes. This is ignored if `weights` is specified, +/// as the weights from that argument will be used instead. +/// * `default_edge_weight` - A callable that will return the weight object +/// to use for newly created edges. +/// * `bidirectional` - Whether edges are added bidirectionally, if set to +/// `true` then for any edge `(u, v)` an edge `(v, u)` will also be added. +/// If the graph is undirected this will result in a parallel edge. +/// +/// # Example +/// ```rust +/// use rustworkx_core::petgraph; +/// use rustworkx_core::generators::grid_graph; +/// use rustworkx_core::petgraph::visit::EdgeRef; +/// +/// let g: petgraph::graph::UnGraph<(), ()> = grid_graph( +/// Some(3), +/// Some(3), +/// None, +/// || {()}, +/// || {()}, +/// false +/// ).unwrap(); +/// assert_eq!( +/// vec![(0, 3), (0, 1), (1, 4), (1, 2), (2, 5), +/// (3, 6), (3, 4), (4, 7), (4, 5), (5, 8), (6, 7), (7, 8)], +/// g.edge_references() +/// .map(|edge| (edge.source().index(), edge.target().index())) +/// .collect::>(), +/// ) +/// ``` +pub fn grid_graph( + rows: Option, + cols: Option, + weights: Option>, + mut default_node_weight: F, + mut default_edge_weight: H, + bidirectional: bool, +) -> Result +where + G: Build + Create + Data + NodeIndexable, + F: FnMut() -> T, + H: FnMut() -> M, +{ + if weights.is_none() && (rows.is_none() || cols.is_none()) { + return Err(InvalidInputError {}); + } + let mut rowlen = rows.unwrap_or(0); + let mut collen = cols.unwrap_or(0); + let mut num_nodes = rowlen * collen; + let mut num_edges = 0; + if num_nodes != 0 { + num_edges = (rowlen - 1) * collen + (collen - 1) * rowlen; + } + if bidirectional { + num_edges *= 2; + } + let mut graph = G::with_capacity(num_nodes, num_edges); + if num_nodes == 0 && weights.is_none() { + return Ok(graph); + } + match weights { + Some(weights) => { + if num_nodes < weights.len() && rowlen == 0 { + collen = weights.len(); + rowlen = 1; + num_nodes = collen; + } + + let mut node_cnt = num_nodes; + + for weight in weights { + if node_cnt == 0 { + break; + } + graph.add_node(weight); + node_cnt -= 1; + } + for _i in 0..node_cnt { + graph.add_node(default_node_weight()); + } + } + None => { + (0..num_nodes).for_each(|_| { + graph.add_node(default_node_weight()); + }); + } + }; + + for i in 0..rowlen { + for j in 0..collen { + if i + 1 < rowlen { + let node_a = graph.from_index(i * collen + j); + let node_b = graph.from_index((i + 1) * collen + j); + graph.add_edge(node_a, node_b, default_edge_weight()); + if bidirectional { + let node_a = graph.from_index((i + 1) * collen + j); + let node_b = graph.from_index(i * collen + j); + graph.add_edge(node_a, node_b, default_edge_weight()); + } + } + + if j + 1 < collen { + let node_a = graph.from_index(i * collen + j); + let node_b = graph.from_index(i * collen + j + 1); + graph.add_edge(node_a, node_b, default_edge_weight()); + if bidirectional { + let node_a = graph.from_index(i * collen + j + 1); + let node_b = graph.from_index(i * collen + j); + graph.add_edge(node_a, node_b, default_edge_weight()); + } + } + } + } + Ok(graph) +} + +#[cfg(test)] +mod tests { + use crate::generators::grid_graph; + use crate::generators::InvalidInputError; + use crate::petgraph::visit::EdgeRef; + + #[test] + fn test_directed_grid_simple_row_col() { + let g: petgraph::graph::DiGraph<(), ()> = + grid_graph(Some(3), Some(3), None, || (), || (), false).unwrap(); + assert_eq!( + vec![ + (0, 3), + (0, 1), + (1, 4), + (1, 2), + (2, 5), + (3, 6), + (3, 4), + (4, 7), + (4, 5), + (5, 8), + (6, 7), + (7, 8) + ], + g.edge_references() + .map(|edge| (edge.source().index(), edge.target().index())) + .collect::>(), + ); + assert_eq!(g.edge_count(), 12); + } + + #[test] + fn test_grid_simple_row_col() { + let g: petgraph::graph::UnGraph<(), ()> = + grid_graph(Some(3), Some(3), None, || (), || (), false).unwrap(); + assert_eq!( + vec![ + (0, 3), + (0, 1), + (1, 4), + (1, 2), + (2, 5), + (3, 6), + (3, 4), + (4, 7), + (4, 5), + (5, 8), + (6, 7), + (7, 8) + ], + g.edge_references() + .map(|edge| (edge.source().index(), edge.target().index())) + .collect::>(), + ); + assert_eq!(g.edge_count(), 12); + } + + #[test] + fn test_directed_grid_weights() { + let g: petgraph::graph::DiGraph = grid_graph( + Some(2), + Some(3), + Some(vec![0, 1, 2, 3, 4, 5]), + || 4, + || (), + false, + ) + .unwrap(); + assert_eq!( + vec![(0, 3), (0, 1), (1, 4), (1, 2), (2, 5), (3, 4), (4, 5),], + g.edge_references() + .map(|edge| (edge.source().index(), edge.target().index())) + .collect::>(), + ); + assert_eq!(g.edge_count(), 7); + assert_eq!( + vec![0, 1, 2, 3, 4, 5], + g.node_weights().copied().collect::>(), + ); + } + + #[test] + fn test_directed_grid_more_weights() { + let g: petgraph::graph::DiGraph = grid_graph( + Some(2), + Some(3), + Some(vec![0, 1, 2, 3, 4, 5, 6, 7]), + || 4, + || (), + false, + ) + .unwrap(); + assert_eq!( + vec![(0, 3), (0, 1), (1, 4), (1, 2), (2, 5), (3, 4), (4, 5),], + g.edge_references() + .map(|edge| (edge.source().index(), edge.target().index())) + .collect::>(), + ); + assert_eq!(g.edge_count(), 7); + assert_eq!( + vec![0, 1, 2, 3, 4, 5], + g.node_weights().copied().collect::>(), + ); + } + + #[test] + fn test_directed_grid_less_weights() { + let g: petgraph::graph::DiGraph = + grid_graph(Some(2), Some(3), Some(vec![0, 1, 2, 3]), || 6, || (), false).unwrap(); + assert_eq!( + vec![(0, 3), (0, 1), (1, 4), (1, 2), (2, 5), (3, 4), (4, 5),], + g.edge_references() + .map(|edge| (edge.source().index(), edge.target().index())) + .collect::>(), + ); + assert_eq!(g.edge_count(), 7); + assert_eq!( + vec![0, 1, 2, 3, 6, 6], + g.node_weights().copied().collect::>(), + ); + } + + #[test] + fn test_directed_grid_bidirectional() { + let g: petgraph::graph::DiGraph<(), ()> = + grid_graph(Some(2), Some(3), None, || (), || (), true).unwrap(); + assert_eq!( + vec![ + (0, 3), + (3, 0), + (0, 1), + (1, 0), + (1, 4), + (4, 1), + (1, 2), + (2, 1), + (2, 5), + (5, 2), + (3, 4), + (4, 3), + (4, 5), + (5, 4), + ], + g.edge_references() + .map(|edge| (edge.source().index(), edge.target().index())) + .collect::>(), + ); + assert_eq!(g.edge_count(), 14); + } + + #[test] + fn test_grid_error() { + match grid_graph::, (), _, _, ()>( + None, + None, + None, + || (), + || (), + false, + ) { + Ok(_) => panic!("Returned a non-error"), + Err(e) => assert_eq!(e, InvalidInputError), + }; + } +} diff --git a/rustworkx-core/src/generators/mod.rs b/rustworkx-core/src/generators/mod.rs index f297491e5..3360ca6a9 100644 --- a/rustworkx-core/src/generators/mod.rs +++ b/rustworkx-core/src/generators/mod.rs @@ -13,6 +13,7 @@ //! This module contains generator functions for building graphs mod cycle_graph; +mod grid_graph; mod star_graph; mod utils; @@ -33,4 +34,5 @@ impl fmt::Display for InvalidInputError { } pub use cycle_graph::cycle_graph; +pub use grid_graph::grid_graph; pub use star_graph::star_graph; diff --git a/src/generators.rs b/src/generators.rs index f62425531..c6fa47a18 100644 --- a/src/generators.rs +++ b/src/generators.rs @@ -602,7 +602,7 @@ pub fn directed_mesh_graph( /// /// :param int rows: The number of rows to generate the graph with. /// If specified, cols also need to be specified -/// :param list cols: The number of rows to generate the graph with. +/// :param int cols: The number of cols to generate the graph with. /// If specified, rows also need to be specified. rows*cols /// defines the number of nodes in the graph /// :param list weights: A list of node weights. Nodes are filled row wise. @@ -639,76 +639,16 @@ pub fn grid_graph( weights: Option>, multigraph: bool, ) -> PyResult { - if weights.is_none() && (rows.is_none() || cols.is_none()) { - return Err(PyIndexError::new_err( - "dimensions and weights list not specified", - )); - } - - let mut rowlen = rows.unwrap_or(0); - let mut collen = cols.unwrap_or(0); - let mut num_nodes = rowlen * collen; - let mut num_edges = 0; - if num_nodes == 0 { - if weights.is_none() { - return Ok(graph::PyGraph { - graph: StablePyGraph::::default(), - node_removed: false, - multigraph, - attrs: py.None(), - }); - } - } else { - num_edges = (rowlen - 1) * collen + (collen - 1) * rowlen; - } - let mut graph = StablePyGraph::::with_capacity(num_nodes, num_edges); - - match weights { - Some(weights) => { - if num_nodes < weights.len() && rowlen == 0 { - collen = weights.len(); - rowlen = 1; - num_nodes = collen; - } - - let mut node_cnt = num_nodes; - - for weight in weights { - if node_cnt == 0 { - break; - } - graph.add_node(weight); - node_cnt -= 1; - } - for _i in 0..node_cnt { - graph.add_node(py.None()); - } - } - None => { - (0..num_nodes).for_each(|_| { - graph.add_node(py.None()); - }); - } - }; - - for i in 0..rowlen { - for j in 0..collen { - if i + 1 < rowlen { - graph.add_edge( - NodeIndex::new(i * collen + j), - NodeIndex::new((i + 1) * collen + j), - py.None(), - ); - } - if j + 1 < collen { - graph.add_edge( - NodeIndex::new(i * collen + j), - NodeIndex::new(i * collen + j + 1), - py.None(), - ); + let default_fn = || py.None(); + let graph: StablePyGraph = + match core_generators::grid_graph(rows, cols, weights, default_fn, default_fn, false) { + Ok(graph) => graph, + Err(_) => { + return Err(PyIndexError::new_err( + "num_nodes and weights list not specified", + )) } - } - } + }; Ok(graph::PyGraph { graph, node_removed: false, @@ -722,7 +662,7 @@ pub fn grid_graph( /// /// :param int rows: The number of rows to generate the graph with. /// If specified, cols also need to be specified. -/// :param list cols: The number of rows to generate the graph with. +/// :param int cols: The number of cols to generate the graph with. /// If specified, rows also need to be specified. rows*cols /// defines the number of nodes in the graph. /// :param list weights: A list of node weights. Nodes are filled row wise. @@ -764,96 +704,22 @@ pub fn directed_grid_graph( bidirectional: bool, multigraph: bool, ) -> PyResult { - if weights.is_none() && (rows.is_none() || cols.is_none()) { - return Err(PyIndexError::new_err( - "dimensions and weights list not specified", - )); - } - - let mut rowlen = rows.unwrap_or(0); - let mut collen = cols.unwrap_or(0); - let mut num_nodes = rowlen * collen; - let mut num_edges = 0; - if num_nodes == 0 { - if weights.is_none() { - return Ok(digraph::PyDiGraph { - graph: StablePyGraph::::default(), - node_removed: false, - check_cycle: false, - cycle_state: algo::DfsSpace::default(), - multigraph, - attrs: py.None(), - }); - } - } else { - num_edges = (rowlen - 1) * collen + (collen - 1) * rowlen; - } - if bidirectional { - num_edges *= 2; - } - let mut graph = StablePyGraph::::with_capacity(num_nodes, num_edges); - - match weights { - Some(weights) => { - if num_nodes < weights.len() && rowlen == 0 { - collen = weights.len(); - rowlen = 1; - num_nodes = collen; - } - - let mut node_cnt = num_nodes; - - for weight in weights { - if node_cnt == 0 { - break; - } - graph.add_node(weight); - node_cnt -= 1; - } - for _i in 0..node_cnt { - graph.add_node(py.None()); - } - } - None => { - (0..num_nodes).for_each(|_| { - graph.add_node(py.None()); - }); + let default_fn = || py.None(); + let graph: StablePyGraph = match core_generators::grid_graph( + rows, + cols, + weights, + default_fn, + default_fn, + bidirectional, + ) { + Ok(graph) => graph, + Err(_) => { + return Err(PyIndexError::new_err( + "num_nodes and weights list not specified", + )) } }; - - for i in 0..rowlen { - for j in 0..collen { - if i + 1 < rowlen { - graph.add_edge( - NodeIndex::new(i * collen + j), - NodeIndex::new((i + 1) * collen + j), - py.None(), - ); - if bidirectional { - graph.add_edge( - NodeIndex::new((i + 1) * collen + j), - NodeIndex::new(i * collen + j), - py.None(), - ); - } - } - - if j + 1 < collen { - graph.add_edge( - NodeIndex::new(i * collen + j), - NodeIndex::new(i * collen + j + 1), - py.None(), - ); - if bidirectional { - graph.add_edge( - NodeIndex::new(i * collen + j + 1), - NodeIndex::new(i * collen + j), - py.None(), - ); - } - } - } - } Ok(digraph::PyDiGraph { graph, node_removed: false,