Skip to content

Commit

Permalink
Move complete_graph generator to rustworkx-core (#772)
Browse files Browse the repository at this point in the history
* Initial complete graph

* Add tests

* Finish tests

* Fmt

* Mesh calls complete

* Update tests

* Fmt

* Simplify mesh call

* Remove get_num_nodes from generators.rs
  • Loading branch information
enavarro51 committed Jan 14, 2023
1 parent 810391b commit 301f6e8
Show file tree
Hide file tree
Showing 3 changed files with 240 additions and 100 deletions.
201 changes: 201 additions & 0 deletions rustworkx-core/src/generators/complete_graph.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.

use petgraph::data::{Build, Create};
use petgraph::visit::{Data, GraphProp, NodeIndexable};

use super::utils::get_num_nodes;
use super::InvalidInputError;

/// Generate a complete graph
///
/// Arguments:
///
/// * `num_nodes` - The number of nodes to create a complete graph for. Either this or
/// `weights must be specified. If both this and `weights are specified, weights
/// will take priorty and this argument will be ignored
/// * `weights` - A `Vec` of node weight objects.
/// * `default_node_weight` - A callable that will return the weight to use
/// for newly created nodes. This is ignored if `weights` is specified,
/// as the weights from that argument will be used instead.
/// * `default_edge_weight` - A callable that will return the weight object
/// to use for newly created edges.
///
/// # Example
/// ```rust
/// use rustworkx_core::petgraph;
/// use rustworkx_core::generators::complete_graph;
/// use rustworkx_core::petgraph::visit::EdgeRef;
///
/// let g: petgraph::graph::UnGraph<(), ()> = complete_graph(
/// Some(4),
/// None,
/// || {()},
/// || {()},
/// ).unwrap();
/// assert_eq!(
/// vec![(0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3)],
/// g.edge_references()
/// .map(|edge| (edge.source().index(), edge.target().index()))
/// .collect::<Vec<(usize, usize)>>(),
/// )
/// ```
pub fn complete_graph<G, T, F, H, M>(
num_nodes: Option<usize>,
weights: Option<Vec<T>>,
mut default_node_weight: F,
mut default_edge_weight: H,
) -> Result<G, InvalidInputError>
where
G: Build + Create + Data<NodeWeight = T, EdgeWeight = M> + NodeIndexable + GraphProp,
F: FnMut() -> T,
H: FnMut() -> M,
{
if weights.is_none() && num_nodes.is_none() {
return Err(InvalidInputError {});
}
let node_len = get_num_nodes(&num_nodes, &weights);
let mut graph = G::with_capacity(node_len, node_len);
if node_len == 0 {
return Ok(graph);
}

match weights {
Some(weights) => {
for weight in weights {
graph.add_node(weight);
}
}
None => {
for _ in 0..node_len {
graph.add_node(default_node_weight());
}
}
};
for i in 0..node_len - 1 {
for j in i + 1..node_len {
let node_i = graph.from_index(i);
let node_j = graph.from_index(j);
graph.add_edge(node_i, node_j, default_edge_weight());
if graph.is_directed() {
graph.add_edge(node_j, node_i, default_edge_weight());
}
}
}
Ok(graph)
}

#[cfg(test)]
mod tests {
use crate::generators::complete_graph;
use crate::generators::InvalidInputError;
use crate::petgraph::graph::{DiGraph, NodeIndex, UnGraph};
use crate::petgraph::visit::EdgeRef;

#[test]
fn test_directed_complete_graph() {
let g: DiGraph<(), ()> = complete_graph(Some(10), None, || (), || ()).unwrap();
assert_eq!(g.node_count(), 10);
assert_eq!(g.edge_count(), 90);
let mut elist = vec![];
for i in 0..10 {
for j in i..10 {
if i != j {
elist.push((i, j));
elist.push((j, i));
}
}
}
assert_eq!(
elist,
g.edge_references()
.map(|edge| (edge.source().index(), edge.target().index()))
.collect::<Vec<(usize, usize)>>(),
);
}

#[test]
fn test_directed_complete_graph_weights() {
let g: DiGraph<usize, ()> =
complete_graph(None, Some(vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), || 4, || ()).unwrap();
assert_eq!(g.node_count(), 10);
assert_eq!(g.edge_count(), 90);
let mut elist = vec![];
for i in 0..10 {
for j in i..10 {
if i != j {
elist.push((i, j));
elist.push((j, i));
}
}
assert_eq!(*g.node_weight(NodeIndex::new(i)).unwrap(), i);
}
assert_eq!(
elist,
g.edge_references()
.map(|edge| (edge.source().index(), edge.target().index()))
.collect::<Vec<(usize, usize)>>(),
);
}

#[test]
fn test_compete_graph_error() {
match complete_graph::<DiGraph<(), ()>, (), _, _, ()>(None, None, || (), || ()) {
Ok(_) => panic!("Returned a non-error"),
Err(e) => assert_eq!(e, InvalidInputError),
};
}

#[test]
fn test_complete_graph() {
let g: UnGraph<(), ()> = complete_graph(Some(10), None, || (), || ()).unwrap();
assert_eq!(g.node_count(), 10);
assert_eq!(g.edge_count(), 45);
let mut elist = vec![];
for i in 0..10 {
for j in i..10 {
if i != j {
elist.push((i, j));
}
}
}
assert_eq!(
elist,
g.edge_references()
.map(|edge| (edge.source().index(), edge.target().index()))
.collect::<Vec<(usize, usize)>>(),
);
}

#[test]
fn test_complete_graph_weights() {
let g: UnGraph<usize, ()> =
complete_graph(None, Some(vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), || 4, || ()).unwrap();
assert_eq!(g.node_count(), 10);
assert_eq!(g.edge_count(), 45);
let mut elist = vec![];
for i in 0..10 {
for j in i..10 {
if i != j {
elist.push((i, j));
}
}
assert_eq!(*g.node_weight(NodeIndex::new(i)).unwrap(), i);
}
assert_eq!(
elist,
g.edge_references()
.map(|edge| (edge.source().index(), edge.target().index()))
.collect::<Vec<(usize, usize)>>(),
);
}
}
2 changes: 2 additions & 0 deletions rustworkx-core/src/generators/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

mod barbell_graph;
mod binomial_tree_graph;
mod complete_graph;
mod cycle_graph;
mod grid_graph;
mod heavy_hex_graph;
Expand Down Expand Up @@ -43,6 +44,7 @@ impl fmt::Display for InvalidInputError {

pub use barbell_graph::barbell_graph;
pub use binomial_tree_graph::binomial_tree_graph;
pub use complete_graph::complete_graph;
pub use cycle_graph::cycle_graph;
pub use grid_graph::grid_graph;
pub use heavy_hex_graph::heavy_hex_graph;
Expand Down
137 changes: 37 additions & 100 deletions src/generators.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,6 @@ where
left.zip(right)
}

#[inline]
fn get_num_nodes(num_nodes: &Option<usize>, weights: &Option<Vec<PyObject>>) -> usize {
if weights.is_some() {
weights.as_ref().unwrap().len()
} else {
num_nodes.unwrap()
}
}

/// Generate a cycle graph
///
/// :param int num_node: The number of nodes to generate the graph with. Node
Expand Down Expand Up @@ -428,48 +419,7 @@ pub fn mesh_graph(
weights: Option<Vec<PyObject>>,
multigraph: bool,
) -> PyResult<graph::PyGraph> {
if weights.is_none() && num_nodes.is_none() {
return Err(PyIndexError::new_err(
"num_nodes and weights list not specified",
));
}
let node_len = get_num_nodes(&num_nodes, &weights);
if node_len == 0 {
return Ok(graph::PyGraph {
graph: StablePyGraph::<Undirected>::default(),
node_removed: false,
multigraph,
attrs: py.None(),
});
}
let num_edges = (node_len * (node_len - 1)) / 2;
let mut graph = StablePyGraph::<Undirected>::with_capacity(node_len, num_edges);
match weights {
Some(weights) => {
for weight in weights {
graph.add_node(weight);
}
}
None => {
(0..node_len).for_each(|_| {
graph.add_node(py.None());
});
}
};

for i in 0..node_len - 1 {
for j in i + 1..node_len {
let i_index = NodeIndex::new(i);
let j_index = NodeIndex::new(j);
graph.add_edge(i_index, j_index, py.None());
}
}
Ok(graph::PyGraph {
graph,
node_removed: false,
multigraph,
attrs: py.None(),
})
complete_graph(py, num_nodes, weights, multigraph)
}

/// Generate a directed mesh graph where every node is connected to every other
Expand Down Expand Up @@ -504,52 +454,7 @@ pub fn directed_mesh_graph(
weights: Option<Vec<PyObject>>,
multigraph: bool,
) -> PyResult<digraph::PyDiGraph> {
if weights.is_none() && num_nodes.is_none() {
return Err(PyIndexError::new_err(
"num_nodes and weights list not specified",
));
}
let node_len = get_num_nodes(&num_nodes, &weights);
if node_len == 0 {
return Ok(digraph::PyDiGraph {
graph: StablePyGraph::<Directed>::default(),
node_removed: false,
check_cycle: false,
cycle_state: algo::DfsSpace::default(),
multigraph,
attrs: py.None(),
});
}
let num_edges = node_len * (node_len - 1);
let mut graph = StablePyGraph::<Directed>::with_capacity(node_len, num_edges);
match weights {
Some(weights) => {
for weight in weights {
graph.add_node(weight);
}
}
None => {
(0..node_len).for_each(|_| {
graph.add_node(py.None());
});
}
};
for i in 0..node_len - 1 {
for j in i + 1..node_len {
let i_index = NodeIndex::new(i);
let j_index = NodeIndex::new(j);
graph.add_edge(i_index, j_index, py.None());
graph.add_edge(j_index, i_index, py.None());
}
}
Ok(digraph::PyDiGraph {
graph,
node_removed: false,
check_cycle: false,
cycle_state: algo::DfsSpace::default(),
multigraph,
attrs: py.None(),
})
directed_complete_graph(py, num_nodes, weights, multigraph)
}

/// Generate an undirected grid graph.
Expand Down Expand Up @@ -1603,7 +1508,22 @@ pub fn complete_graph(
weights: Option<Vec<PyObject>>,
multigraph: bool,
) -> PyResult<graph::PyGraph> {
mesh_graph(py, num_nodes, weights, multigraph)
let default_fn = || py.None();
let graph: StablePyGraph<Undirected> =
match core_generators::complete_graph(num_nodes, weights, default_fn, default_fn) {
Ok(graph) => graph,
Err(_) => {
return Err(PyIndexError::new_err(
"num_nodes and weights list not specified",
))
}
};
Ok(graph::PyGraph {
graph,
node_removed: false,
multigraph,
attrs: py.None(),
})
}

/// Generate a directed complete graph with ``n`` nodes.
Expand Down Expand Up @@ -1636,15 +1556,32 @@ pub fn complete_graph(
/// graph = rustworkx.generators.directed_complete_graph(5)
/// mpl_draw(graph)
///
#[pyfunction(multigraph = true)]
#[pyfunction(multigraph = "true")]
#[pyo3(text_signature = "(/, num_nodes=None, weights=None, multigraph=True)")]
pub fn directed_complete_graph(
py: Python,
num_nodes: Option<usize>,
weights: Option<Vec<PyObject>>,
multigraph: bool,
) -> PyResult<digraph::PyDiGraph> {
directed_mesh_graph(py, num_nodes, weights, multigraph)
let default_fn = || py.None();
let graph: StablePyGraph<Directed> =
match core_generators::complete_graph(num_nodes, weights, default_fn, default_fn) {
Ok(graph) => graph,
Err(_) => {
return Err(PyIndexError::new_err(
"num_nodes and weights list not specified",
))
}
};
Ok(digraph::PyDiGraph {
graph,
node_removed: false,
check_cycle: false,
cycle_state: algo::DfsSpace::default(),
multigraph,
attrs: py.None(),
})
}

#[pymodule]
Expand Down

0 comments on commit 301f6e8

Please sign in to comment.