Skip to content

Commit

Permalink
Cleaned up index layouts
Browse files Browse the repository at this point in the history
  • Loading branch information
tbetcke committed Dec 15, 2024
1 parent 500cd48 commit d12b694
Show file tree
Hide file tree
Showing 5 changed files with 192 additions and 75 deletions.
61 changes: 61 additions & 0 deletions examples/map_index_layout.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
//! Map betwen two index layouts
use bempp_distributed_tools::{
index_layout::{IndexLayout, IndexLayoutFromLocalCounts},
EquiDistributedIndexLayout,
};
use itertools::{izip, Itertools};
use mpi::traits::Communicator;

fn main() {
let universe = mpi::initialize().unwrap();
let world = universe.world();

// Create an index layout with 10 indices on each rank.

let layout1 = EquiDistributedIndexLayout::new(30, 1, &world);

// Create a second layout with 5 indices on rank 0, 17 on rank 1 and 8 on rank 2.

let counts = match world.rank() {
0 => 5,
1 => 17,
2 => 8,
_ => panic!("This example only works with three processes."),
};

let layout2 = IndexLayoutFromLocalCounts::new(counts, &world);

// Now we can map between the two layouts.

let data = if world.rank() == 0 {
(0..10).collect_vec()
} else if world.rank() == 1 {
(10..20).collect_vec()
} else {
(20..30).collect_vec()
};

let mapped_data = layout1.remap(&layout2, &data);

if world.rank() == 0 {
assert_eq!(mapped_data.len(), 5);
for (expected, &actual) in izip!(0..5, mapped_data.iter()) {
assert_eq!(expected, actual);
}
} else if world.rank() == 1 {
assert_eq!(mapped_data.len(), 17);
for (expected, &actual) in izip!(5..22, mapped_data.iter()) {
assert_eq!(expected, actual);
}
} else if world.rank() == 2 {
assert_eq!(mapped_data.len(), 8);
for (expected, &actual) in izip!(22..30, mapped_data.iter()) {
assert_eq!(expected, actual);
}
}

let remapped_data = layout2.remap(&layout1, &mapped_data);

assert_eq!(data, remapped_data);
}
93 changes: 86 additions & 7 deletions src/index_layout.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,13 @@
/// index `last` is not contained on the process. If `first == last` then there is no index on
/// the local process.
mod equidistributed_index_layout;
mod index_layout_from_local_counts;
pub use equidistributed_index_layout::EquiDistributedIndexLayout;
pub use index_layout_from_local_counts::IndexLayoutFromLocalCounts;
use itertools::Itertools;
use mpi::traits::{Communicator, Equivalence};

use crate::array_tools::redistribute;

// An index layout specifying index ranges on each rank.
//
Expand All @@ -19,33 +25,106 @@ pub use equidistributed_index_layout::EquiDistributedIndexLayout;
pub trait IndexLayout {
/// MPI Communicator;
type Comm: mpi::topology::Communicator;

/// The cumulative sum of indices over the ranks.
///
/// The number of indices on rank is is counts[1 + i] - counts[i].
/// The last entry is the total number of indices.
fn counts(&self) -> &[usize];

/// The local index range. If there is no local index
/// the left and right bound are identical.
fn local_range(&self) -> (usize, usize);
fn local_range(&self) -> (usize, usize) {
let counts = self.counts();
(
counts[self.comm().rank() as usize],
counts[1 + self.comm().rank() as usize],
)
}

/// The number of global indices.
fn number_of_global_indices(&self) -> usize;
fn number_of_global_indices(&self) -> usize {
*self.counts().last().unwrap()
}

/// The number of local indicies, that is the amount of indicies
/// on my process.
fn number_of_local_indices(&self) -> usize;
fn number_of_local_indices(&self) -> usize {
let counts = self.counts();
counts[1 + self.comm().rank() as usize] - counts[self.comm().rank() as usize]
}

/// Index range on a given process.
fn index_range(&self, rank: usize) -> Option<(usize, usize)>;
fn index_range(&self, rank: usize) -> Option<(usize, usize)> {
let counts = self.counts();
if rank < self.comm().size() as usize {
Some((counts[rank], counts[1 + rank]))
} else {
None
}
}

/// Convert continuous (0, n) indices to actual indices.
///
/// Assume that the local range is (30, 40). Then this method
/// will map (0,10) -> (30, 40).
/// It returns ```None``` if ```index``` is out of bounds.
fn local2global(&self, index: usize) -> Option<usize>;
fn local2global(&self, index: usize) -> Option<usize> {
let rank = self.comm().rank() as usize;
if index < self.number_of_local_indices() {
Some(self.counts()[rank] + index)
} else {
None
}
}

/// Convert global index to local index on a given rank.
/// Returns ```None``` if index does not exist on rank.
fn global2local(&self, rank: usize, index: usize) -> Option<usize>;
fn global2local(&self, rank: usize, index: usize) -> Option<usize> {
if let Some(index_range) = self.index_range(rank) {
if index >= index_range.1 {
return None;
}

Some(index - index_range.0)
} else {
None
}
}

/// Get the rank of a given index.
fn rank_from_index(&self, index: usize) -> Option<usize>;
fn rank_from_index(&self, index: usize) -> Option<usize> {
for (count_index, &count) in self.counts()[1..].iter().enumerate() {
if index < count {
return Some(count_index);
}
}
None
}

/// Remap indices from one layout to another.
fn remap<L: IndexLayout, T: Equivalence>(&self, other: &L, data: &[T]) -> Vec<T> {
assert_eq!(data.len(), self.number_of_local_indices());
assert_eq!(
self.number_of_global_indices(),
other.number_of_global_indices()
);

let my_range = self.local_range();

let other_bins = (0..other.comm().size() as usize)
.map(|rank| other.index_range(rank).unwrap().0)
.collect_vec();

let sorted_keys = (my_range.0..my_range.1).collect_vec();

let counts = crate::array_tools::sort_to_bins(&sorted_keys, &other_bins)
.iter()
.map(|&key| key as i32)
.collect_vec();

redistribute(data, &counts, other.comm())
}

/// Return the communicator.
fn comm(&self) -> &Self::Comm;
Expand Down
65 changes: 6 additions & 59 deletions src/index_layout/equidistributed_index_layout.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,20 @@ use mpi::traits::Communicator;

/// Default index layout
pub struct EquiDistributedIndexLayout<'a, C: Communicator> {
size: usize,
my_rank: usize,
counts: Vec<usize>,
comm: &'a C,
}

impl<'a, C: Communicator> EquiDistributedIndexLayout<'a, C> {
/// Crate new
pub fn new(nchunks: usize, chunk_size: usize, comm: &'a C) -> Self {
let size = nchunks * chunk_size;
let nindices = nchunks * chunk_size;
let comm_size = comm.size() as usize;

assert!(
comm_size > 0,
"Group size is zero. At least one process needs to be in the group."
);
let my_rank = comm.rank() as usize;
let mut counts = vec![0; 1 + comm_size];

// The following code computes what index is on what rank. No MPI operation necessary.
Expand All @@ -37,10 +34,10 @@ impl<'a, C: Communicator> EquiDistributedIndexLayout<'a, C> {
}

for item in counts.iter_mut().take(comm_size).skip(nchunks) {
*item = size;
*item = nindices;
}

counts[comm_size] = size;
counts[comm_size] = nindices;
} else {
// We want to equally distribute the range
// among the ranks. Assume that we have 12
Expand Down Expand Up @@ -75,65 +72,15 @@ impl<'a, C: Communicator> EquiDistributedIndexLayout<'a, C> {
}
}

Self {
size,
my_rank,
counts,
comm,
}
Self { counts, comm }
}
}

impl<C: Communicator> IndexLayout for EquiDistributedIndexLayout<'_, C> {
type Comm = C;

fn index_range(&self, rank: usize) -> Option<(usize, usize)> {
if rank < self.comm.size() as usize {
Some((self.counts[rank], self.counts[1 + rank]))
} else {
None
}
}

fn local_range(&self) -> (usize, usize) {
self.index_range(self.my_rank).unwrap()
}

fn number_of_local_indices(&self) -> usize {
self.counts[1 + self.my_rank] - self.counts[self.my_rank]
}

fn number_of_global_indices(&self) -> usize {
self.size
}

fn local2global(&self, index: usize) -> Option<usize> {
if index < self.number_of_local_indices() {
Some(self.counts[self.my_rank] + index)
} else {
None
}
}

fn global2local(&self, rank: usize, index: usize) -> Option<usize> {
if let Some(index_range) = self.index_range(rank) {
if index >= index_range.1 {
return None;
}

Some(index - index_range.0)
} else {
None
}
}

fn rank_from_index(&self, index: usize) -> Option<usize> {
for (count_index, &count) in self.counts[1..].iter().enumerate() {
if index < count {
return Some(count_index);
}
}
None
fn counts(&self) -> &[usize] {
&self.counts
}

fn comm(&self) -> &Self::Comm {
Expand Down
34 changes: 34 additions & 0 deletions src/index_layout/index_layout_from_local_counts.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
//! Default distributed index layout
use crate::index_layout::IndexLayout;
use mpi::traits::{Communicator, CommunicatorCollectives};

/// Specify an index layout from local variable counts
pub struct IndexLayoutFromLocalCounts<'a, C: Communicator> {
counts: Vec<usize>,
comm: &'a C,
}

impl<'a, C: Communicator + CommunicatorCollectives> IndexLayoutFromLocalCounts<'a, C> {
/// Crate new
pub fn new(local_count: usize, comm: &'a C) -> Self {
let size = comm.size() as usize;
let mut counts = vec![0; size + 1];
comm.all_gather_into(&local_count, &mut counts[1..]);
for i in 1..=size {
counts[i] += counts[i - 1];
}
Self { counts, comm }
}
}

impl<C: Communicator> IndexLayout for IndexLayoutFromLocalCounts<'_, C> {
type Comm = C;

fn counts(&self) -> &[usize] {
&self.counts
}

fn comm(&self) -> &Self::Comm {
&self.comm

Check failure on line 32 in src/index_layout/index_layout_from_local_counts.rs

View workflow job for this annotation

GitHub Actions / Rust style checks

this expression creates a reference which is immediately dereferenced by the compiler
}
}
14 changes: 5 additions & 9 deletions src/permutation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use crate::index_layout::IndexLayout;
/// Permuation of data.
pub struct DataPermutation<'a, L: IndexLayout> {
index_layout: &'a L,
custom_indices: &'a [usize],
nindices: usize,
my_rank: usize,
custom_local_indices: Vec<usize>,
local_to_custom_map: Vec<usize>,
Expand All @@ -21,11 +21,7 @@ pub struct DataPermutation<'a, L: IndexLayout> {

impl<'a, L: IndexLayout> DataPermutation<'a, L> {
/// Create a new permutation object.
pub fn new<C: Communicator>(
index_layout: &'a L,
custom_indices: &'a [usize],
comm: &C,
) -> Self {
pub fn new<C: Communicator>(index_layout: &'a L, custom_indices: &[usize], comm: &C) -> Self {
// We first need to identify which custom indices are local and which are global.

let my_rank = comm.rank() as usize;
Expand Down Expand Up @@ -72,7 +68,7 @@ impl<'a, L: IndexLayout> DataPermutation<'a, L> {

Self {
index_layout,
custom_indices,
nindices: custom_indices.len(),
my_rank,
custom_local_indices,
local_to_custom_map,
Expand All @@ -88,7 +84,7 @@ impl<'a, L: IndexLayout> DataPermutation<'a, L> {
permuted_data: &mut [T],
) {
assert_eq!(data.len(), self.index_layout.number_of_local_indices());
assert_eq!(permuted_data.len(), self.custom_indices.len());
assert_eq!(permuted_data.len(), self.nindices);

// We first need to get the send data. This is quite easy. We can just
// use the global2local method from the index layout.
Expand Down Expand Up @@ -127,7 +123,7 @@ impl<'a, L: IndexLayout> DataPermutation<'a, L> {
data: &[T],
permuted_data: &mut [T],
) {
assert_eq!(data.len(), self.custom_indices.len());
assert_eq!(data.len(), self.nindices);
assert_eq!(
permuted_data.len(),
self.index_layout.number_of_local_indices()
Expand Down

0 comments on commit d12b694

Please sign in to comment.