Skip to content

Commit

Permalink
Cache alignment for serial and parallel FFT and IFFT (#245)
Browse files Browse the repository at this point in the history
Co-authored-by: Pratyush Mishra <pratyushmishra@berkeley.edu>
  • Loading branch information
jon-chuang and Pratyush authored Apr 8, 2021
1 parent 0bd355b commit e504bda
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 77 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
### Features

- [\#230](https://github.com/arkworks-rs/algebra/pull/230) (ark-ec) Add `wnaf_mul` implementation for `ProjectiveCurve`.
- [\#245](https://github.com/arkworks-rs/algebra/pull/245) (ark-poly) Speedup the sequential and parallel radix-2 FFT and IFFT significantly by making the method in which it accesses roots more cache-friendly.
- [\#258](https://github.com/arkworks-rs/algebra/pull/258) (ark-poly) Add `Mul<F>` implementation for `DensePolynomial`

### Improvements
Expand Down
201 changes: 124 additions & 77 deletions poly/src/domain/radix2/fft.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
use crate::domain::utils::compute_powers_serial;
use crate::domain::{radix2::*, DomainCoeff};
use ark_ff::FftField;
use ark_std::{cfg_iter_mut, vec::Vec};
use ark_std::{cfg_chunks_mut, vec::Vec};
#[cfg(feature = "parallel")]
use rayon::prelude::*;

Expand Down Expand Up @@ -137,109 +137,156 @@ impl<F: FftField> Radix2EvaluationDomain<F> {
});
}

#[inline(always)]
fn butterfly_fn_io<T: DomainCoeff<F>>(((lo, hi), root): ((&mut T, &mut T), &F)) {
let neg = *lo - *hi;
*lo += *hi;
*hi = neg;
*hi *= *root;
}

#[inline(always)]
fn butterfly_fn_oi<T: DomainCoeff<F>>(((lo, hi), root): ((&mut T, &mut T), &F)) {
*hi *= *root;
let neg = *lo - *hi;
*lo += *hi;
*hi = neg;
}

fn apply_butterfly<T: DomainCoeff<F>, G: Fn(((&mut T, &mut T), &F)) + Copy + Sync + Send>(
g: G,
xi: &mut [T],
roots: &[F],
step: usize,
chunk_size: usize,
num_chunks: usize,
max_threads: usize,
gap: usize,
) {
cfg_chunks_mut!(xi, chunk_size).for_each(|cxi| {
let (lo, hi) = cxi.split_at_mut(gap);
// If the chunk is sufficiently big that parallelism helps,
// we parallelize the butterfly operation within the chunk.

if gap > MIN_GAP_SIZE_FOR_PARALLELISATION && num_chunks < max_threads {
cfg_iter_mut!(lo)
.zip(hi)
.zip(cfg_iter!(roots).step_by(step))
.for_each(g);
} else {
lo.iter_mut()
.zip(hi)
.zip(roots.iter().step_by(step))
.for_each(g);
}
});
}

fn io_helper<T: DomainCoeff<F>>(&self, xi: &mut [T], root: F) {
// In the sequential case, we will keep on making the roots cache-aligned,
// according to the access pattern that the FFT uses.
// It is left as a TODO to implement this for the parallel case
let roots = &mut self.roots_of_unity(root);
let mut roots = self.roots_of_unity(root);
let mut step = 1;
let mut first = true;

#[cfg(feature = "parallel")]
let max_threads = rayon::current_num_threads();
#[cfg(not(feature = "parallel"))]
let mut root_len = roots.len();
let max_threads = 1;

let mut gap = xi.len() / 2;
while gap > 0 {
// each butterfly cluster uses 2*gap positions
let chunk_size = 2 * gap;
#[cfg(feature = "parallel")]
let nchunks = xi.len() / chunk_size;

let butterfly_fn = |(chunk_index, (lo, hi)): (usize, (&mut T, &mut T))| {
let neg = *lo - *hi;
*lo += *hi;

*hi = neg;

#[cfg(feature = "parallel")]
let index = nchunks * chunk_index;
// Due to cache aligning, the index into roots is the chunk index
#[cfg(not(feature = "parallel"))]
let index = chunk_index;

*hi *= roots[index];
};

ark_std::cfg_chunks_mut!(xi, chunk_size).for_each(|cxi| {
let (lo, hi) = cxi.split_at_mut(gap);
// If the chunk is sufficiently big that parallelism helps,
// we parallelize the butterfly operation within the chunk.
//
// if chunk_size > MIN_CHUNK_SIZE_FOR_PARALLELIZATION
if gap > MIN_CHUNK_SIZE_FOR_PARALLELIZATION / 2 {
cfg_iter_mut!(lo).zip(hi).enumerate().for_each(butterfly_fn);
} else {
lo.iter_mut().zip(hi).enumerate().for_each(butterfly_fn);
let num_chunks = xi.len() / chunk_size;

// Only compact roots to achieve cache locality/compactness if
// the roots lookup is done a significant amount of times
// Which also implies a large lookup stride.
if num_chunks >= MIN_NUM_CHUNKS_FOR_COMPACTION {
if !first {
roots = cfg_into_iter!(roots).step_by(step * 2).collect()
}
});

// Cache align the FFT roots in the sequential case.
// In this case, we are aiming to make every root that is accessed one after another,
// appear one after another in the list of roots.
// if the roots are cache aligned in iteration i, then in iteration i+1,
// cache alignment requires selecting every other root.
// (The even powers relative to the current iterations generator)
//
//
// (Roots are already aligned in the first iteration,
// so we only to do realignment after the first iteration.)
#[cfg(not(feature = "parallel"))]
{
for i in 1..(root_len / 2) {
roots[i] = roots[i * 2];
}
root_len /= 2;
step = 1;
roots.shrink_to_fit();
} else {
step = num_chunks;
}
first = false;

Self::apply_butterfly(
Self::butterfly_fn_io,
xi,
&roots[..],
step,
chunk_size,
num_chunks,
max_threads,
gap,
);

gap /= 2;
}
}

fn oi_helper<T: DomainCoeff<F>>(&self, xi: &mut [T], root: F) {
let roots = self.roots_of_unity(root);
let roots_cache = self.roots_of_unity(root);

// The `cmp::min` is only necessary for the case where
// `MIN_NUM_CHUNKS_FOR_COMPACTION = 1`. Else, notice that we compact
// the roots cache by a stride of at least `MIN_NUM_CHUNKS_FOR_COMPACTION`.

let compaction_max_size = core::cmp::min(
roots_cache.len() / 2,
roots_cache.len() / MIN_NUM_CHUNKS_FOR_COMPACTION,
);
let mut compacted_roots = vec![F::default(); compaction_max_size];

#[cfg(feature = "parallel")]
let max_threads = rayon::current_num_threads();
#[cfg(not(feature = "parallel"))]
let max_threads = 1;

let mut gap = 1;
while gap < xi.len() {
// each butterfly cluster uses 2*gap positions
let chunk_size = 2 * gap;
let nchunks = xi.len() / chunk_size;
let num_chunks = xi.len() / chunk_size;

let butterfly_fn = |(idx, (lo, hi)): (usize, (&mut T, &mut T))| {
*hi *= roots[nchunks * idx];
let neg = *lo - *hi;
*lo += *hi;
*hi = neg;
// Only compact roots to achieve cache locality/compactness if
// the roots lookup is done a significant amount of times
// Which also implies a large lookup stride.
let (roots, step) = if num_chunks >= MIN_NUM_CHUNKS_FOR_COMPACTION && gap < xi.len() / 2
{
cfg_iter_mut!(compacted_roots[..gap])
.zip(cfg_iter!(roots_cache[..(gap * num_chunks)]).step_by(num_chunks))
.for_each(|(a, b)| *a = *b);
(&compacted_roots[..gap], 1)
} else {
(&roots_cache[..], num_chunks)
};

ark_std::cfg_chunks_mut!(xi, chunk_size).for_each(|cxi| {
let (lo, hi) = cxi.split_at_mut(gap);
// If the chunk is sufficiently big that parallelism helps,
// we parallelize the butterfly operation within the chunk.
//
// if chunk_size > MIN_CHUNK_SIZE_FOR_PARALLELIZATION
if gap > MIN_CHUNK_SIZE_FOR_PARALLELIZATION / 2 {
cfg_iter_mut!(lo).zip(hi).enumerate().for_each(butterfly_fn);
} else {
lo.iter_mut().zip(hi).enumerate().for_each(butterfly_fn);
}
});
Self::apply_butterfly(
Self::butterfly_fn_oi,
xi,
roots,
step,
chunk_size,
num_chunks,
max_threads,
gap,
);

gap *= 2;
}
}
}

// This value controls that when doing a butterfly on a chunk of size c,
// do you parallelize operations on the chunk.
// If c > MIN_CHUNK_SIZE_FOR_PARALLELIZATION,
// then parallelize, else be sequential.
// This value was chosen empirically.
const MIN_CHUNK_SIZE_FOR_PARALLELIZATION: usize = 2048;
/// The minimum number of chunks at which root compaction
/// is beneficial.
const MIN_NUM_CHUNKS_FOR_COMPACTION: usize = 1 << 7;

/// The minimum size of a chunk at which parallelization of `butterfly`s is beneficial.
/// This value was chosen empirically.
const MIN_GAP_SIZE_FOR_PARALLELISATION: usize = 1 << 10;

// minimum size at which to parallelize.
#[cfg(feature = "parallel")]
Expand Down

0 comments on commit e504bda

Please sign in to comment.