Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move complete_graph generator to rustworkx-core #772

Merged
merged 14 commits into from
Jan 14, 2023
201 changes: 201 additions & 0 deletions rustworkx-core/src/generators/complete_graph.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.

use petgraph::data::{Build, Create};
use petgraph::visit::{Data, GraphProp, NodeIndexable};

use super::utils::get_num_nodes;
use super::InvalidInputError;

/// Generate a complete graph
///
/// Arguments:
///
/// * `num_nodes` - The number of nodes to create a complete graph for. Either this or
/// `weights must be specified. If both this and `weights are specified, weights
/// will take priorty and this argument will be ignored
/// * `weights` - A `Vec` of node weight objects.
/// * `default_node_weight` - A callable that will return the weight to use
/// for newly created nodes. This is ignored if `weights` is specified,
/// as the weights from that argument will be used instead.
/// * `default_edge_weight` - A callable that will return the weight object
/// to use for newly created edges.
///
/// # Example
/// ```rust
/// use rustworkx_core::petgraph;
/// use rustworkx_core::generators::complete_graph;
/// use rustworkx_core::petgraph::visit::EdgeRef;
///
/// let g: petgraph::graph::UnGraph<(), ()> = complete_graph(
/// Some(4),
/// None,
/// || {()},
/// || {()},
/// ).unwrap();
/// assert_eq!(
/// vec![(0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3)],
/// g.edge_references()
/// .map(|edge| (edge.source().index(), edge.target().index()))
/// .collect::<Vec<(usize, usize)>>(),
/// )
/// ```
pub fn complete_graph<G, T, F, H, M>(
num_nodes: Option<usize>,
weights: Option<Vec<T>>,
mut default_node_weight: F,
mut default_edge_weight: H,
) -> Result<G, InvalidInputError>
where
G: Build + Create + Data<NodeWeight = T, EdgeWeight = M> + NodeIndexable + GraphProp,
F: FnMut() -> T,
H: FnMut() -> M,
{
if weights.is_none() && num_nodes.is_none() {
return Err(InvalidInputError {});
}
let node_len = get_num_nodes(&num_nodes, &weights);
let mut graph = G::with_capacity(node_len, node_len);
if node_len == 0 {
return Ok(graph);
}

match weights {
Some(weights) => {
for weight in weights {
graph.add_node(weight);
}
}
None => {
for _ in 0..node_len {
graph.add_node(default_node_weight());
}
}
};
for i in 0..node_len - 1 {
for j in i + 1..node_len {
let node_i = graph.from_index(i);
let node_j = graph.from_index(j);
graph.add_edge(node_i, node_j, default_edge_weight());
if graph.is_directed() {
graph.add_edge(node_j, node_i, default_edge_weight());
}
}
}
Ok(graph)
}

#[cfg(test)]
mod tests {
use crate::generators::complete_graph;
use crate::generators::InvalidInputError;
use crate::petgraph::graph::{DiGraph, NodeIndex, UnGraph};
use crate::petgraph::visit::EdgeRef;

#[test]
fn test_directed_complete_graph() {
let g: DiGraph<(), ()> = complete_graph(Some(10), None, || (), || ()).unwrap();
assert_eq!(g.node_count(), 10);
assert_eq!(g.edge_count(), 90);
let mut elist = vec![];
for i in 0..10 {
for j in i..10 {
if i != j {
elist.push((i, j));
elist.push((j, i));
}
}
}
assert_eq!(
elist,
g.edge_references()
.map(|edge| (edge.source().index(), edge.target().index()))
.collect::<Vec<(usize, usize)>>(),
);
}

#[test]
fn test_directed_complete_graph_weights() {
let g: DiGraph<usize, ()> =
complete_graph(None, Some(vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), || 4, || ()).unwrap();
assert_eq!(g.node_count(), 10);
assert_eq!(g.edge_count(), 90);
let mut elist = vec![];
for i in 0..10 {
for j in i..10 {
if i != j {
elist.push((i, j));
elist.push((j, i));
}
}
assert_eq!(*g.node_weight(NodeIndex::new(i)).unwrap(), i);
}
assert_eq!(
elist,
g.edge_references()
.map(|edge| (edge.source().index(), edge.target().index()))
.collect::<Vec<(usize, usize)>>(),
);
}

#[test]
fn test_compete_graph_error() {
match complete_graph::<DiGraph<(), ()>, (), _, _, ()>(None, None, || (), || ()) {
Ok(_) => panic!("Returned a non-error"),
Err(e) => assert_eq!(e, InvalidInputError),
};
}

#[test]
fn test_complete_graph() {
let g: UnGraph<(), ()> = complete_graph(Some(10), None, || (), || ()).unwrap();
assert_eq!(g.node_count(), 10);
assert_eq!(g.edge_count(), 45);
let mut elist = vec![];
for i in 0..10 {
for j in i..10 {
if i != j {
elist.push((i, j));
}
}
}
assert_eq!(
elist,
g.edge_references()
.map(|edge| (edge.source().index(), edge.target().index()))
.collect::<Vec<(usize, usize)>>(),
);
}

#[test]
fn test_complete_graph_weights() {
let g: UnGraph<usize, ()> =
complete_graph(None, Some(vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), || 4, || ()).unwrap();
assert_eq!(g.node_count(), 10);
assert_eq!(g.edge_count(), 45);
let mut elist = vec![];
for i in 0..10 {
for j in i..10 {
if i != j {
elist.push((i, j));
}
}
assert_eq!(*g.node_weight(NodeIndex::new(i)).unwrap(), i);
}
assert_eq!(
elist,
g.edge_references()
.map(|edge| (edge.source().index(), edge.target().index()))
.collect::<Vec<(usize, usize)>>(),
);
}
}
2 changes: 2 additions & 0 deletions rustworkx-core/src/generators/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

mod barbell_graph;
mod binomial_tree_graph;
mod complete_graph;
mod cycle_graph;
mod grid_graph;
mod heavy_hex_graph;
Expand Down Expand Up @@ -43,6 +44,7 @@ impl fmt::Display for InvalidInputError {

pub use barbell_graph::barbell_graph;
pub use binomial_tree_graph::binomial_tree_graph;
pub use complete_graph::complete_graph;
pub use cycle_graph::cycle_graph;
pub use grid_graph::grid_graph;
pub use heavy_hex_graph::heavy_hex_graph;
Expand Down
137 changes: 37 additions & 100 deletions src/generators.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,6 @@ where
left.zip(right)
}

#[inline]
fn get_num_nodes(num_nodes: &Option<usize>, weights: &Option<Vec<PyObject>>) -> usize {
if weights.is_some() {
weights.as_ref().unwrap().len()
} else {
num_nodes.unwrap()
}
}

/// Generate a cycle graph
///
/// :param int num_node: The number of nodes to generate the graph with. Node
Expand Down Expand Up @@ -428,48 +419,7 @@ pub fn mesh_graph(
weights: Option<Vec<PyObject>>,
multigraph: bool,
) -> PyResult<graph::PyGraph> {
if weights.is_none() && num_nodes.is_none() {
return Err(PyIndexError::new_err(
"num_nodes and weights list not specified",
));
}
let node_len = get_num_nodes(&num_nodes, &weights);
if node_len == 0 {
return Ok(graph::PyGraph {
graph: StablePyGraph::<Undirected>::default(),
node_removed: false,
multigraph,
attrs: py.None(),
});
}
let num_edges = (node_len * (node_len - 1)) / 2;
let mut graph = StablePyGraph::<Undirected>::with_capacity(node_len, num_edges);
match weights {
Some(weights) => {
for weight in weights {
graph.add_node(weight);
}
}
None => {
(0..node_len).for_each(|_| {
graph.add_node(py.None());
});
}
};

for i in 0..node_len - 1 {
for j in i + 1..node_len {
let i_index = NodeIndex::new(i);
let j_index = NodeIndex::new(j);
graph.add_edge(i_index, j_index, py.None());
}
}
Ok(graph::PyGraph {
graph,
node_removed: false,
multigraph,
attrs: py.None(),
})
complete_graph(py, num_nodes, weights, multigraph)
}

/// Generate a directed mesh graph where every node is connected to every other
Expand Down Expand Up @@ -504,52 +454,7 @@ pub fn directed_mesh_graph(
weights: Option<Vec<PyObject>>,
multigraph: bool,
) -> PyResult<digraph::PyDiGraph> {
if weights.is_none() && num_nodes.is_none() {
return Err(PyIndexError::new_err(
"num_nodes and weights list not specified",
));
}
let node_len = get_num_nodes(&num_nodes, &weights);
if node_len == 0 {
return Ok(digraph::PyDiGraph {
graph: StablePyGraph::<Directed>::default(),
node_removed: false,
check_cycle: false,
cycle_state: algo::DfsSpace::default(),
multigraph,
attrs: py.None(),
});
}
let num_edges = node_len * (node_len - 1);
let mut graph = StablePyGraph::<Directed>::with_capacity(node_len, num_edges);
match weights {
Comment on lines -507 to -525
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same apllies here, I'd call directed_mesh_graph(py, num_nodes, weights, multigraph)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in af75626.

Some(weights) => {
for weight in weights {
graph.add_node(weight);
}
}
None => {
(0..node_len).for_each(|_| {
graph.add_node(py.None());
});
}
};
for i in 0..node_len - 1 {
for j in i + 1..node_len {
let i_index = NodeIndex::new(i);
let j_index = NodeIndex::new(j);
graph.add_edge(i_index, j_index, py.None());
graph.add_edge(j_index, i_index, py.None());
}
}
Ok(digraph::PyDiGraph {
graph,
node_removed: false,
check_cycle: false,
cycle_state: algo::DfsSpace::default(),
multigraph,
attrs: py.None(),
})
directed_complete_graph(py, num_nodes, weights, multigraph)
}

/// Generate an undirected grid graph.
Expand Down Expand Up @@ -1603,7 +1508,22 @@ pub fn complete_graph(
weights: Option<Vec<PyObject>>,
multigraph: bool,
) -> PyResult<graph::PyGraph> {
mesh_graph(py, num_nodes, weights, multigraph)
enavarro51 marked this conversation as resolved.
Show resolved Hide resolved
let default_fn = || py.None();
let graph: StablePyGraph<Undirected> =
match core_generators::complete_graph(num_nodes, weights, default_fn, default_fn) {
Ok(graph) => graph,
Err(_) => {
return Err(PyIndexError::new_err(
"num_nodes and weights list not specified",
))
}
};
Ok(graph::PyGraph {
graph,
node_removed: false,
multigraph,
attrs: py.None(),
})
}

/// Generate a directed complete graph with ``n`` nodes.
Expand Down Expand Up @@ -1636,15 +1556,32 @@ pub fn complete_graph(
/// graph = rustworkx.generators.directed_complete_graph(5)
/// mpl_draw(graph)
///
#[pyfunction(multigraph = true)]
#[pyfunction(multigraph = "true")]
#[pyo3(text_signature = "(/, num_nodes=None, weights=None, multigraph=True)")]
pub fn directed_complete_graph(
py: Python,
num_nodes: Option<usize>,
weights: Option<Vec<PyObject>>,
multigraph: bool,
) -> PyResult<digraph::PyDiGraph> {
directed_mesh_graph(py, num_nodes, weights, multigraph)
enavarro51 marked this conversation as resolved.
Show resolved Hide resolved
let default_fn = || py.None();
let graph: StablePyGraph<Directed> =
match core_generators::complete_graph(num_nodes, weights, default_fn, default_fn) {
Ok(graph) => graph,
Err(_) => {
return Err(PyIndexError::new_err(
"num_nodes and weights list not specified",
))
}
};
Ok(digraph::PyDiGraph {
graph,
node_removed: false,
check_cycle: false,
cycle_state: algo::DfsSpace::default(),
multigraph,
attrs: py.None(),
})
}

#[pymodule]
Expand Down