Skip to content

Commit

Permalink
Simplifying betweenness_centrality (for vertices) (#815)
Browse files Browse the repository at this point in the history
* Simplifying betweenness_centrality (for vertices)

* Run cargo fmt

* Remove unused import

---------

Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
jlapeyre and mergify[bot] authored Mar 10, 2023
1 parent 9dd7a8a commit 9f56767
Showing 1 changed file with 38 additions and 109 deletions.
147 changes: 38 additions & 109 deletions rustworkx-core/src/centrality.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ use petgraph::visit::{
Reversed,
Visitable,
};
use rayon::prelude::*;
use rayon_cond::CondIterator;

/// Compute the betweenness centrality of all nodes in a graph.
Expand All @@ -50,7 +49,7 @@ use rayon_cond::CondIterator;
/// Arguments:
///
/// * `graph` - The graph object to run the algorithm on
/// * `endpoints` - Whether to include the endpoints of paths in the path
/// * `include_endpoints` - Whether to include the endpoints of paths in the path
/// lengths used to compute the betweenness
/// * `normalized` - Whether to normalize the betweenness scores by the number
/// of distinct paths between all pairs of nodes
Expand Down Expand Up @@ -79,7 +78,7 @@ use rayon_cond::CondIterator;
/// [`edge_betweenness_centrality`]
pub fn betweenness_centrality<G>(
graph: G,
endpoints: bool,
include_endpoints: bool,
normalized: bool,
parallel_threshold: usize,
) -> Vec<Option<f64>>
Expand Down Expand Up @@ -115,73 +114,27 @@ where
betweenness[is] = Some(0.0);
}
let locked_betweenness = RwLock::new(&mut betweenness);
let node_indices: Vec<usize> = graph
.node_identifiers()
.map(|i| graph.to_index(i))
.collect();
if graph.node_count() < parallel_threshold {
node_indices
.iter()
.map(|node_s| {
(
shortest_path_for_centrality(&graph, &graph.from_index(*node_s)),
*node_s,
)
})
.for_each(|(mut shortest_path_calc, is)| {
if endpoints {
_accumulate_endpoints(
&locked_betweenness,
max_index,
&mut shortest_path_calc,
is,
&graph,
);
} else {
_accumulate_basic(
&locked_betweenness,
max_index,
&mut shortest_path_calc,
is,
&graph,
);
}
});
} else {
node_indices
.par_iter()
.map(|node_s| {
(
shortest_path_for_centrality(&graph, &graph.from_index(*node_s)),
node_s,
)
})
.for_each(|(mut shortest_path_calc, is)| {
if endpoints {
_accumulate_endpoints(
&locked_betweenness,
max_index,
&mut shortest_path_calc,
*is,
&graph,
);
} else {
_accumulate_basic(
&locked_betweenness,
max_index,
&mut shortest_path_calc,
*is,
&graph,
);
}
});
}
let node_indices: Vec<G::NodeId> = graph.node_identifiers().collect();

CondIterator::new(node_indices, graph.node_count() >= parallel_threshold)
.map(|node_s| (shortest_path_for_centrality(&graph, &node_s), node_s))
.for_each(|(mut shortest_path_calc, node_s)| {
_accumulate_vertices(
&locked_betweenness,
max_index,
&mut shortest_path_calc,
node_s,
&graph,
include_endpoints,
);
});

_rescale(
&mut betweenness,
graph.node_count(),
normalized,
graph.is_directed(),
endpoints,
include_endpoints,
);

betweenness
Expand Down Expand Up @@ -279,12 +232,12 @@ fn _rescale(
node_count: usize,
normalized: bool,
directed: bool,
endpoints: bool,
include_endpoints: bool,
) {
let mut do_scale = true;
let mut scale = 1.0;
if normalized {
if endpoints {
if include_endpoints {
if node_count < 2 {
do_scale = false;
} else {
Expand All @@ -307,12 +260,13 @@ fn _rescale(
}
}

fn _accumulate_basic<G>(
fn _accumulate_vertices<G>(
locked_betweenness: &RwLock<&mut Vec<Option<f64>>>,
max_index: usize,
path_calc: &mut ShortestPathData<G>,
is: usize,
node_s: <G as GraphBase>::NodeId,
graph: G,
include_endpoints: bool,
) where
G: NodeIndexable
+ IntoNodeIdentifiers
Expand All @@ -334,47 +288,22 @@ fn _accumulate_basic<G>(
}
}
let mut betweenness = locked_betweenness.write().unwrap();
for w in &path_calc.verts_sorted_by_distance {
let iw = graph.to_index(*w);
if iw != is {
betweenness[iw] = betweenness[iw].map(|x| x + delta[iw]);
}
}
}

fn _accumulate_endpoints<G>(
locked_betweenness: &RwLock<&mut Vec<Option<f64>>>,
max_index: usize,
path_calc: &mut ShortestPathData<G>,
is: usize,
graph: G,
) where
G: NodeIndexable
+ IntoNodeIdentifiers
+ IntoNeighborsDirected
+ NodeCount
+ GraphProp
+ GraphBase
+ std::marker::Sync,
<G as GraphBase>::NodeId: std::cmp::Eq + Hash,
{
let mut delta = vec![0.0; max_index];
for w in &path_calc.verts_sorted_by_distance {
let iw = graph.to_index(*w);
let coeff = (1.0 + delta[iw]) / path_calc.sigma[w];
let p_w = path_calc.predecessors.get(w).unwrap();
for v in p_w {
let iv = graph.to_index(*v);
delta[iv] += path_calc.sigma[v] * coeff;
if include_endpoints {
let i_node_s = graph.to_index(node_s);
betweenness[i_node_s] = betweenness[i_node_s]
.map(|x| x + ((path_calc.verts_sorted_by_distance.len() - 1) as f64));
for w in &path_calc.verts_sorted_by_distance {
if *w != node_s {
let iw = graph.to_index(*w);
betweenness[iw] = betweenness[iw].map(|x| x + delta[iw] + 1.0);
}
}
}
let mut betweenness = locked_betweenness.write().unwrap();
betweenness[is] =
betweenness[is].map(|x| x + ((path_calc.verts_sorted_by_distance.len() - 1) as f64));
for w in &path_calc.verts_sorted_by_distance {
let iw = graph.to_index(*w);
if iw != is {
betweenness[iw] = betweenness[iw].map(|x| x + delta[iw] + 1.0);
} else {
for w in &path_calc.verts_sorted_by_distance {
if *w != node_s {
let iw = graph.to_index(*w);
betweenness[iw] = betweenness[iw].map(|x| x + delta[iw]);
}
}
}
}
Expand Down

0 comments on commit 9f56767

Please sign in to comment.