Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cache alignment for serial and parallel FFT and IFFT #245

Merged
merged 25 commits into from
Apr 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
### Features

- [\#230](https://github.com/arkworks-rs/algebra/pull/230) (ark-ec) Add `wnaf_mul` implementation for `ProjectiveCurve`.
- [\#245](https://github.com/arkworks-rs/algebra/pull/245) (ark-poly) Speedup the sequential and parallel radix-2 FFT and IFFT significantly by making the method in which it accesses roots more cache-friendly.
- [\#258](https://github.com/arkworks-rs/algebra/pull/258) (ark-poly) Add `Mul<F>` implementation for `DensePolynomial`

### Improvements
Expand Down
201 changes: 124 additions & 77 deletions poly/src/domain/radix2/fft.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
use crate::domain::utils::compute_powers_serial;
use crate::domain::{radix2::*, DomainCoeff};
use ark_ff::FftField;
use ark_std::{cfg_iter_mut, vec::Vec};
use ark_std::{cfg_chunks_mut, vec::Vec};
#[cfg(feature = "parallel")]
use rayon::prelude::*;

Expand Down Expand Up @@ -137,109 +137,156 @@ impl<F: FftField> Radix2EvaluationDomain<F> {
});
}

#[inline(always)]
fn butterfly_fn_io<T: DomainCoeff<F>>(((lo, hi), root): ((&mut T, &mut T), &F)) {
let neg = *lo - *hi;
*lo += *hi;
*hi = neg;
*hi *= *root;
}

#[inline(always)]
fn butterfly_fn_oi<T: DomainCoeff<F>>(((lo, hi), root): ((&mut T, &mut T), &F)) {
*hi *= *root;
let neg = *lo - *hi;
*lo += *hi;
*hi = neg;
}

fn apply_butterfly<T: DomainCoeff<F>, G: Fn(((&mut T, &mut T), &F)) + Copy + Sync + Send>(
g: G,
xi: &mut [T],
roots: &[F],
step: usize,
chunk_size: usize,
num_chunks: usize,
max_threads: usize,
gap: usize,
) {
cfg_chunks_mut!(xi, chunk_size).for_each(|cxi| {
let (lo, hi) = cxi.split_at_mut(gap);
// If the chunk is sufficiently big that parallelism helps,
// we parallelize the butterfly operation within the chunk.

if gap > MIN_GAP_SIZE_FOR_PARALLELISATION && num_chunks < max_threads {
cfg_iter_mut!(lo)
.zip(hi)
.zip(cfg_iter!(roots).step_by(step))
.for_each(g);
} else {
lo.iter_mut()
.zip(hi)
.zip(roots.iter().step_by(step))
.for_each(g);
}
});
}

fn io_helper<T: DomainCoeff<F>>(&self, xi: &mut [T], root: F) {
// In the sequential case, we will keep on making the roots cache-aligned,
// according to the access pattern that the FFT uses.
// It is left as a TODO to implement this for the parallel case
let roots = &mut self.roots_of_unity(root);
let mut roots = self.roots_of_unity(root);
let mut step = 1;
let mut first = true;

#[cfg(feature = "parallel")]
let max_threads = rayon::current_num_threads();
#[cfg(not(feature = "parallel"))]
let mut root_len = roots.len();
let max_threads = 1;

let mut gap = xi.len() / 2;
while gap > 0 {
// each butterfly cluster uses 2*gap positions
let chunk_size = 2 * gap;
#[cfg(feature = "parallel")]
let nchunks = xi.len() / chunk_size;

let butterfly_fn = |(chunk_index, (lo, hi)): (usize, (&mut T, &mut T))| {
let neg = *lo - *hi;
*lo += *hi;

*hi = neg;

#[cfg(feature = "parallel")]
let index = nchunks * chunk_index;
// Due to cache aligning, the index into roots is the chunk index
#[cfg(not(feature = "parallel"))]
let index = chunk_index;

*hi *= roots[index];
};

ark_std::cfg_chunks_mut!(xi, chunk_size).for_each(|cxi| {
let (lo, hi) = cxi.split_at_mut(gap);
// If the chunk is sufficiently big that parallelism helps,
// we parallelize the butterfly operation within the chunk.
//
// if chunk_size > MIN_CHUNK_SIZE_FOR_PARALLELIZATION
if gap > MIN_CHUNK_SIZE_FOR_PARALLELIZATION / 2 {
cfg_iter_mut!(lo).zip(hi).enumerate().for_each(butterfly_fn);
} else {
lo.iter_mut().zip(hi).enumerate().for_each(butterfly_fn);
let num_chunks = xi.len() / chunk_size;

// Only compact roots to achieve cache locality/compactness if
// the roots lookup is done a significant amount of times
// Which also implies a large lookup stride.
if num_chunks >= MIN_NUM_CHUNKS_FOR_COMPACTION {
if !first {
roots = cfg_into_iter!(roots).step_by(step * 2).collect()
}
});

// Cache align the FFT roots in the sequential case.
// In this case, we are aiming to make every root that is accessed one after another,
// appear one after another in the list of roots.
// if the roots are cache aligned in iteration i, then in iteration i+1,
// cache alignment requires selecting every other root.
// (The even powers relative to the current iterations generator)
//
//
// (Roots are already aligned in the first iteration,
// so we only to do realignment after the first iteration.)
#[cfg(not(feature = "parallel"))]
{
for i in 1..(root_len / 2) {
roots[i] = roots[i * 2];
}
root_len /= 2;
step = 1;
roots.shrink_to_fit();
} else {
step = num_chunks;
}
first = false;

Self::apply_butterfly(
Self::butterfly_fn_io,
xi,
&roots[..],
step,
chunk_size,
num_chunks,
max_threads,
gap,
);

gap /= 2;
}
}

fn oi_helper<T: DomainCoeff<F>>(&self, xi: &mut [T], root: F) {
let roots = self.roots_of_unity(root);
let roots_cache = self.roots_of_unity(root);

// The `cmp::min` is only necessary for the case where
// `MIN_NUM_CHUNKS_FOR_COMPACTION = 1`. Else, notice that we compact
// the roots cache by a stride of at least `MIN_NUM_CHUNKS_FOR_COMPACTION`.

let compaction_max_size = core::cmp::min(
roots_cache.len() / 2,
roots_cache.len() / MIN_NUM_CHUNKS_FOR_COMPACTION,
);
let mut compacted_roots = vec![F::default(); compaction_max_size];

#[cfg(feature = "parallel")]
let max_threads = rayon::current_num_threads();
#[cfg(not(feature = "parallel"))]
let max_threads = 1;

let mut gap = 1;
while gap < xi.len() {
// each butterfly cluster uses 2*gap positions
let chunk_size = 2 * gap;
let nchunks = xi.len() / chunk_size;
let num_chunks = xi.len() / chunk_size;

let butterfly_fn = |(idx, (lo, hi)): (usize, (&mut T, &mut T))| {
*hi *= roots[nchunks * idx];
let neg = *lo - *hi;
*lo += *hi;
*hi = neg;
// Only compact roots to achieve cache locality/compactness if
// the roots lookup is done a significant amount of times
// Which also implies a large lookup stride.
let (roots, step) = if num_chunks >= MIN_NUM_CHUNKS_FOR_COMPACTION && gap < xi.len() / 2
{
cfg_iter_mut!(compacted_roots[..gap])
.zip(cfg_iter!(roots_cache[..(gap * num_chunks)]).step_by(num_chunks))
.for_each(|(a, b)| *a = *b);
(&compacted_roots[..gap], 1)
} else {
(&roots_cache[..], num_chunks)
};

ark_std::cfg_chunks_mut!(xi, chunk_size).for_each(|cxi| {
let (lo, hi) = cxi.split_at_mut(gap);
// If the chunk is sufficiently big that parallelism helps,
// we parallelize the butterfly operation within the chunk.
//
// if chunk_size > MIN_CHUNK_SIZE_FOR_PARALLELIZATION
if gap > MIN_CHUNK_SIZE_FOR_PARALLELIZATION / 2 {
cfg_iter_mut!(lo).zip(hi).enumerate().for_each(butterfly_fn);
} else {
lo.iter_mut().zip(hi).enumerate().for_each(butterfly_fn);
}
});
Self::apply_butterfly(
Self::butterfly_fn_oi,
xi,
roots,
step,
chunk_size,
num_chunks,
max_threads,
gap,
);

gap *= 2;
}
}
}

// This value controls that when doing a butterfly on a chunk of size c,
// do you parallelize operations on the chunk.
// If c > MIN_CHUNK_SIZE_FOR_PARALLELIZATION,
// then parallelize, else be sequential.
// This value was chosen empirically.
const MIN_CHUNK_SIZE_FOR_PARALLELIZATION: usize = 2048;
/// The minimum number of chunks at which root compaction
/// is beneficial.
const MIN_NUM_CHUNKS_FOR_COMPACTION: usize = 1 << 7;

/// The minimum size of a chunk at which parallelization of `butterfly`s is beneficial.
/// This value was chosen empirically.
const MIN_GAP_SIZE_FOR_PARALLELISATION: usize = 1 << 10;
jon-chuang marked this conversation as resolved.
Show resolved Hide resolved

// minimum size at which to parallelize.
#[cfg(feature = "parallel")]
Expand Down