Skip to content

Commit

Permalink
FFT opt
Browse files Browse the repository at this point in the history
  • Loading branch information
Brechtpd authored and einar-taiko committed May 24, 2023
1 parent 73b66d0 commit 0b9ba5e
Show file tree
Hide file tree
Showing 14 changed files with 1,432 additions and 141 deletions.
1 change: 1 addition & 0 deletions halo2_proofs/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ tracing = "0.1"
blake2b_simd = "1"
sha3 = "0.9.1"
rand_chacha = "0.3"
ark-std = { version = "0.3", features = ["print-trace"] }

# Developer tooling dependencies
plotters = { version = "0.3.0", optional = true }
Expand Down
17 changes: 11 additions & 6 deletions halo2_proofs/benches/fft.rs
Original file line number Diff line number Diff line change
@@ -1,22 +1,27 @@
#[macro_use]
extern crate criterion;

use crate::arithmetic::best_fft;
use halo2_proofs::{arithmetic::best_fft, poly::EvaluationDomain};
use group::ff::Field;
use halo2_proofs::*;
use halo2curves::pasta::Fp;
use halo2curves::bn256::Fr as Scalar;

use criterion::{BenchmarkId, Criterion};
use rand_core::OsRng;

fn criterion_benchmark(c: &mut Criterion) {
let j = 5;
let mut group = c.benchmark_group("fft");
for k in 3..19 {
let domain = EvaluationDomain::new(j,k);
let omega = domain.get_omega();
let l = 1<<k;
let data = domain.get_fft_data(l);

group.bench_function(BenchmarkId::new("k", k), |b| {
let mut a = (0..(1 << k)).map(|_| Fp::random(OsRng)).collect::<Vec<_>>();
let omega = Fp::random(OsRng); // would be weird if this mattered
let mut a = (0..(1 << k)).map(|_| Scalar::random(OsRng)).collect::<Vec<_>>();

b.iter(|| {
best_fft(&mut a, omega, k as u32);
best_fft(&mut a, omega, k as u32, data, false);
});
});
}
Expand Down
166 changes: 54 additions & 112 deletions halo2_proofs/src/arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,16 @@ use group::{

pub use halo2curves::{CurveAffine, CurveExt};

use crate::{
fft::{
self, parallel,
recursive::{self, FFTData},
},
plonk::{get_duration, get_time, log_info},
poly::EvaluationDomain,
};
use std::{env, mem};

/// This represents an element of a group with basic operations that can be
/// performed. This allows an FFT implementation (for example) to operate
/// generically over either a field or elliptic curve group.
Expand All @@ -25,6 +35,9 @@ where
{
}

/// TEMP
pub static mut MULTIEXP_TOTAL_TIME: usize = 0;

fn multiexp_serial<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C], acc: &mut C::Curve) {
let coeffs: Vec<_> = coeffs.iter().map(|a| a.to_repr()).collect();

Expand Down Expand Up @@ -147,8 +160,11 @@ pub fn small_multiexp<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C]) -> C::C
pub fn best_multiexp<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C]) -> C::Curve {
assert_eq!(coeffs.len(), bases.len());

log_info(format!("msm: {}", coeffs.len()));

let start = get_time();
let num_threads = multicore::current_num_threads();
if coeffs.len() > num_threads {
let res = if coeffs.len() > num_threads {
let chunk = coeffs.len() / num_threads;
let num_chunks = coeffs.chunks(chunk).len();
let mut results = vec![C::Curve::identity(); num_chunks];
Expand All @@ -170,134 +186,48 @@ pub fn best_multiexp<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C]) -> C::Cu
let mut acc = C::Curve::identity();
multiexp_serial(coeffs, bases, &mut acc);
acc
}
}

/// Performs a radix-$2$ Fast-Fourier Transformation (FFT) on a vector of size
/// $n = 2^k$, when provided `log_n` = $k$ and an element of multiplicative
/// order $n$ called `omega` ($\omega$). The result is that the vector `a`, when
/// interpreted as the coefficients of a polynomial of degree $n - 1$, is
/// transformed into the evaluations of this polynomial at each of the $n$
/// distinct powers of $\omega$. This transformation is invertible by providing
/// $\omega^{-1}$ in place of $\omega$ and dividing each resulting field element
/// by $n$.
///
/// This will use multithreading if beneficial.
pub fn best_fft<Scalar: Field, G: FftGroup<Scalar>>(a: &mut [G], omega: Scalar, log_n: u32) {
fn bitreverse(mut n: usize, l: usize) -> usize {
let mut r = 0;
for _ in 0..l {
r = (r << 1) | (n & 1);
n >>= 1;
}
r
}

let threads = multicore::current_num_threads();
let log_threads = log2_floor(threads);
let n = a.len() as usize;
assert_eq!(n, 1 << log_n);
};

for k in 0..n {
let rk = bitreverse(k, log_n as usize);
if k < rk {
a.swap(rk, k);
}
let duration = get_duration(start);
#[allow(unsafe_code)]
unsafe {
crate::arithmetic::MULTIEXP_TOTAL_TIME += duration;
}

// precompute twiddle factors
let twiddles: Vec<_> = (0..(n / 2) as usize)
.scan(Scalar::ONE, |w, _| {
let tw = *w;
*w *= &omega;
Some(tw)
})
.collect();

if log_n <= log_threads {
let mut chunk = 2_usize;
let mut twiddle_chunk = (n / 2) as usize;
for _ in 0..log_n {
a.chunks_mut(chunk).for_each(|coeffs| {
let (left, right) = coeffs.split_at_mut(chunk / 2);

// case when twiddle factor is one
let (a, left) = left.split_at_mut(1);
let (b, right) = right.split_at_mut(1);
let t = b[0];
b[0] = a[0];
a[0] += &t;
b[0] -= &t;

left.iter_mut()
.zip(right.iter_mut())
.enumerate()
.for_each(|(i, (a, b))| {
let mut t = *b;
t *= &twiddles[(i + 1) * twiddle_chunk];
*b = *a;
*a += &t;
*b -= &t;
});
});
chunk *= 2;
twiddle_chunk /= 2;
}
} else {
recursive_butterfly_arithmetic(a, n, 1, &twiddles)
}
res
}

/// This perform recursive butterfly arithmetic
pub fn recursive_butterfly_arithmetic<Scalar: Field, G: FftGroup<Scalar>>(
/// Dispatcher
pub fn best_fft<Scalar: Field, G: FftGroup<Scalar>>(
a: &mut [G],
n: usize,
twiddle_chunk: usize,
twiddles: &[Scalar],
omega: Scalar,
log_n: u32,
data: &FFTData<Scalar>,
inverse: bool,
) {
if n == 2 {
let t = a[1];
a[1] = a[0];
a[0] += &t;
a[1] -= &t;
} else {
let (left, right) = a.split_at_mut(n / 2);
rayon::join(
|| recursive_butterfly_arithmetic(left, n / 2, twiddle_chunk * 2, twiddles),
|| recursive_butterfly_arithmetic(right, n / 2, twiddle_chunk * 2, twiddles),
);

// case when twiddle factor is one
let (a, left) = left.split_at_mut(1);
let (b, right) = right.split_at_mut(1);
let t = b[0];
b[0] = a[0];
a[0] += &t;
b[0] -= &t;

left.iter_mut()
.zip(right.iter_mut())
.enumerate()
.for_each(|(i, (a, b))| {
let mut t = *b;
t *= &twiddles[(i + 1) * twiddle_chunk];
*b = *a;
*a += &t;
*b -= &t;
});
}
fft::fft(a, omega, log_n, data, inverse);
}

/// Convert coefficient bases group elements to lagrange basis by inverse FFT.
pub fn g_to_lagrange<C: CurveAffine>(g_projective: Vec<C::Curve>, k: u32) -> Vec<C> {
let n_inv = C::Scalar::TWO_INV.pow_vartime(&[k as u64, 0, 0, 0]);
let omega = C::Scalar::ROOT_OF_UNITY;
let mut omega_inv = C::Scalar::ROOT_OF_UNITY_INV;
for _ in k..C::Scalar::S {
omega_inv = omega_inv.square();
}

let mut g_lagrange_projective = g_projective;
best_fft(&mut g_lagrange_projective, omega_inv, k);
let n = g_lagrange_projective.len();
let fft_data = FFTData::new(n, omega, omega_inv);

best_fft(
&mut g_lagrange_projective,
omega_inv,
k,
&fft_data,
false,
);
parallelize(&mut g_lagrange_projective, |g, _| {
for g in g.iter_mut() {
*g *= n_inv;
Expand Down Expand Up @@ -402,7 +332,8 @@ pub fn parallelize<T: Send, F: Fn(&mut [T], usize) + Send + Sync + Clone>(v: &mu
});
}

fn log2_floor(num: usize) -> u32 {
/// Compute the binary logarithm floored.
pub fn log2_floor(num: usize) -> u32 {
assert!(num > 0);

let mut pow = 0;
Expand Down Expand Up @@ -496,7 +427,18 @@ pub(crate) fn powers<F: Field>(base: F) -> impl Iterator<Item = F> {
std::iter::successors(Some(F::ONE), move |power| Some(base * power))
}

/// Reverse `l` LSBs of bitvector `n`
pub fn bitreverse(mut n: usize, l: usize) -> usize {
let mut r = 0;
for _ in 0..l {
r = (r << 1) | (n & 1);
n >>= 1;
}
r
}

#[cfg(test)]
use crate::plonk::{start_measure, stop_measure};
use rand_core::OsRng;

#[cfg(test)]
Expand Down
Loading

0 comments on commit 0b9ba5e

Please sign in to comment.