diff --git a/Cargo.lock b/Cargo.lock index 959c32034..25806f054 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -633,6 +633,7 @@ dependencies = [ "fixedbitset", "hashbrown 0.14.5", "indexmap 2.2.6", + "ndarray", "num-traits", "petgraph", "priority-queue", diff --git a/Cargo.toml b/Cargo.toml index 1e50fea81..c3bb08693 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -28,6 +28,7 @@ ahash = "0.8.6" fixedbitset = "0.4.2" hashbrown = { version = ">=0.13, <0.15", features = ["rayon"] } indexmap = { version = ">=1.9, <3", features = ["rayon"] } +ndarray = { version = "0.15.6", features = ["rayon"] } num-traits = "0.2" numpy = "0.21.0" petgraph = "0.6.5" @@ -44,6 +45,7 @@ ahash.workspace = true fixedbitset.workspace = true hashbrown.workspace = true indexmap.workspace = true +ndarray.workspace = true ndarray-stats = "0.5.1" num-bigint = "0.4" num-complex = "0.4" @@ -63,10 +65,6 @@ rustworkx-core = { path = "rustworkx-core", version = "=0.15.0" } version = "0.21.2" features = ["abi3-py38", "extension-module", "hashbrown", "num-bigint", "num-complex", "indexmap"] -[dependencies.ndarray] -version = "^0.15.6" -features = ["rayon"] - [dependencies.sprs] version = "^0.11" features = ["multi_thread"] diff --git a/docs/source/api/random_graph_generator_functions.rst b/docs/source/api/random_graph_generator_functions.rst index 4bc52096e..4c0a33c5f 100644 --- a/docs/source/api/random_graph_generator_functions.rst +++ b/docs/source/api/random_graph_generator_functions.rst @@ -10,6 +10,8 @@ Random Graph Generator Functions rustworkx.undirected_gnp_random_graph rustworkx.directed_gnm_random_graph rustworkx.undirected_gnm_random_graph + rustworkx.directed_sbm_random_graph + rustworkx.undirected_sbm_random_graph rustworkx.random_geometric_graph rustworkx.hyperbolic_random_graph rustworkx.barabasi_albert_graph diff --git a/releasenotes/notes/sbm-random-graph-bf7ccd8e938f4218.yaml b/releasenotes/notes/sbm-random-graph-bf7ccd8e938f4218.yaml new file mode 100644 index 000000000..8ec9490a4 --- /dev/null +++ b/releasenotes/notes/sbm-random-graph-bf7ccd8e938f4218.yaml @@ -0,0 +1,9 @@ +features: + - | + Adds new random graph generator in rustworkx for the stochastic block model. + There is a generator for directed :func:`.directed_sbm_random_graph` and + undirected graphs :func:`.undirected_sbm_random_graph`. + - | + Adds new function ``sbm_random_graph`` to the rustworkx-core module + ``rustworkx_core::generators`` that samples a graph from the stochastic + block model. diff --git a/rustworkx-core/Cargo.toml b/rustworkx-core/Cargo.toml index 781a9fbf5..c8d292627 100644 --- a/rustworkx-core/Cargo.toml +++ b/rustworkx-core/Cargo.toml @@ -16,6 +16,7 @@ ahash.workspace = true fixedbitset.workspace = true hashbrown.workspace = true indexmap.workspace = true +ndarray.workspace = true num-traits.workspace = true petgraph.workspace = true priority-queue = "2.0" diff --git a/rustworkx-core/src/generators/mod.rs b/rustworkx-core/src/generators/mod.rs index 04d672749..3034baab1 100644 --- a/rustworkx-core/src/generators/mod.rs +++ b/rustworkx-core/src/generators/mod.rs @@ -62,4 +62,5 @@ pub use random_graph::gnp_random_graph; pub use random_graph::hyperbolic_random_graph; pub use random_graph::random_bipartite_graph; pub use random_graph::random_geometric_graph; +pub use random_graph::sbm_random_graph; pub use star_graph::star_graph; diff --git a/rustworkx-core/src/generators/random_graph.rs b/rustworkx-core/src/generators/random_graph.rs index 1768619d5..edea398fb 100644 --- a/rustworkx-core/src/generators/random_graph.rs +++ b/rustworkx-core/src/generators/random_graph.rs @@ -14,6 +14,7 @@ use std::hash::Hash; +use ndarray::ArrayView2; use petgraph::data::{Build, Create}; use petgraph::visit::{ Data, EdgeRef, GraphBase, GraphProp, IntoEdgeReferences, IntoEdgesDirected, @@ -305,6 +306,131 @@ where Ok(graph) } +/// Generate a graph from the stochastic block model. +/// +/// The stochastic block model is a generalization of the Gnp random graph +/// (see [gnp_random_graph] ). The connection probability of +/// nodes `u` and `v` depends on their block and is given by +/// `probabilities[blocks[u]][blocks[v]]`, where `blocks[u]` is the block membership +/// of vertex `u`. The number of nodes and the number of blocks are inferred from +/// `sizes`. +/// +/// Arguments: +/// +/// * `sizes` - Number of nodes in each block. +/// * `probabilities` - B x B array that contains the connection probability between +/// nodes of different blocks. Must be symmetric for undirected graphs. +/// * `loops` - Determines whether the graph can have loops or not. +/// * `seed` - An optional seed to use for the random number generator. +/// * `default_node_weight` - A callable that will return the weight to use +/// for newly created nodes. +/// * `default_edge_weight` - A callable that will return the weight object +/// to use for newly created edges. +/// +/// # Example +/// ```rust +/// use ndarray::arr2; +/// use rustworkx_core::petgraph; +/// use rustworkx_core::generators::sbm_random_graph; +/// +/// let g = sbm_random_graph::, (), _, _, ()>( +/// &vec![1, 2], +/// &ndarray::arr2(&[[0., 1.], [0., 1.]]).view(), +/// true, +/// Some(10), +/// || (), +/// || (), +/// ) +/// .unwrap(); +/// assert_eq!(g.node_count(), 3); +/// assert_eq!(g.edge_count(), 6); +/// ``` +pub fn sbm_random_graph( + sizes: &[usize], + probabilities: &ndarray::ArrayView2, + loops: bool, + seed: Option, + mut default_node_weight: F, + mut default_edge_weight: H, +) -> Result +where + G: Build + Create + Data + NodeIndexable + GraphProp, + F: FnMut() -> T, + H: FnMut() -> M, + G::NodeId: Eq + Hash, +{ + let num_nodes: usize = sizes.iter().sum(); + if num_nodes == 0 { + return Err(InvalidInputError {}); + } + let num_communities = sizes.len(); + if probabilities.nrows() != num_communities + || probabilities.ncols() != num_communities + || probabilities.iter().any(|&x| !(0. ..=1.).contains(&x)) + { + return Err(InvalidInputError {}); + } + + let mut graph = G::with_capacity(num_nodes, num_nodes); + let directed = graph.is_directed(); + if !directed && !symmetric_array(probabilities) { + return Err(InvalidInputError {}); + } + + for _ in 0..num_nodes { + graph.add_node(default_node_weight()); + } + let mut rng: Pcg64 = match seed { + Some(seed) => Pcg64::seed_from_u64(seed), + None => Pcg64::from_entropy(), + }; + let mut blocks = Vec::new(); + { + let mut block = 0; + let mut vertices_left = sizes[0]; + for _ in 0..num_nodes { + while vertices_left == 0 { + block += 1; + vertices_left = sizes[block]; + } + blocks.push(block); + vertices_left -= 1; + } + } + + let between = Uniform::new(0.0, 1.0); + for v in 0..(if directed || loops { + num_nodes + } else { + num_nodes - 1 + }) { + for w in ((if directed { 0 } else { v })..num_nodes).filter(|&w| w != v || loops) { + if &between.sample(&mut rng) + < probabilities.get((blocks[v], blocks[w])).unwrap_or(&0_f64) + { + graph.add_edge( + graph.from_index(v), + graph.from_index(w), + default_edge_weight(), + ); + } + } + } + Ok(graph) +} + +fn symmetric_array(mat: &ArrayView2) -> bool { + let n = mat.nrows(); + for (i, row) in mat.rows().into_iter().enumerate().take(n - 1) { + for (j, m_ij) in row.iter().enumerate().skip(i + 1) { + if m_ij != mat.get((j, i)).unwrap() { + return false; + } + } + } + true +} + #[inline] fn pnorm(x: f64, p: f64) -> f64 { if p == 1.0 || p == std::f64::INFINITY { @@ -749,7 +875,7 @@ mod tests { use crate::generators::InvalidInputError; use crate::generators::{ barabasi_albert_graph, gnm_random_graph, gnp_random_graph, hyperbolic_random_graph, - path_graph, random_bipartite_graph, random_geometric_graph, + path_graph, random_bipartite_graph, random_geometric_graph, sbm_random_graph, }; use crate::petgraph; @@ -879,6 +1005,165 @@ mod tests { }; } + // Test sbm_random_graph + #[test] + fn test_sbm_directed_complete_blocks_loops() { + let g = sbm_random_graph::, (), _, _, ()>( + &vec![1, 2], + &ndarray::arr2(&[[0., 1.], [0., 1.]]).view(), + true, + Some(10), + || (), + || (), + ) + .unwrap(); + assert_eq!(g.node_count(), 3); + assert_eq!(g.edge_count(), 6); + for (u, v) in [(1, 1), (1, 2), (2, 1), (2, 2), (0, 1), (0, 2)] { + assert_eq!(g.contains_edge(u.into(), v.into()), true); + } + assert_eq!(g.contains_edge(1.into(), 0.into()), false); + assert_eq!(g.contains_edge(2.into(), 0.into()), false); + } + + #[test] + fn test_sbm_undirected_complete_blocks_loops() { + let g = sbm_random_graph::, (), _, _, ()>( + &vec![1, 2], + &ndarray::arr2(&[[0., 1.], [1., 1.]]).view(), + true, + Some(10), + || (), + || (), + ) + .unwrap(); + assert_eq!(g.node_count(), 3); + assert_eq!(g.edge_count(), 5); + for (u, v) in [(1, 1), (1, 2), (2, 2), (0, 1), (0, 2)] { + assert_eq!(g.contains_edge(u.into(), v.into()), true); + } + assert_eq!(g.contains_edge(0.into(), 0.into()), false); + } + + #[test] + fn test_sbm_directed_complete_blocks_noloops() { + let g = sbm_random_graph::, (), _, _, ()>( + &vec![1, 2], + &ndarray::arr2(&[[0., 1.], [0., 1.]]).view(), + false, + Some(10), + || (), + || (), + ) + .unwrap(); + assert_eq!(g.node_count(), 3); + assert_eq!(g.edge_count(), 4); + for (u, v) in [(1, 2), (2, 1), (0, 1), (0, 2)] { + assert_eq!(g.contains_edge(u.into(), v.into()), true); + } + assert_eq!(g.contains_edge(1.into(), 0.into()), false); + assert_eq!(g.contains_edge(2.into(), 0.into()), false); + for u in 0..2 { + assert_eq!(g.contains_edge(u.into(), u.into()), false); + } + } + + #[test] + fn test_sbm_undirected_complete_blocks_noloops() { + let g = sbm_random_graph::, (), _, _, ()>( + &vec![1, 2], + &ndarray::arr2(&[[0., 1.], [1., 1.]]).view(), + false, + Some(10), + || (), + || (), + ) + .unwrap(); + assert_eq!(g.node_count(), 3); + assert_eq!(g.edge_count(), 3); + for (u, v) in [(1, 2), (0, 1), (0, 2)] { + assert_eq!(g.contains_edge(u.into(), v.into()), true); + } + for u in 0..2 { + assert_eq!(g.contains_edge(u.into(), u.into()), false); + } + } + + #[test] + fn test_sbm_bad_array_rows_error() { + match sbm_random_graph::, (), _, _, ()>( + &vec![1, 2], + &ndarray::arr2(&[[0., 1.], [1., 1.], [1., 1.]]).view(), + true, + Some(10), + || (), + || (), + ) { + Ok(_) => panic!("Returned a non-error"), + Err(e) => assert_eq!(e, InvalidInputError), + }; + } + #[test] + + fn test_sbm_bad_array_cols_error() { + match sbm_random_graph::, (), _, _, ()>( + &vec![1, 2], + &ndarray::arr2(&[[0., 1., 1.], [1., 1., 1.]]).view(), + true, + Some(10), + || (), + || (), + ) { + Ok(_) => panic!("Returned a non-error"), + Err(e) => assert_eq!(e, InvalidInputError), + }; + } + + #[test] + fn test_sbm_asymmetric_array_error() { + match sbm_random_graph::, (), _, _, ()>( + &vec![1, 2], + &ndarray::arr2(&[[0., 1.], [0., 1.]]).view(), + true, + Some(10), + || (), + || (), + ) { + Ok(_) => panic!("Returned a non-error"), + Err(e) => assert_eq!(e, InvalidInputError), + }; + } + + #[test] + fn test_sbm_invalid_probability_error() { + match sbm_random_graph::, (), _, _, ()>( + &vec![1, 2], + &ndarray::arr2(&[[0., 1.], [0., -1.]]).view(), + true, + Some(10), + || (), + || (), + ) { + Ok(_) => panic!("Returned a non-error"), + Err(e) => assert_eq!(e, InvalidInputError), + }; + } + + #[test] + fn test_sbm_empty_error() { + match sbm_random_graph::, (), _, _, ()>( + &vec![], + &ndarray::arr2(&[[]]).view(), + true, + Some(10), + || (), + || (), + ) { + Ok(_) => panic!("Returned a non-error"), + Err(e) => assert_eq!(e, InvalidInputError), + }; + } + // Test random_geometric_graph #[test] diff --git a/rustworkx/__init__.pyi b/rustworkx/__init__.pyi index f25387975..f499b9813 100644 --- a/rustworkx/__init__.pyi +++ b/rustworkx/__init__.pyi @@ -127,6 +127,8 @@ from .rustworkx import directed_gnm_random_graph as directed_gnm_random_graph from .rustworkx import undirected_gnm_random_graph as undirected_gnm_random_graph from .rustworkx import directed_gnp_random_graph as directed_gnp_random_graph from .rustworkx import undirected_gnp_random_graph as undirected_gnp_random_graph +from .rustworkx import directed_sbm_random_graph as directed_sbm_random_graph +from .rustworkx import undirected_sbm_random_graph as undirected_sbm_random_graph from .rustworkx import random_geometric_graph as random_geometric_graph from .rustworkx import hyperbolic_random_graph as hyperbolic_random_graph from .rustworkx import barabasi_albert_graph as barabasi_albert_graph diff --git a/rustworkx/rustworkx.pyi b/rustworkx/rustworkx.pyi index bebe0520e..522962994 100644 --- a/rustworkx/rustworkx.pyi +++ b/rustworkx/rustworkx.pyi @@ -549,6 +549,20 @@ def undirected_gnp_random_graph( /, seed: int | None = ..., ) -> PyGraph: ... +def directed_sbm_random_graph( + sizes: list[int], + probabilities: np.ndarray, + loops: bool, + /, + seed: int | None = ..., +) -> PyDiGraph: ... +def undirected_sbm_random_graph( + sizes: list[int], + probabilities: np.ndarray, + loops: bool, + /, + seed: int | None = ..., +) -> PyGraph: ... def random_geometric_graph( num_nodes: int, radius: float, diff --git a/src/lib.rs b/src/lib.rs index ce0843b8e..164b713c5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -520,6 +520,8 @@ fn rustworkx(py: Python<'_>, m: &Bound) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(undirected_gnp_random_graph))?; m.add_wrapped(wrap_pyfunction!(directed_gnm_random_graph))?; m.add_wrapped(wrap_pyfunction!(undirected_gnm_random_graph))?; + m.add_wrapped(wrap_pyfunction!(undirected_sbm_random_graph))?; + m.add_wrapped(wrap_pyfunction!(directed_sbm_random_graph))?; m.add_wrapped(wrap_pyfunction!(random_geometric_graph))?; m.add_wrapped(wrap_pyfunction!(hyperbolic_random_graph))?; m.add_wrapped(wrap_pyfunction!(barabasi_albert_graph))?; diff --git a/src/random_graph.rs b/src/random_graph.rs index f0e6ee679..8360c0ff9 100644 --- a/src/random_graph.rs +++ b/src/random_graph.rs @@ -23,6 +23,8 @@ use petgraph::algo; use petgraph::graph::NodeIndex; use petgraph::prelude::*; +use numpy::PyReadonlyArray2; + use rand::distributions::{Distribution, Uniform}; use rand::prelude::*; use rand_pcg::Pcg64; @@ -273,6 +275,116 @@ pub fn undirected_gnm_random_graph( }) } +/// Return a directed graph from the stochastic block model. +/// +/// The stochastic block model is a generalization of the :math:`G(n,p)` random graph +/// (see :func:`~rustworkx.directed_gnp_random_graph`). The connection probability of +/// nodes ``u`` and ``v`` depends on their block (or community) and is given by +/// ``probabilities[blocks[u]][blocks[v]]``, where ``blocks[u]`` is the block +/// membership of node ``u``. The number of nodes and the number of blocks are +/// inferred from ``sizes``. +/// +/// This algorithm has a time complexity of :math:`O(n^2)` for :math:`n` nodes. +/// +/// Arguments: +/// +/// :param list[int] sizes: Number of nodes in each block. +/// :param np.ndarray probabilities: B x B array that contains the connection +/// probability between nodes of different blocks. +/// :param bool loops: Determines whether the graph can have loops or not. +/// :param int seed: An optional seed to use for the random number generator. +/// +/// :return: A PyDiGraph object +/// :rtype: PyDiGraph +#[pyfunction] +#[pyo3(text_signature = "(sizes, probabilities, loops, /, seed=None)")] +pub fn directed_sbm_random_graph<'p>( + py: Python<'p>, + sizes: Vec, + probabilities: PyReadonlyArray2<'p, f64>, + loops: bool, + seed: Option, +) -> PyResult { + let default_fn = || py.None(); + let graph: StablePyGraph = match core_generators::sbm_random_graph( + &sizes, + &probabilities.as_array(), + loops, + seed, + default_fn, + default_fn, + ) { + Ok(graph) => graph, + Err(_) => { + return Err(PyValueError::new_err( + "invalid blocks or probabilities input", + )) + } + }; + Ok(digraph::PyDiGraph { + graph, + node_removed: false, + check_cycle: false, + cycle_state: algo::DfsSpace::default(), + multigraph: false, + attrs: py.None(), + }) +} + +/// Return an undirected graph from the stochastic block model. +/// +/// The stochastic block model is a generalization of the :math:`G(n,p)` random graph +/// (see :func:`~rustworkx.undirected_gnp_random_graph`). The connection probability of +/// nodes ``u`` and ``v`` depends on their block (or community) and is given by +/// ``probabilities[blocks[u]][blocks[v]]``, where ``blocks[u]`` is the block membership +/// of node ``u``. The number of nodes and the number of blocks are inferred from +/// ``sizes``. +/// +/// This algorithm has a time complexity of :math:`O(n^2)` for :math:`n` nodes. +/// +/// Arguments: +/// +/// :param list[int] sizes: Number of nodes in each block. +/// :param np.ndarray probabilities: Symmetric B x B array that contains the +/// connection probability between nodes of different blocks. +/// :param bool loops: Determines whether the graph can have loops or not. +/// :param int seed: An optional seed to use for the random number generator. +/// +/// :return: A PyGraph object +/// :rtype: PyGraph +#[pyfunction] +#[pyo3(text_signature = "(sizes, probabilities, loops, /, seed=None)")] +pub fn undirected_sbm_random_graph<'p>( + py: Python<'p>, + sizes: Vec, + probabilities: PyReadonlyArray2<'p, f64>, + loops: bool, + seed: Option, +) -> PyResult { + let default_fn = || py.None(); + let graph: StablePyGraph = match core_generators::sbm_random_graph( + &sizes, + &probabilities.as_array(), + loops, + seed, + default_fn, + default_fn, + ) { + Ok(graph) => graph, + Err(_) => { + return Err(PyValueError::new_err( + "invalid blocks or probabilities input", + )) + } + }; + Ok(graph::PyGraph { + graph, + node_removed: false, + multigraph: false, + attrs: py.None(), + }) +} + #[inline] fn pnorm(x: f64, p: f64) -> f64 { if p == 1.0 || p == std::f64::INFINITY { diff --git a/tests/test_random.py b/tests/test_random.py index 02cfcd36a..74f7668bb 100644 --- a/tests/test_random.py +++ b/tests/test_random.py @@ -14,6 +14,7 @@ import random import math +import numpy as np import rustworkx @@ -177,6 +178,87 @@ def test_random_gnm_undirected_payload(self): self.assertEqual(graph.nodes(), [0, 1, 2]) +class TestRandomSBM(unittest.TestCase): + def test_undirected_sbm_complete_blocks_loops(self): + graph = rustworkx.undirected_sbm_random_graph( + [2, 1], np.array([[1, 1], [1, 0]], dtype=float), True + ) + self.assertEqual(len(graph), 3) + self.assertEqual(len(graph.edges()), 5) + for i in range(2): + for j in range(i, 2): + if (i, j) != (2, 2): + self.assertTrue(graph.has_edge(i, j)) + self.assertFalse(graph.has_edge(2, 2)) + + def test_directed_sbm_complete_blocks_loops(self): + graph = rustworkx.directed_sbm_random_graph( + [2, 1], np.array([[0, 0], [1, 1]], dtype=float), True + ) + self.assertEqual(len(graph), 3) + self.assertEqual(len(graph.edges()), 3) + self.assertEqual(set(graph.edge_list()), set([(2, 2), (2, 0), (2, 1)])) + + def test_undirected_sbm_complete_blocks_noloops(self): + graph = rustworkx.undirected_sbm_random_graph( + [2, 1], np.array([[1, 1], [1, 0]], dtype=float), False + ) + self.assertEqual(len(graph), 3) + self.assertEqual(len(graph.edges()), 3) + for i in range(2): + for j in range(i, 2): + if i != j: + self.assertTrue(graph.has_edge(i, j)) + + def test_directed_sbm_complete_blocks_noloops(self): + graph = rustworkx.directed_sbm_random_graph( + [2, 1], np.array([[0, 0], [1, 1]], dtype=float), False + ) + self.assertEqual(len(graph), 3) + self.assertEqual(len(graph.edges()), 2) + self.assertEqual(set(graph.edge_list()), set([(2, 0), (2, 1)])) + + def test_undirected_sbm_asymmetric_probabilities_error(self): + with self.assertRaises(ValueError): + rustworkx.undirected_sbm_random_graph( + [2, 1], np.array([[0, 0], [1, 1]], dtype=float), True + ) + + def test_sbm_invalid_matrix_dim(self): + with self.assertRaises(ValueError): + rustworkx.undirected_sbm_random_graph( + [2, 1], np.array([[1, 0], [0, 1], [0, 1]], dtype=float), True + ) + with self.assertRaises(ValueError): + rustworkx.directed_sbm_random_graph( + [2, 1], np.array([[1, 0, 1], [0, 1, 0]], dtype=float), True + ) + + def test_sbm_invalid_probabilities(self): + with self.assertRaises(ValueError): + rustworkx.undirected_sbm_random_graph( + [2, 1], np.array([[1, 0], [0, 1.5]], dtype=float), True + ) + with self.assertRaises(ValueError): + rustworkx.undirected_sbm_random_graph( + [2, 1], np.array([[-1, 0], [0, 1]], dtype=float), True + ) + with self.assertRaises(ValueError): + rustworkx.directed_sbm_random_graph( + [2, 1], np.array([[1, 0], [0, 1.5]], dtype=float), True + ) + with self.assertRaises(ValueError): + rustworkx.directed_sbm_random_graph( + [2, 1], np.array([[-1, 0], [0, 1]], dtype=float), True + ) + + def test_sbm_empty(self): + with self.assertRaises(ValueError): + rustworkx.undirected_sbm_random_graph([], np.array([[]]), True) + with self.assertRaises(ValueError): + rustworkx.directed_sbm_random_graph([], np.array([[]]), True) + + class TestGeometricRandomGraph(unittest.TestCase): def test_random_geometric_empty(self): graph = rustworkx.random_geometric_graph(20, 0)