Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cache alignment for serial and parallel FFT and IFFT #245

Merged
merged 25 commits into from
Apr 8, 2021
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ The main features of this release are:
- [\#207](https://github.com/arkworks-rs/algebra/pull/207) (ark-ff) Improve performance of extension fields when the non-residue is negative. (Improves fq2, fq12, and g2 speed on bls12 and bn curves)
- [\#211](https://github.com/arkworks-rs/algebra/pull/211) (ark-ec) Improve performance of BLS12 final exponentiation.
- [\#214](https://github.com/arkworks-rs/algebra/pull/214) (ark-poly) Utilise a more efficient way of evaluating a polynomial at a single point
- [\#242](https://github.com/arkworks-rs/algebra/pull/242), [\#244][https://github.com/arkworks-rs/algebra/pull/244] (ark-poly) Speedup the sequential radix-2 FFT significantly by making the method in which it accesses roots more cache-friendly.
- [\#242](https://github.com/arkworks-rs/algebra/pull/242), [\#244][https://github.com/arkworks-rs/algebra/pull/244], [\#245](https://github.com/arkworks-rs/algebra/pull/245) (ark-poly) Speedup the sequential and parallel radix-2 FFT and IFFT significantly by making the method in which it accesses roots more cache-friendly.
jon-chuang marked this conversation as resolved.
Show resolved Hide resolved

### Bug fixes
- [\#36](https://github.com/arkworks-rs/algebra/pull/36) (ark-ec) In Short-Weierstrass curves, include an infinity bit in `ToConstraintField`.
Expand All @@ -114,5 +114,6 @@ The main features of this release are:
- [\#184](https://github.com/arkworks-rs/algebra/pull/184) Compile with `panic='abort'` in release mode, for safety of the library across FFI boundaries.
- [\#192](https://github.com/arkworks-rs/algebra/pull/192) Fix a bug in the assembly backend for finite field arithmetic.
- [\#217](https://github.com/arkworks-rs/algebra/pull/217) (ark-ec) Fix the definition of `PairingFriendlyCycle` introduced in #190.
-

## v0.1.0 (Initial release of arkworks/algebra)
154 changes: 88 additions & 66 deletions poly/src/domain/radix2/fft.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
use crate::domain::utils::compute_powers_serial;
use crate::domain::{radix2::*, DomainCoeff};
use ark_ff::FftField;
use ark_std::{cfg_iter_mut, vec::Vec};
use ark_std::{cfg_chunks_mut, vec::Vec};
#[cfg(feature = "parallel")]
use rayon::prelude::*;

Expand Down Expand Up @@ -137,109 +137,131 @@ impl<F: FftField> Radix2EvaluationDomain<F> {
});
}

#[inline(always)]
fn butterfly_fn_io<T: DomainCoeff<F>>(((lo, hi), root): ((&mut T, &mut T), &F)) {
let neg = *lo - *hi;
*lo += *hi;
*hi = neg;
*hi *= *root;
}

#[inline(always)]
fn butterfly_fn_oi<T: DomainCoeff<F>>(((lo, hi), root): ((&mut T, &mut T), &F)) {
*hi *= *root;
let neg = *lo - *hi;
*lo += *hi;
*hi = neg;
}

fn io_helper<T: DomainCoeff<F>>(&self, xi: &mut [T], root: F) {
// In the sequential case, we will keep on making the roots cache-aligned,
// according to the access pattern that the FFT uses.
// It is left as a TODO to implement this for the parallel case
let roots = &mut self.roots_of_unity(root);
let mut roots = self.roots_of_unity(root);
let mut step = 1;
let mut first = true;

#[cfg(feature = "parallel")]
let max_threads = rayon::current_num_threads();
#[cfg(not(feature = "parallel"))]
let mut root_len = roots.len();
let max_threads = 1;

let mut gap = xi.len() / 2;
while gap > 0 {
// each butterfly cluster uses 2*gap positions
let chunk_size = 2 * gap;
#[cfg(feature = "parallel")]
let nchunks = xi.len() / chunk_size;

let butterfly_fn = |(chunk_index, (lo, hi)): (usize, (&mut T, &mut T))| {
let neg = *lo - *hi;
*lo += *hi;

*hi = neg;

#[cfg(feature = "parallel")]
let index = nchunks * chunk_index;
// Due to cache aligning, the index into roots is the chunk index
#[cfg(not(feature = "parallel"))]
let index = chunk_index;

*hi *= roots[index];
};
let num_chunks = xi.len() / chunk_size;

// Only compact roots to achieve cache locality/compactness if
// the roots lookup is done a significant amount of times
// Which also implies a large lookup stride.
if num_chunks >= MIN_COMPACTION_CHUNKS {
if !first {
roots = cfg_into_iter!(roots).step_by(step * 2).collect()
}
step = 1;
roots.shrink_to_fit();
} else {
step = num_chunks;
}
first = false;

ark_std::cfg_chunks_mut!(xi, chunk_size).for_each(|cxi| {
cfg_chunks_mut!(xi, chunk_size).for_each(|cxi| {
let (lo, hi) = cxi.split_at_mut(gap);
// If the chunk is sufficiently big that parallelism helps,
// we parallelize the butterfly operation within the chunk.
//
// if chunk_size > MIN_CHUNK_SIZE_FOR_PARALLELIZATION
if gap > MIN_CHUNK_SIZE_FOR_PARALLELIZATION / 2 {
cfg_iter_mut!(lo).zip(hi).enumerate().for_each(butterfly_fn);

if gap > MIN_PROBLEM_SIZE && num_chunks < max_threads {
cfg_iter_mut!(lo)
.zip(cfg_iter_mut!(hi))
jon-chuang marked this conversation as resolved.
Show resolved Hide resolved
.zip(cfg_iter!(roots).step_by(step))
.for_each(Self::butterfly_fn_io);
} else {
lo.iter_mut().zip(hi).enumerate().for_each(butterfly_fn);
lo.iter_mut()
.zip(hi)
.zip(roots.iter().step_by(step))
.for_each(Self::butterfly_fn_io);
}
});

// Cache align the FFT roots in the sequential case.
// In this case, we are aiming to make every root that is accessed one after another,
// appear one after another in the list of roots.
// if the roots are cache aligned in iteration i, then in iteration i+1,
// cache alignment requires selecting every other root.
// (The even powers relative to the current iterations generator)
//
//
// (Roots are already aligned in the first iteration,
// so we only to do realignment after the first iteration.)
#[cfg(not(feature = "parallel"))]
{
for i in 1..(root_len / 2) {
roots[i] = roots[i * 2];
}
root_len /= 2;
}
gap /= 2;
}
}

fn oi_helper<T: DomainCoeff<F>>(&self, xi: &mut [T], root: F) {
let roots = self.roots_of_unity(root);
let roots_cache = self.roots_of_unity(root);
let compaction_size = core::cmp::min(
roots_cache.len() / 2,
roots_cache.len() / MIN_COMPACTION_CHUNKS,
);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens when roots_cache.len() is 2 or less than MIN_COMPACTION_CHUNKS? Could you amend the tests to check that as well? Thanks!

Copy link
Contributor Author

@jon-chuang jon-chuang Mar 30, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm the compaction wouldn't happen. So we don't have to worry about it. Notice that cmp::min is only necessary for MIN_COMPACTION_CHUNKS = 1, since chunks > 0.

If roots_cache.len() < MIN_COMPACTION_CHUNKS, then chunks <= xi.len() / 2 = roots_cache.len() < MIN_COMPACTION_CHUNKS

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok sounds good, a comment to that effect would be great.

let mut compacted_roots = vec![F::default(); compaction_size];

#[cfg(feature = "parallel")]
let max_threads = rayon::current_num_threads();
#[cfg(not(feature = "parallel"))]
let max_threads = 1;

let mut gap = 1;
while gap < xi.len() {
// each butterfly cluster uses 2*gap positions
let chunk_size = 2 * gap;
let nchunks = xi.len() / chunk_size;

let butterfly_fn = |(idx, (lo, hi)): (usize, (&mut T, &mut T))| {
*hi *= roots[nchunks * idx];
let neg = *lo - *hi;
*lo += *hi;
*hi = neg;
let num_chunks = xi.len() / chunk_size;

// Only compact roots to achieve cache locality/compactness if
// the roots lookup is done a significant amount of times
// Which also implies a large lookup stride.
let (roots, step) = if num_chunks >= MIN_COMPACTION_CHUNKS && gap < xi.len() / 2 {
cfg_iter_mut!(compacted_roots[..gap])
.zip(cfg_iter!(roots_cache[..gap * num_chunks]).step_by(num_chunks))
jon-chuang marked this conversation as resolved.
Show resolved Hide resolved
.for_each(|(a, b)| *a = *b);
(&compacted_roots[..gap], 1)
} else {
(&roots_cache[..], num_chunks)
};

ark_std::cfg_chunks_mut!(xi, chunk_size).for_each(|cxi| {
cfg_chunks_mut!(xi, chunk_size).for_each(|cxi| {
let (lo, hi) = cxi.split_at_mut(gap);
// If the chunk is sufficiently big that parallelism helps,
// we parallelize the butterfly operation within the chunk.
//
// if chunk_size > MIN_CHUNK_SIZE_FOR_PARALLELIZATION
if gap > MIN_CHUNK_SIZE_FOR_PARALLELIZATION / 2 {
cfg_iter_mut!(lo).zip(hi).enumerate().for_each(butterfly_fn);

if gap > MIN_PROBLEM_SIZE && num_chunks < max_threads {
cfg_iter_mut!(lo)
.zip(cfg_iter_mut!(hi))
jon-chuang marked this conversation as resolved.
Show resolved Hide resolved
.zip(cfg_iter!(roots).step_by(step))
.for_each(Self::butterfly_fn_oi);
} else {
lo.iter_mut().zip(hi).enumerate().for_each(butterfly_fn);
lo.iter_mut()
.zip(hi)
.zip(roots.iter().step_by(step))
.for_each(Self::butterfly_fn_oi);
}
});

gap *= 2;
}
}
}

// This value controls that when doing a butterfly on a chunk of size c,
// do you parallelize operations on the chunk.
// If c > MIN_CHUNK_SIZE_FOR_PARALLELIZATION,
// then parallelize, else be sequential.
// This value was chosen empirically.
const MIN_CHUNK_SIZE_FOR_PARALLELIZATION: usize = 2048;
const MIN_COMPACTION_CHUNKS: usize = 1 << 7;
jon-chuang marked this conversation as resolved.
Show resolved Hide resolved
const LOG_MIN_PROBLEM_SIZE: usize = 10;
const MIN_PROBLEM_SIZE: usize = 1 << LOG_MIN_PROBLEM_SIZE;
jon-chuang marked this conversation as resolved.
Show resolved Hide resolved

// minimum size at which to parallelize.
#[cfg(feature = "parallel")]
Expand Down