Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cache alignment for serial and parallel FFT and IFFT #245

Merged
merged 25 commits into from
Apr 8, 2021
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ The main features of this release are:
- [\#207](https://github.com/arkworks-rs/algebra/pull/207) (ark-ff) Improve performance of extension fields when the non-residue is negative. (Improves fq2, fq12, and g2 speed on bls12 and bn curves)
- [\#211](https://github.com/arkworks-rs/algebra/pull/211) (ark-ec) Improve performance of BLS12 final exponentiation.
- [\#214](https://github.com/arkworks-rs/algebra/pull/214) (ark-poly) Utilise a more efficient way of evaluating a polynomial at a single point
- [\#242](https://github.com/arkworks-rs/algebra/pull/242), [\#244][https://github.com/arkworks-rs/algebra/pull/244] (ark-poly) Speedup the sequential radix-2 FFT significantly by making the method in which it accesses roots more cache-friendly.
- [\#242](https://github.com/arkworks-rs/algebra/pull/242), [\#244][https://github.com/arkworks-rs/algebra/pull/244], [\#245](https://github.com/arkworks-rs/algebra/pull/242) (ark-poly) Speedup the sequential and parallel radix-2 FFT and IFFT significantly by making the method in which it accesses roots more cache-friendly.
jon-chuang marked this conversation as resolved.
Show resolved Hide resolved

### Bug fixes
- [\#36](https://github.com/arkworks-rs/algebra/pull/36) (ark-ec) In Short-Weierstrass curves, include an infinity bit in `ToConstraintField`.
Expand All @@ -100,5 +100,6 @@ The main features of this release are:
- [\#184](https://github.com/arkworks-rs/algebra/pull/184) Compile with `panic='abort'` in release mode, for safety of the library across FFI boundaries.
- [\#192](https://github.com/arkworks-rs/algebra/pull/192) Fix a bug in the assembly backend for finite field arithmetic.
- [\#217](https://github.com/arkworks-rs/algebra/pull/217) (ark-ec) Fix the definition of `PairingFriendlyCycle` introduced in #190.
-

## v0.1.0 (Initial release of arkworks/algebra)
158 changes: 126 additions & 32 deletions poly/src/domain/radix2/fft.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
use crate::domain::utils::compute_powers_serial;
use crate::domain::{radix2::*, DomainCoeff};
use ark_ff::FftField;
use ark_std::{cfg_iter_mut, vec::Vec};
use ark_std::{cfg_chunks_mut, cfg_iter_mut, vec::Vec};
#[cfg(feature = "parallel")]
use rayon::prelude::*;

Expand Down Expand Up @@ -141,94 +141,188 @@ impl<F: FftField> Radix2EvaluationDomain<F> {
// In the sequential case, we will keep on making the roots cache-aligned,
// according to the access pattern that the FFT uses.
// It is left as a TODO to implement this for the parallel case
let roots = &mut self.roots_of_unity(root);
let mut roots = self.roots_of_unity(root);

#[cfg(feature = "parallel")]
let max_threads = rayon::current_num_threads();

#[cfg(not(feature = "parallel"))]
let mut root_len = roots.len();
let max_threads = 1;

let mut gap = xi.len() / 2;
while gap > 0 {
// each butterfly cluster uses 2*gap positions
let chunk_size = 2 * gap;
#[cfg(feature = "parallel")]
let nchunks = xi.len() / chunk_size;
let num_chunks = xi.len() / chunk_size;

// Since we define the number of sub chunks per chunk as the ceil of the threads
// partitioned amongst the chunks, as follows
let sub_chunks = (max_threads - 1) / num_chunks + 1;

// the size of each sub chunk that results in that number of sub chunks
// is the floor of the gap, which is the chunk problem size, divided by
// the number of sub chunks
let sub_chunk_size = gap / sub_chunks;

let butterfly_fn = |(chunk_index, (lo, hi)): (usize, (&mut T, &mut T))| {
let butterfly_fn = |(index, (lo, hi)): (usize, (&mut T, &mut T))| {
let neg = *lo - *hi;
*lo += *hi;

*hi = neg;

#[cfg(feature = "parallel")]
let index = nchunks * chunk_index;
// Due to cache aligning, the index into roots is the chunk index
#[cfg(not(feature = "parallel"))]
let index = chunk_index;

*hi *= roots[index];
};

ark_std::cfg_chunks_mut!(xi, chunk_size).for_each(|cxi| {
cfg_chunks_mut!(xi, chunk_size).for_each(|cxi| {
let (lo, hi) = cxi.split_at_mut(gap);
// If the chunk is sufficiently big that parallelism helps,
// we parallelize the butterfly operation within the chunk.
//
// if chunk_size > MIN_CHUNK_SIZE_FOR_PARALLELIZATION
if gap > MIN_CHUNK_SIZE_FOR_PARALLELIZATION / 2 {
// We chunk up the chunk such that each thread can operate on elements in sequence.

if gap > MIN_CHUNK_SIZE_FOR_PARALLELIZATION / 2 * sub_chunks {
cfg_chunks_mut!(lo, sub_chunk_size)
.zip(cfg_chunks_mut!(hi, sub_chunk_size))
.enumerate()
.for_each(|(chunk_id, (lo_chunk, hi_chunk))| {
lo_chunk
.iter_mut()
.zip(hi_chunk)
.enumerate()
.for_each(|(idx, d)| {
butterfly_fn((chunk_id * sub_chunk_size + idx, d))
jon-chuang marked this conversation as resolved.
Show resolved Hide resolved
});
});
} else if gap > MIN_CHUNK_SIZE_FOR_PARALLELIZATION / 2 {
cfg_iter_mut!(lo).zip(hi).enumerate().for_each(butterfly_fn);
} else {
lo.iter_mut().zip(hi).enumerate().for_each(butterfly_fn);
}
});

// Cache align the FFT roots in the sequential case.
// In this case, we are aiming to make every root that is accessed one after another,
// appear one after another in the list of roots.
// if the roots are cache aligned in iteration i, then in iteration i+1,
// cache alignment requires selecting every other root.
// (The even powers relative to the current iterations generator)
//
//
// (Roots are already aligned in the first iteration,
// so we only to do realignment after the first iteration.)

#[cfg(feature = "parallel")]
{
for i in 1..core::cmp::min(roots.len() / 2, 1 << LOG_ROOTS_OF_UNITY_PARALLEL_SIZE) {
roots[i] = roots[i * 2];
}

// if a sufficient amount of roots have been compacted, we have the following situation:
// [a:compacted][b:writable][---- c:to be written ----][...rest of slice ...]
// <----j----> <----j----> <--------- 2j -------->
// So we divy up b and c into chunks and write them in parallel. We procede
// recursively, where c becomes the new writeable b.

for i in LOG_ROOTS_OF_UNITY_PARALLEL_SIZE..=ark_std::log2(roots.len() / 4) {
let j = 1 << i;

// We set this to be the ceil of the problem size divided by the number of threads
let chunk_size = (j - 1) / max_threads + 1;

let (roots_lo, roots_hi) = roots.split_at_mut(2 * j);

cfg_chunks_mut!(roots_lo[j..2 * j], chunk_size)
.zip(cfg_chunks_mut!(roots_hi[..2 * j], chunk_size * 2))
.for_each(|(chunk, chunk_2)| {
for i in 0..chunk.len() {
chunk[i] = chunk_2[2 * i];
}
});
}
roots.resize(roots.len() / 2, F::default());
}

#[cfg(not(feature = "parallel"))]
{
for i in 1..(root_len / 2) {
for i in 1..roots.len() / 2 {
roots[i] = roots[i * 2];
}
root_len /= 2;

roots.resize(roots.len() / 2, F::default());
}

gap /= 2;
}
}

fn oi_helper<T: DomainCoeff<F>>(&self, xi: &mut [T], root: F) {
let roots = self.roots_of_unity(root);
let roots_cache = self.roots_of_unity(root);
let mut compacted_roots = vec![F::default(); roots_cache.len() / 2];

#[cfg(feature = "parallel")]
let max_threads = rayon::current_num_threads();

#[cfg(not(feature = "parallel"))]
let max_threads = 1;

let mut gap = 1;
while gap < xi.len() {
// each butterfly cluster uses 2*gap positions
let chunk_size = 2 * gap;
let nchunks = xi.len() / chunk_size;
let num_chunks = xi.len() / chunk_size;

// Since we define the number of sub chunks per chunk as the ceil of the threads
// partitioned amongst the chunks, as follows
let sub_chunks = (max_threads - 1) / num_chunks + 1;

// the size of each sub chunk that results in that number of sub chunks
// is the floor of the gap, which is the chunk problem size, divided by
// the number of sub chunks
let sub_chunk_size = gap / sub_chunks;

let roots = if gap < xi.len() / 2 {
if gap > MIN_CHUNK_SIZE_FOR_PARALLELIZATION {
cfg_iter_mut!(compacted_roots[..gap])
.enumerate()
.for_each(|(i, root)| {
*root = roots_cache[num_chunks * i];
});
} else {
for i in 0..gap {
compacted_roots[i] = roots_cache[num_chunks * i];
}
}
&compacted_roots[..]
} else {
&roots_cache[..]
};

let butterfly_fn = |(idx, (lo, hi)): (usize, (&mut T, &mut T))| {
*hi *= roots[nchunks * idx];
*hi *= roots[idx];
let neg = *lo - *hi;
*lo += *hi;
*hi = neg;
};

ark_std::cfg_chunks_mut!(xi, chunk_size).for_each(|cxi| {
cfg_chunks_mut!(xi, chunk_size).for_each(|cxi| {
let (lo, hi) = cxi.split_at_mut(gap);
// If the chunk is sufficiently big that parallelism helps,
// we parallelize the butterfly operation within the chunk.
//
// if chunk_size > MIN_CHUNK_SIZE_FOR_PARALLELIZATION
if gap > MIN_CHUNK_SIZE_FOR_PARALLELIZATION / 2 {
// We chunk up the chunk such that each thread can operate on elements in sequence.
// to help cache locality.

if gap > MIN_CHUNK_SIZE_FOR_PARALLELIZATION / 2 * sub_chunks {
cfg_chunks_mut!(lo, sub_chunk_size)
.zip(cfg_chunks_mut!(hi, sub_chunk_size))
.enumerate()
.for_each(|(chunk_id, (lo_chunk, hi_chunk))| {
lo_chunk
.iter_mut()
.zip(hi_chunk)
.enumerate()
.for_each(|(idx, d)| {
butterfly_fn((chunk_id * sub_chunk_size + idx, d))
});
});
} else if gap > MIN_CHUNK_SIZE_FOR_PARALLELIZATION / 2 {
cfg_iter_mut!(lo).zip(hi).enumerate().for_each(butterfly_fn);
} else {
lo.iter_mut().zip(hi).enumerate().for_each(butterfly_fn);
}
});

gap *= 2;
}
}
Expand Down