diff --git a/rustworkx-core/src/centrality.rs b/rustworkx-core/src/centrality.rs index a816af427..d454741a9 100644 --- a/rustworkx-core/src/centrality.rs +++ b/rustworkx-core/src/centrality.rs @@ -32,7 +32,6 @@ use petgraph::visit::{ Reversed, Visitable, }; -use rayon::prelude::*; use rayon_cond::CondIterator; /// Compute the betweenness centrality of all nodes in a graph. @@ -50,7 +49,7 @@ use rayon_cond::CondIterator; /// Arguments: /// /// * `graph` - The graph object to run the algorithm on -/// * `endpoints` - Whether to include the endpoints of paths in the path +/// * `include_endpoints` - Whether to include the endpoints of paths in the path /// lengths used to compute the betweenness /// * `normalized` - Whether to normalize the betweenness scores by the number /// of distinct paths between all pairs of nodes @@ -79,7 +78,7 @@ use rayon_cond::CondIterator; /// [`edge_betweenness_centrality`] pub fn betweenness_centrality( graph: G, - endpoints: bool, + include_endpoints: bool, normalized: bool, parallel_threshold: usize, ) -> Vec> @@ -115,73 +114,27 @@ where betweenness[is] = Some(0.0); } let locked_betweenness = RwLock::new(&mut betweenness); - let node_indices: Vec = graph - .node_identifiers() - .map(|i| graph.to_index(i)) - .collect(); - if graph.node_count() < parallel_threshold { - node_indices - .iter() - .map(|node_s| { - ( - shortest_path_for_centrality(&graph, &graph.from_index(*node_s)), - *node_s, - ) - }) - .for_each(|(mut shortest_path_calc, is)| { - if endpoints { - _accumulate_endpoints( - &locked_betweenness, - max_index, - &mut shortest_path_calc, - is, - &graph, - ); - } else { - _accumulate_basic( - &locked_betweenness, - max_index, - &mut shortest_path_calc, - is, - &graph, - ); - } - }); - } else { - node_indices - .par_iter() - .map(|node_s| { - ( - shortest_path_for_centrality(&graph, &graph.from_index(*node_s)), - node_s, - ) - }) - .for_each(|(mut shortest_path_calc, is)| { - if endpoints { - _accumulate_endpoints( - &locked_betweenness, - max_index, - &mut shortest_path_calc, - *is, - &graph, - ); - } else { - _accumulate_basic( - &locked_betweenness, - max_index, - &mut shortest_path_calc, - *is, - &graph, - ); - } - }); - } + let node_indices: Vec = graph.node_identifiers().collect(); + + CondIterator::new(node_indices, graph.node_count() >= parallel_threshold) + .map(|node_s| (shortest_path_for_centrality(&graph, &node_s), node_s)) + .for_each(|(mut shortest_path_calc, node_s)| { + _accumulate_vertices( + &locked_betweenness, + max_index, + &mut shortest_path_calc, + node_s, + &graph, + include_endpoints, + ); + }); + _rescale( &mut betweenness, graph.node_count(), normalized, graph.is_directed(), - endpoints, + include_endpoints, ); betweenness @@ -279,12 +232,12 @@ fn _rescale( node_count: usize, normalized: bool, directed: bool, - endpoints: bool, + include_endpoints: bool, ) { let mut do_scale = true; let mut scale = 1.0; if normalized { - if endpoints { + if include_endpoints { if node_count < 2 { do_scale = false; } else { @@ -307,12 +260,13 @@ fn _rescale( } } -fn _accumulate_basic( +fn _accumulate_vertices( locked_betweenness: &RwLock<&mut Vec>>, max_index: usize, path_calc: &mut ShortestPathData, - is: usize, + node_s: ::NodeId, graph: G, + include_endpoints: bool, ) where G: NodeIndexable + IntoNodeIdentifiers @@ -334,47 +288,22 @@ fn _accumulate_basic( } } let mut betweenness = locked_betweenness.write().unwrap(); - for w in &path_calc.verts_sorted_by_distance { - let iw = graph.to_index(*w); - if iw != is { - betweenness[iw] = betweenness[iw].map(|x| x + delta[iw]); - } - } -} - -fn _accumulate_endpoints( - locked_betweenness: &RwLock<&mut Vec>>, - max_index: usize, - path_calc: &mut ShortestPathData, - is: usize, - graph: G, -) where - G: NodeIndexable - + IntoNodeIdentifiers - + IntoNeighborsDirected - + NodeCount - + GraphProp - + GraphBase - + std::marker::Sync, - ::NodeId: std::cmp::Eq + Hash, -{ - let mut delta = vec![0.0; max_index]; - for w in &path_calc.verts_sorted_by_distance { - let iw = graph.to_index(*w); - let coeff = (1.0 + delta[iw]) / path_calc.sigma[w]; - let p_w = path_calc.predecessors.get(w).unwrap(); - for v in p_w { - let iv = graph.to_index(*v); - delta[iv] += path_calc.sigma[v] * coeff; + if include_endpoints { + let i_node_s = graph.to_index(node_s); + betweenness[i_node_s] = betweenness[i_node_s] + .map(|x| x + ((path_calc.verts_sorted_by_distance.len() - 1) as f64)); + for w in &path_calc.verts_sorted_by_distance { + if *w != node_s { + let iw = graph.to_index(*w); + betweenness[iw] = betweenness[iw].map(|x| x + delta[iw] + 1.0); + } } - } - let mut betweenness = locked_betweenness.write().unwrap(); - betweenness[is] = - betweenness[is].map(|x| x + ((path_calc.verts_sorted_by_distance.len() - 1) as f64)); - for w in &path_calc.verts_sorted_by_distance { - let iw = graph.to_index(*w); - if iw != is { - betweenness[iw] = betweenness[iw].map(|x| x + delta[iw] + 1.0); + } else { + for w in &path_calc.verts_sorted_by_distance { + if *w != node_s { + let iw = graph.to_index(*w); + betweenness[iw] = betweenness[iw].map(|x| x + delta[iw]); + } } } }