Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve radix-2 FFTs #169

Merged
merged 27 commits into from
Jan 8, 2021
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ The main features of this release are:
- Multi-variate polynomial support
- Many speedups to operations involving polynomials
- Some speedups to `sqrt`
- Small speedups to `MSM`s
- Small speedups to MSMs
- Big speedups to radix-2 FFTs

### Breaking changes
- #20 (ark-poly) Move univariate DensePolynomial and SparsePolynomial into a
Expand Down Expand Up @@ -71,6 +72,7 @@ The main features of this release are:
- #157 (ark-ec) Speed up `variable_base_msm` by not relying on unnecessary normalization.
- #158 (ark-serialize) Add an impl of `CanonicalSerialize/Deserialize` for `()`.
- #166 (ark-ff) Add a `to_bytes_be()` and `to_bytes_le` methods to `BigInt`.
- #169 (ark-poly) Improve radix-2 FFTs by moving to a faster algorithm by Riad Wahby.

### Bug fixes
- #36 (ark-ec) In Short-Weierstrass curves, include an infinity bit in `ToConstraintField`.
Expand Down
3 changes: 2 additions & 1 deletion poly-benches/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ publish = false

[dependencies]
ark-ff = { path = "../ff" }
ark-ec = { path = "../ec" }
ark-poly = { path = "../poly" }
ark-std = { git = "https://github.com/arkworks-rs/utils", default-features = false }
ark-test-curves = { path = "../test-curves", default-features = false, features = [ "bls12_381_scalar_field", "mnt4_753_curve" ] }
Expand All @@ -19,7 +20,7 @@ rayon = { version = "1", optional = true }

[features]
default = []
parallel = ["ark-ff/parallel", "rayon", "ark-poly/parallel", "ark-std/parallel" ]
parallel = ["ark-ff/parallel", "rayon", "ark-poly/parallel", "ark-std/parallel", "ark-ec/parallel" ]

[[bench]]
name = "fft"
Expand Down
1 change: 1 addition & 0 deletions poly/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ derivative = { version = "2", default-features = false, features = [ "use_core"

[dev-dependencies]
ark-test-curves = { path = "../test-curves", default-features = false, features = [ "bls12_381_curve", "mnt4_753_curve"] }
ark-ec = { path = "../ec", default-features = false, features = [ "parallel" ] }

[features]
default = []
Expand Down
4 changes: 4 additions & 0 deletions poly/src/domain/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,8 @@ pub trait DomainCoeff<F: FftField>:
Copy
+ Send
+ Sync
+ core::ops::Add<Output = Self>
+ core::ops::Sub<Output = Self>
+ core::ops::AddAssign
+ core::ops::SubAssign
+ ark_ff::Zero
Expand All @@ -244,6 +246,8 @@ where
T: Copy
+ Send
+ Sync
+ core::ops::Add<Output = Self>
+ core::ops::Sub<Output = Self>
+ core::ops::AddAssign
+ core::ops::SubAssign
+ ark_ff::Zero
Expand Down
200 changes: 200 additions & 0 deletions poly/src/domain/radix2/fft.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
// The code below is a port of the excellent library of https://github.com/kwantam/fffft by Riad Wahby
// to the arkworks APIs

use crate::domain::{radix2::*, DomainCoeff};
use ark_ff::FftField;
use ark_std::vec::Vec;
#[cfg(feature = "parallel")]
use rayon::prelude::*;

#[derive(PartialEq, Eq, Debug)]
enum FFTOrder {
/// Both the input and the output of the FFT must be in-order.
II,
/// The input of the FFT must be in-order, but the output does not have to be.
IO,
/// The input of the FFT can be out of order, but the output must be in-order.
OI,
}

// minimum size at which to parallelize.
#[cfg(feature = "parallel")]
const LOG_PARALLEL_SIZE: u32 = 7;

impl<F: FftField> Radix2EvaluationDomain<F> {
pub(crate) fn in_order_fft_in_place<T: DomainCoeff<F>>(&self, x_s: &mut [T]) {
self.fft_helper_in_place(x_s, FFTOrder::II)
}

pub(crate) fn in_order_ifft_in_place<T: DomainCoeff<F>>(&self, x_s: &mut [T]) {
self.ifft_helper_in_place(x_s, FFTOrder::II)
}

fn fft_helper_in_place<T: DomainCoeff<F>>(&self, x_s: &mut [T], ord: FFTOrder) {
use FFTOrder::*;

let log_len = ark_std::log2(x_s.len());

if ord == OI {
self.oi_helper(x_s, self.group_gen);
} else {
self.io_helper(x_s, self.group_gen);
}

if ord == II {
derange(x_s, log_len);
}
}

fn ifft_helper_in_place<T: DomainCoeff<F>>(&self, x_s: &mut [T], ord: FFTOrder) {
use FFTOrder::*;

let log_len = ark_std::log2(x_s.len());

if ord == II {
derange(x_s, log_len);
}

if ord == IO {
self.io_helper(x_s, self.group_gen_inv);
} else {
self.oi_helper(x_s, self.group_gen_inv);
}
ark_std::cfg_iter_mut!(x_s).for_each(|val| *val *= self.size_inv);
}

fn roots_of_unity_serial(&self, root: F) -> Vec<F> {
let mut value = F::one();
(0..self.size)
.map(|_| {
let old_value = value;
value *= root;
old_value
})
.collect()
}

#[cfg(not(feature = "parallel"))]
fn roots_of_unity(&self, root: F) -> Vec<F> {
self.roots_of_unity_serial(root)
}

#[cfg(feature = "parallel")]
fn roots_of_unity(&self, root: F) -> Vec<F> {
let log_size = self.log_size_of_group;
let group_gen = root;
// early exit for short inputs
if log_size <= LOG_PARALLEL_SIZE {
self.roots_of_unity_serial(root)
} else {
let mut value = group_gen;
// w, w^2, w^4, w^8, ..., w^(2^(log_size - 1))
let log_roots: Vec<F> = (0..(log_size - 1))
.map(|_| {
let old_value = value;
value.square_in_place();
old_value
})
.collect();

// allocate the return array and start the recursion
let mut roots = vec![F::zero(); 1 << (log_size - 1)];
Self::roots_of_unity_recursive(&mut roots, &log_roots);
roots
}
}

#[cfg(feature = "parallel")]
fn roots_of_unity_recursive(out: &mut [F], log_roots: &[F]) {
assert_eq!(out.len(), 1 << log_roots.len());

// base case: just compute the roots sequentially
if log_roots.len() <= LOG_PARALLEL_SIZE as usize {
out[0] = F::one();
for idx in 1..out.len() {
out[idx] = out[idx - 1] * log_roots[0];
}
return;
}

// recursive case:
// 1. split log_roots in half
let (lr_lo, lr_hi) = log_roots.split_at((1 + log_roots.len()) / 2);
let mut scr_lo = vec![F::default(); 1 << lr_lo.len()];
let mut scr_hi = vec![F::default(); 1 << lr_hi.len()];
// 2. compute each half individually
rayon::join(
|| Self::roots_of_unity_recursive(&mut scr_lo, lr_lo),
|| Self::roots_of_unity_recursive(&mut scr_hi, lr_hi),
);
// 3. recombine halves
out.par_chunks_mut(scr_lo.len())
.zip(&scr_hi)
.for_each(|(rt, scr_hi)| {
for (rt, scr_lo) in rt.iter_mut().zip(&scr_lo) {
*rt = *scr_hi * scr_lo;
}
});
}

fn io_helper<T: DomainCoeff<F>>(&self, xi: &mut [T], root: F) {
let roots = self.roots_of_unity(root);

let mut gap = xi.len() / 2;
while gap > 0 {
// each butterfly cluster uses 2*gap positions
let nchunks = xi.len() / (2 * gap);
ark_std::cfg_chunks_mut!(xi, 2 * gap).for_each(|cxi| {
let (lo, hi) = cxi.split_at_mut(gap);
ark_std::cfg_iter_mut!(lo)
.zip(hi)
.enumerate()
.for_each(|(idx, (lo, hi))| {
let neg = *lo - *hi;
*lo += *hi;

*hi = neg;
*hi *= roots[nchunks * idx];
});
});
gap /= 2;
}
}

fn oi_helper<T: DomainCoeff<F>>(&self, xi: &mut [T], root: F) {
let roots = self.roots_of_unity(root);

let mut gap = 1;
while gap < xi.len() {
let nchunks = xi.len() / (2 * gap);

ark_std::cfg_chunks_mut!(xi, 2 * gap).for_each(|cxi| {
let (lo, hi) = cxi.split_at_mut(gap);
ark_std::cfg_iter_mut!(lo)
.zip(hi)
.enumerate()
.for_each(|(idx, (lo, hi))| {
*hi *= roots[nchunks * idx];
let neg = *lo - *hi;
*lo += *hi;
*hi = neg;
});
});
gap *= 2;
}
}
}

#[inline]
fn bitrev(a: u64, log_len: u32) -> u64 {
a.reverse_bits() >> (64 - log_len)
}

fn derange<T>(xi: &mut [T], log_len: u32) {
for idx in 1..(xi.len() as u64 - 1) {
let ridx = bitrev(idx, log_len);
if idx < ridx {
xi.swap(idx as usize, ridx as usize);
}
}
}
Loading