Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve radix-2 FFTs #169

Merged
merged 27 commits into from
Jan 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ The main features of this release are:
- Multi-variate polynomial support
- Many speedups to operations involving polynomials
- Some speedups to `sqrt`
- Small speedups to `MSM`s
- Small speedups to MSMs
- Big speedups to radix-2 FFTs

### Breaking changes
- #20 (ark-poly) Move univariate DensePolynomial and SparsePolynomial into a
Expand Down Expand Up @@ -71,6 +72,7 @@ The main features of this release are:
- #157 (ark-ec) Speed up `variable_base_msm` by not relying on unnecessary normalization.
- #158 (ark-serialize) Add an impl of `CanonicalSerialize/Deserialize` for `()`.
- #166 (ark-ff) Add a `to_bytes_be()` and `to_bytes_le` methods to `BigInt`.
- #169 (ark-poly) Improve radix-2 FFTs by moving to a faster algorithm by Riad S. Wahby.

### Bug fixes
- #36 (ark-ec) In Short-Weierstrass curves, include an infinity bit in `ToConstraintField`.
Expand Down
30 changes: 8 additions & 22 deletions poly-benches/benches/fft.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,6 @@ const BENCHMARK_LOG_INTERVAL_DEGREE: usize = 1;
const ENABLE_RADIX2_BENCHES: bool = true;
const ENABLE_MIXED_RADIX_BENCHES: bool = true;

const ENABLE_SUBGROUP_FFT_BENCH: bool = true;
Pratyush marked this conversation as resolved.
Show resolved Hide resolved
const ENABLE_COSET_FFT_BENCH: bool = true;

const ENABLE_SUBGROUP_IFFT_BENCH: bool = false;
const ENABLE_COSET_IFFT_BENCH: bool = false;

// returns vec![2^{min}, 2^{min + interval}, ..., 2^{max}], where:
// interval = BENCHMARK_LOG_INTERVAL_DEGREE
// min = ceil(log_2(BENCHMARK_MIN_DEGREE))
Expand Down Expand Up @@ -92,22 +86,14 @@ fn bench_coset_ifft_in_place<F: FftField, D: EvaluationDomain<F>>(b: &mut Benche
}

fn fft_benches<F: FftField, D: EvaluationDomain<F>>(c: &mut Criterion, name: &'static str) {
if ENABLE_SUBGROUP_FFT_BENCH {
let cur_name = format!("{:?} - subgroup_fft_in_place", name.clone());
setup_bench(c, &cur_name, bench_fft_in_place::<F, D>);
}
if ENABLE_SUBGROUP_IFFT_BENCH {
let cur_name = format!("{:?} - subgroup_ifft_in_place", name.clone());
setup_bench(c, &cur_name, bench_ifft_in_place::<F, D>);
}
if ENABLE_COSET_FFT_BENCH {
let cur_name = format!("{:?} - coset_fft_in_place", name.clone());
setup_bench(c, &cur_name, bench_coset_fft_in_place::<F, D>);
}
if ENABLE_COSET_IFFT_BENCH {
let cur_name = format!("{:?} - coset_ifft_in_place", name.clone());
setup_bench(c, &cur_name, bench_coset_ifft_in_place::<F, D>);
}
let cur_name = format!("{:?} - subgroup_fft_in_place", name.clone());
setup_bench(c, &cur_name, bench_fft_in_place::<F, D>);
let cur_name = format!("{:?} - subgroup_ifft_in_place", name.clone());
setup_bench(c, &cur_name, bench_ifft_in_place::<F, D>);
let cur_name = format!("{:?} - coset_fft_in_place", name.clone());
setup_bench(c, &cur_name, bench_coset_fft_in_place::<F, D>);
let cur_name = format!("{:?} - coset_ifft_in_place", name.clone());
setup_bench(c, &cur_name, bench_coset_ifft_in_place::<F, D>);
}

fn bench_bls12_381(c: &mut Criterion) {
Expand Down
26 changes: 8 additions & 18 deletions poly/src/domain/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@ use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
use ark_std::{fmt, hash, vec::Vec};
use rand::Rng;

#[cfg(feature = "parallel")]
use ark_std::cmp::max;
#[cfg(feature = "parallel")]
use rayon::prelude::*;

Expand Down Expand Up @@ -96,23 +94,11 @@ pub trait EvaluationDomain<F: FftField>:
/// Multiply the `i`-th element of `coeffs` with the `i`-th power of `g`.
#[cfg(feature = "parallel")]
fn distribute_powers<T: DomainCoeff<F>>(coeffs: &mut [T], g: F) {
// compute the number of threads we will be using.
let num_cpus_available = rayon::current_num_threads();
let min_elements_per_thread = 256;
let num_elem_per_thread = max(coeffs.len() / num_cpus_available, min_elements_per_thread);

// Split up the coeffs to coset-shift across each thread evenly,
// and then apply coset shift.
let powers_of_g = crate::domain::utils::compute_powers(coeffs.len(), g);
coeffs
.par_chunks_mut(num_elem_per_thread)
.enumerate()
.for_each(|(i, chunk)| {
let mut pow = g.pow(&[(i * num_elem_per_thread) as u64]);
chunk.iter_mut().for_each(|c| {
*c *= pow;
pow *= &g
});
});
.par_iter_mut()
.zip(powers_of_g)
.for_each(|(coeff, power)| *coeff *= power);
}

/// Compute a FFT over a coset of the domain.
Expand Down Expand Up @@ -231,6 +217,8 @@ pub trait DomainCoeff<F: FftField>:
Copy
+ Send
+ Sync
+ core::ops::Add<Output = Self>
+ core::ops::Sub<Output = Self>
+ core::ops::AddAssign
+ core::ops::SubAssign
+ ark_ff::Zero
Expand All @@ -244,6 +232,8 @@ where
T: Copy
+ Send
+ Sync
+ core::ops::Add<Output = Self>
+ core::ops::Sub<Output = Self>
+ core::ops::AddAssign
+ core::ops::SubAssign
+ ark_ff::Zero
Expand Down
132 changes: 132 additions & 0 deletions poly/src/domain/radix2/fft.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
// The code below is a port of the excellent library of https://github.com/kwantam/fffft by Riad S. Wahby
// to the arkworks APIs

use crate::domain::{radix2::*, DomainCoeff};
use ark_ff::FftField;
use ark_std::vec::Vec;
#[cfg(feature = "parallel")]
use rayon::prelude::*;

#[derive(PartialEq, Eq, Debug)]
enum FFTOrder {
/// Both the input and the output of the FFT must be in-order.
II,
/// The input of the FFT must be in-order, but the output does not have to be.
IO,
/// The input of the FFT can be out of order, but the output must be in-order.
OI,
}

impl<F: FftField> Radix2EvaluationDomain<F> {
pub(crate) fn in_order_fft_in_place<T: DomainCoeff<F>>(&self, x_s: &mut [T]) {
self.fft_helper_in_place(x_s, FFTOrder::II)
}

pub(crate) fn in_order_ifft_in_place<T: DomainCoeff<F>>(&self, x_s: &mut [T]) {
self.ifft_helper_in_place(x_s, FFTOrder::II)
}

fn fft_helper_in_place<T: DomainCoeff<F>>(&self, x_s: &mut [T], ord: FFTOrder) {
use FFTOrder::*;

let log_len = ark_std::log2(x_s.len());

if ord == OI {
self.oi_helper(x_s, self.group_gen);
} else {
self.io_helper(x_s, self.group_gen);
}

if ord == II {
derange(x_s, log_len);
}
}

fn ifft_helper_in_place<T: DomainCoeff<F>>(&self, x_s: &mut [T], ord: FFTOrder) {
use FFTOrder::*;

let log_len = ark_std::log2(x_s.len());

if ord == II {
derange(x_s, log_len);
}

if ord == IO {
self.io_helper(x_s, self.group_gen_inv);
} else {
self.oi_helper(x_s, self.group_gen_inv);
}
ark_std::cfg_iter_mut!(x_s).for_each(|val| *val *= self.size_inv);
}

#[cfg(not(feature = "parallel"))]
fn roots_of_unity(&self, root: F) -> Vec<F> {
crate::domain::utils::compute_powers_serial(self.size as usize, root)
}

#[cfg(feature = "parallel")]
fn roots_of_unity(&self, root: F) -> Vec<F> {
crate::domain::utils::compute_powers(self.size as usize, root)
}

fn io_helper<T: DomainCoeff<F>>(&self, xi: &mut [T], root: F) {
let roots = self.roots_of_unity(root);

let mut gap = xi.len() / 2;
while gap > 0 {
// each butterfly cluster uses 2*gap positions
let nchunks = xi.len() / (2 * gap);
ark_std::cfg_chunks_mut!(xi, 2 * gap).for_each(|cxi| {
let (lo, hi) = cxi.split_at_mut(gap);
ark_std::cfg_iter_mut!(lo)
.zip(hi)
.enumerate()
.for_each(|(idx, (lo, hi))| {
let neg = *lo - *hi;
*lo += *hi;

*hi = neg;
*hi *= roots[nchunks * idx];
});
});
gap /= 2;
}
}

fn oi_helper<T: DomainCoeff<F>>(&self, xi: &mut [T], root: F) {
let roots = self.roots_of_unity(root);

let mut gap = 1;
while gap < xi.len() {
let nchunks = xi.len() / (2 * gap);

ark_std::cfg_chunks_mut!(xi, 2 * gap).for_each(|cxi| {
let (lo, hi) = cxi.split_at_mut(gap);
ark_std::cfg_iter_mut!(lo)
.zip(hi)
.enumerate()
.for_each(|(idx, (lo, hi))| {
*hi *= roots[nchunks * idx];
let neg = *lo - *hi;
*lo += *hi;
*hi = neg;
});
});
gap *= 2;
}
}
}

#[inline]
fn bitrev(a: u64, log_len: u32) -> u64 {
a.reverse_bits() >> (64 - log_len)
}

fn derange<T>(xi: &mut [T], log_len: u32) {
for idx in 1..(xi.len() as u64 - 1) {
let ridx = bitrev(idx, log_len);
if idx < ridx {
xi.swap(idx as usize, ridx as usize);
}
}
}
Loading