From af2026b427842571b294c65eb57a824486ccbbe4 Mon Sep 17 00:00:00 2001 From: Christopher Patton Date: Fri, 26 Mar 2021 17:57:08 -0700 Subject: [PATCH] Implement an iterative FFT algorithm This change adds an alternative algorithm for computing the discrete Fourier transform. It also adds a new module, benchmarked, for components of the crate that we want to benchmark, but don't want to expose in the public API. Finally, it adds a benchmark for comparing the speed of iterative FFT and recursive FFT on various input lengths. --- Cargo.toml | 5 + benches/fft.rs | 33 +++++++ src/benchmarked.rs | 26 +++++ src/client.rs | 2 +- src/fft.rs | 172 +++++++++++++++++++++++++++++++++ src/finite_field.rs | 67 ++++++++++--- src/fp.rs | 226 +++++++++++++++++++++++++++++++------------- src/lib.rs | 2 + src/polynomial.rs | 2 +- 9 files changed, 452 insertions(+), 83 deletions(-) create mode 100644 benches/fft.rs create mode 100644 src/benchmarked.rs create mode 100644 src/fft.rs diff --git a/Cargo.toml b/Cargo.toml index 73b486dbb..52abf8147 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,8 +19,13 @@ thiserror = "1.0" [dev-dependencies] assert_matches = "1.5.0" +criterion = "0.3" modinverse = "0.1.0" num-bigint = "0.4.0" +[[bench]] +name = "fft" +harness = false + [[example]] name = "sum" diff --git a/benches/fft.rs b/benches/fft.rs new file mode 100644 index 000000000..e51276e0a --- /dev/null +++ b/benches/fft.rs @@ -0,0 +1,33 @@ +// SPDX-License-Identifier: MPL-2.0 + +use criterion::{criterion_group, criterion_main, Criterion}; + +use prio::benchmarked::{benchmarked_iterative_fft, benchmarked_recursive_fft}; +use prio::finite_field::{Field, FieldElement}; + +pub fn fft(c: &mut Criterion) { + let test_sizes = [16, 256, 1024, 4096]; + for size in test_sizes.iter() { + let mut rng = rand::thread_rng(); + let mut inp = vec![Field::zero(); *size]; + let mut outp = vec![Field::zero(); *size]; + for i in 0..*size { + inp[i] = Field::rand(&mut rng); + } + + c.bench_function(&format!("iterative/{}", *size), |b| { + b.iter(|| { + benchmarked_iterative_fft(&mut outp, &inp); + }) + }); + + c.bench_function(&format!("recursive/{}", *size), |b| { + b.iter(|| { + benchmarked_recursive_fft(&mut outp, &inp); + }) + }); + } +} + +criterion_group!(benches, fft); +criterion_main!(benches); diff --git a/src/benchmarked.rs b/src/benchmarked.rs new file mode 100644 index 000000000..ffd75df4a --- /dev/null +++ b/src/benchmarked.rs @@ -0,0 +1,26 @@ +// SPDX-License-Identifier: MPL-2.0 + +//! This package provides wrappers around internal components of this crate that we want to +//! benchmark, but which we don't want to expose in the public API. + +use crate::fft::discrete_fourier_transform; +use crate::finite_field::{Field, FieldElement}; +use crate::polynomial::{poly_fft, PolyAuxMemory}; + +/// Sets `outp` to the Discrete Fourier Transform (DFT) using an iterative FFT algorithm. +pub fn benchmarked_iterative_fft(outp: &mut [F], inp: &[F]) { + discrete_fourier_transform(outp, inp).expect("encountered unexpected error"); +} + +/// Sets `outp` to the Discrete Fourier Transform (DFT) using a recursive FFT algorithm. +pub fn benchmarked_recursive_fft(outp: &mut [Field], inp: &[Field]) { + let mut mem = PolyAuxMemory::new(inp.len() / 2); + poly_fft( + outp, + inp, + &mem.roots_2n, + inp.len(), + false, + &mut mem.fft_memory, + ) +} diff --git a/src/client.rs b/src/client.rs index ca3a9676d..d7209b2ef 100644 --- a/src/client.rs +++ b/src/client.rs @@ -28,7 +28,7 @@ impl Client { pub fn new(dimension: usize, public_key1: PublicKey, public_key2: PublicKey) -> Option { let n = (dimension + 1).next_power_of_two(); - if 2 * n > Field::num_roots() as usize { + if 2 * n > Field::generator_order() as usize { // too many elements for this field, not enough roots of unity return None; } diff --git a/src/fft.rs b/src/fft.rs new file mode 100644 index 000000000..370ce2fff --- /dev/null +++ b/src/fft.rs @@ -0,0 +1,172 @@ +// SPDX-License-Identifier: MPL-2.0 + +//! This module implements an iterative FFT algorithm for computing the (inverse) Discrete Fourier +//! Transform (DFT) over a slice of field elements. + +use crate::finite_field::FieldElement; +use crate::fp::{log2, MAX_ROOTS}; + +use std::convert::TryFrom; + +/// An error returned by DFT or DFT inverse computation. +#[derive(Debug, thiserror::Error)] +pub enum FftError { + /// The output is too small. + #[error("output slice is smaller than the input")] + OutputTooSmall, + /// The input is too large. + #[error("input slice is larger than than maximum permitted")] + InputTooLarge, + /// The input length is not a power of 2. + #[error("input size is not a power of 2")] + InputSizeInvalid, +} + +/// Sets `outp` to the DFT of `inp`. +pub fn discrete_fourier_transform( + outp: &mut [F], + inp: &[F], +) -> Result<(), FftError> { + let n = inp.len(); + let d = usize::try_from(log2(n as u128)).unwrap(); + + if n > outp.len() { + return Err(FftError::OutputTooSmall); + } + + if n > 1 << MAX_ROOTS { + return Err(FftError::InputTooLarge); + } + + if n != 1 << d { + return Err(FftError::InputSizeInvalid); + } + + for i in 0..n { + outp[i] = inp[bitrev(d, i)]; + } + + let mut w: F; + for l in 1..d + 1 { + w = F::root(0).unwrap(); // one + let r = F::root(l).unwrap(); + let y = 1 << (l - 1); + for i in 0..y { + for j in 0..(n / y) >> 1 { + let x = (1 << l) * j + i; + let u = outp[x]; + let v = w * outp[x + y]; + outp[x] = u + v; + outp[x + y] = u - v; + } + w *= r; + } + } + + Ok(()) +} + +/// Sets `outp` to the inverse of the DFT of `inp`. +#[allow(dead_code)] +pub fn discrete_fourier_transform_inv( + outp: &mut [F], + inp: &[F], +) -> Result<(), FftError> { + discrete_fourier_transform(outp, inp)?; + let n = inp.len(); + let m = F::from(F::Integer::try_from(n).unwrap()).inv(); + let mut tmp: F; + + outp[0] *= m; + outp[n >> 1] *= m; + for i in 1..n >> 1 { + tmp = outp[i] * m; + outp[i] = outp[n - i] * m; + outp[n - i] = tmp; + } + + Ok(()) +} + +// bitrev returns the first d bits of x in reverse order. (Thanks, OEIS! https://oeis.org/A030109) +fn bitrev(d: usize, x: usize) -> usize { + let mut y = 0; + for i in 0..d { + y += ((x >> i) & 1) << (d - i); + } + y >> 1 +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::finite_field::{Field, Field126, Field64, Field80}; + use crate::polynomial::{poly_fft, PolyAuxMemory}; + + fn discrete_fourier_transform_then_inv_test() -> Result<(), FftError> { + let mut rng = rand::thread_rng(); + let test_sizes = [1, 2, 4, 8, 16, 256, 1024, 2048]; + + for size in test_sizes.iter() { + let mut want = vec![F::zero(); *size]; + let mut tmp = vec![F::zero(); *size]; + let mut got = vec![F::zero(); *size]; + for i in 0..*size { + want[i] = F::rand(&mut rng); + } + + discrete_fourier_transform(&mut tmp, &want)?; + discrete_fourier_transform_inv(&mut got, &tmp)?; + assert_eq!(got, want); + } + + Ok(()) + } + + #[test] + fn test_field32() { + discrete_fourier_transform_then_inv_test::().expect("unexpected error"); + } + + #[test] + fn test_field64() { + discrete_fourier_transform_then_inv_test::().expect("unexpected error"); + } + + #[test] + fn test_field80() { + discrete_fourier_transform_then_inv_test::().expect("unexpected error"); + } + + #[test] + fn test_field126() { + discrete_fourier_transform_then_inv_test::().expect("unexpected error"); + } + + #[test] + fn test_recursive_fft() { + let size = 128; + let mut rng = rand::thread_rng(); + let mut mem = PolyAuxMemory::new(size / 2); + + let mut inp = vec![Field::zero(); size]; + let mut want = vec![Field::zero(); size]; + let mut got = vec![Field::zero(); size]; + for i in 0..size { + inp[i] = Field::rand(&mut rng); + } + + discrete_fourier_transform::(&mut want, &inp).expect("unexpected error"); + + poly_fft( + &mut got, + &inp, + &mem.roots_2n, + size, + false, + &mut mem.fft_memory, + ); + + assert_eq!(got, want); + } +} diff --git a/src/finite_field.rs b/src/finite_field.rs index 0134a179a..676b8e1ea 100644 --- a/src/finite_field.rs +++ b/src/finite_field.rs @@ -5,11 +5,14 @@ use crate::fp::{FP126, FP32, FP64, FP80}; use std::{ + cmp::min, convert::TryFrom, - fmt::{Display, Formatter}, + fmt::{Debug, Display, Formatter}, ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign}, }; +use rand::Rng; + /// Possible errors from finite field operations. #[derive(Debug, thiserror::Error)] pub enum FiniteFieldError { @@ -21,24 +24,33 @@ pub enum FiniteFieldError { /// Objects with this trait represent an element of `GF(p)` for some prime `p`. pub trait FieldElement: Sized + + Debug + + Copy + PartialEq + Eq - + Add + + Add + AddAssign - + Sub + + Sub + SubAssign - + Mul + + Mul + MulAssign - + Div + + Div + DivAssign - + Neg + + Neg + Display + + From<::Integer> { + /// The error returned if converting `usize` to an `Int` fails. + type IntegerTryFromError: std::fmt::Debug; + /// The integer representation of the field element. - type Integer; + type Integer: Copy + + Debug + + Sub::Integer> + + TryFrom; /// Modular exponentation, i.e., `self^exp (mod p)`. - fn pow(&self, exp: Self) -> Self; + fn pow(&self, exp: Self) -> Self; // TODO(cjpatton) exp should have type Self::Integer /// Modular inversion, i.e., `self^-1 (mod p)`. If `self` is 0, then the output is undefined. fn inv(&self) -> Self; @@ -46,11 +58,21 @@ pub trait FieldElement: /// Returns the prime modulus `p`. fn modulus() -> Self::Integer; - /// Returns a generator of the multiplicative subgroup of size `FieldElement::num_roots()`. + /// Returns the size of the multiplicative subgroup generated by `generator()`. + fn generator_order() -> Self::Integer; + + /// Returns the generator of the multiplicative subgroup of size `generator_order()`. fn generator() -> Self; - /// Returns the size of the multiplicative subgroup generated by `FieldElement::generator()`. - fn num_roots() -> Self::Integer; + /// Returns the `2^l`-th principal root of unity for any `l <= 20`. Note that the `2^0`-th + /// prinicpal root of unity is 1 by definition. + fn root(l: usize) -> Option; + + /// Returns a random field element distributed uniformly over all field elements. + fn rand(rng: &mut R) -> Self; + + /// Returns the additive identity. + fn zero() -> Self; } macro_rules! make_field { @@ -190,6 +212,7 @@ macro_rules! make_field { impl FieldElement for $elem { type Integer = $int; + type IntegerTryFromError = >::Error; fn pow(&self, exp: Self) -> Self { Self($fp.pow(self.0, $fp.from_elem(exp.0))) @@ -207,8 +230,24 @@ macro_rules! make_field { Self($fp.g) } - fn num_roots() -> Self::Integer { - $fp.num_roots as Self::Integer + fn generator_order() -> Self::Integer { + 1 << (Self::Integer::try_from($fp.num_roots).unwrap()) + } + + fn root(l: usize) -> Option { + if l < min($fp.roots.len(), $fp.num_roots+1) { + Some(Self($fp.roots[l])) + } else { + None + } + } + + fn rand(rng: &mut R) -> Self { + Self($fp.rand_elem(rng)) + } + + fn zero() -> Self { + Self(0) } } }; @@ -245,7 +284,7 @@ make_field!( #[test] fn test_arithmetic() { - // TODO(cjpatton) Add tests for Field64, Field80, and Field126. + // TODO(cjpatton) Add tests for the other fields. use rand::prelude::*; let modulus = Field::modulus(); diff --git a/src/fp.rs b/src/fp.rs index 3a99db6db..29d435fa4 100644 --- a/src/fp.rs +++ b/src/fp.rs @@ -2,9 +2,13 @@ //! Finite field arithmetic for any field GF(p) for which p < 2^126. -#[cfg(test)] use rand::{prelude::*, Rng}; +/// For each set of field parameters we pre-compute the 1st, 2nd, 4th, ..., 2^20-th principal roots +/// of unity. The largest of these is used to run the FFT algorithm on an input of size 2^20. This +/// is the largest input size we would ever need for the cryptographic applications in this crate. +pub(crate) const MAX_ROOTS: usize = 20; + /// This structure represents the parameters of a finite field GF(p) for which p < 2^126. #[derive(Debug, PartialEq, Eq)] pub(crate) struct FieldParameters { @@ -16,49 +20,16 @@ pub(crate) struct FieldParameters { pub mu: u64, /// `r2 = (2^128)^2 mod p`. pub r2: u128, - /// The generator of a multiplicative subgroup of order `num_roots`. The value is mapped to the - /// Montgomeryh domain. + /// The `2^num_roots`-th -principal root of unity. This element is used to generate the + /// elements of `roots`. pub g: u128, - /// The order of the multiplicaitve subgroup of generated by `g`. - pub num_roots: u128, + /// The number of principal roots of unity in `roots`. + pub num_roots: usize, + /// `roots[l]` is the `2^l`-th principal root of unity, i.e., `roots[l]` has order `2^l` in the + /// multiplicative group. `root[l]` is equal to one by definition. + pub roots: [u128; MAX_ROOTS + 1], } -pub(crate) const FP32: FieldParameters = FieldParameters { - p: 4293918721, // 32-bit prime - p2: 8587837442, - mu: 17302828673139736575, - r2: 1676699750, - g: 1074114499, - num_roots: 1 << 20, -}; - -pub(crate) const FP64: FieldParameters = FieldParameters { - p: 15564440312192434177, // 64-bit prime - p2: 31128880624384868354, - mu: 15564440312192434175, - r2: 13031533328350459868, - g: 8693478717884812021, - num_roots: 1 << 59, -}; - -pub(crate) const FP80: FieldParameters = FieldParameters { - p: 779190469673491460259841, // 80-bit prime - p2: 1558380939346982920519682, - mu: 18446744073709551615, - r2: 699883506621195336351723, - g: 470015708362303528848629, - num_roots: 1 << 72, -}; - -pub(crate) const FP126: FieldParameters = FieldParameters { - p: 74769074762901517850839147140769382401, // 126-bit prime - p2: 149538149525803035701678294281538764802, - mu: 18446744073709551615, - r2: 27801541991839173768379182336352451464, - g: 63245316532470582112420298384754157617, - num_roots: 1 << 118, -}; - impl FieldParameters { /// Addition. pub fn add(&self, x: u128, y: u128) -> u128 { @@ -218,7 +189,6 @@ impl FieldParameters { } /// Returns a random field element mapped. - #[cfg(test)] pub fn rand_elem(&self, rng: &mut R) -> u128 { let uniform = rand::distributions::Uniform::from(0..self.p); self.elem(uniform.sample(rng)) @@ -241,23 +211,26 @@ impl FieldParameters { } #[cfg(test)] - pub fn new(p: u128, g: u128, num_roots: u128) -> Result { + pub fn check(&self, p: u128, g: u128, order: u128) { use modinverse::modinverse; use num_bigint::{BigInt, ToBigInt}; + use std::cmp::max; - let err_modulus_too_large = "p > 2^126"; if let Some(x) = p.checked_next_power_of_two() { if x > 1 << 126 { - return Err(err_modulus_too_large); + panic!("p >= 2^126"); } } else { - return Err(err_modulus_too_large); + panic!("p >= 2^126"); } + assert_eq!(self.p, p, "p mismatch"); + assert_eq!(self.p2, p << 1, "p2 mismatch"); let mu = match modinverse((-(p as i128)).rem_euclid(1 << 64), 1 << 64) { Some(mu) => mu as u64, - None => return Err("inverse of -p (mod 2^64) is undefined"), + None => panic!("inverse of -p (mod 2^64) is undefined"), }; + assert_eq!(self.mu, mu, "mu mismatch"); let big_p = &p.to_bigint().unwrap(); let big_r: &BigInt = &(&(BigInt::from(1) << 128) % big_p); @@ -268,18 +241,26 @@ impl FieldParameters { if let Some(x) = it.next() { r2 |= (x as u128) << 64; } - - let mut fp = FieldParameters { - p: p, - p2: p << 1, - mu: mu, - r2: r2, - g: 0, - num_roots: num_roots, - }; - - fp.g = fp.elem(g); - Ok(fp) + assert_eq!(self.r2, r2, "r2 mismatch"); + + assert_eq!(self.g, self.elem(g), "g mismatch"); + assert_eq!( + self.from_elem(self.pow(self.g, order)), + 1, + "g order incorrect" + ); + + let num_roots = log2(order) as usize; + assert_eq!(order, 1 << num_roots, "order not a power of 2"); + assert_eq!(self.num_roots, num_roots, "num_roots mismatch"); + + let mut roots = vec![0; max(num_roots, MAX_ROOTS) + 1]; + roots[num_roots] = self.elem(g); + for i in (0..num_roots).rev() { + roots[i] = self.mul(roots[i + 1], roots[i + 1]); + } + assert_eq!(&self.roots, &roots[..MAX_ROOTS + 1], "roots mismatch"); + assert_eq!(self.from_elem(self.roots[0]), 1, "first root is not one"); } } @@ -297,6 +278,121 @@ fn modp(x: u128, p: u128) -> u128 { z.wrapping_add(m & p) } +// Compute the ceiling of the base-2 logarithm of `x`. +pub(crate) fn log2(x: u128) -> u128 { + (128 - x.leading_zeros() - 1) as u128 +} + +pub(crate) const FP32: FieldParameters = FieldParameters { + p: 4293918721, // 32-bit prime + p2: 8587837442, + mu: 17302828673139736575, + r2: 1676699750, + g: 1074114499, + num_roots: 20, + roots: [ + 2564090464, 1729828257, 306605458, 2294308040, 1648889905, 57098624, 2788941825, + 2779858277, 368200145, 2760217336, 594450960, 4255832533, 1372848488, 721329415, + 3873251478, 1134002069, 7138597, 2004587313, 2989350643, 725214187, 1074114499, + ], +}; + +pub(crate) const FP64: FieldParameters = FieldParameters { + p: 15564440312192434177, // 64-bit prime + p2: 31128880624384868354, + mu: 15564440312192434175, + r2: 13031533328350459868, + g: 8693478717884812021, + num_roots: 59, + roots: [ + 3501465310287461188, + 12062975001904972989, + 14847933823983913979, + 5743743733744043357, + 12036183376424650304, + 1310071208805268988, + 351359342873885390, + 760642505652925971, + 8075457983432319221, + 14554120515039960006, + 9277695709938157757, + 5146332056710123439, + 9547487945110664452, + 1379816102304800478, + 8461341165309158767, + 12152693588256515089, + 9516424165972384563, + 8278889272850348764, + 6847784946159064188, + 875721217475244711, + 3028669228647031529, + ], +}; + +pub(crate) const FP80: FieldParameters = FieldParameters { + p: 779190469673491460259841, // 80-bit prime + p2: 1558380939346982920519682, + mu: 18446744073709551615, + r2: 699883506621195336351723, + g: 470015708362303528848629, + num_roots: 72, + roots: [ + 146393360532246310485619, + 632797109141245149774222, + 671768715528862959481, + 155287852188866912681838, + 84398650169430234366422, + 591732619446824370107997, + 369489067863767193117628, + 65351307276236357745139, + 250263845222966534834802, + 615370028124972287024172, + 428271082931219526829234, + 82144483146855494501530, + 655790508505248218964487, + 715547187733913654852114, + 29653674159319497967645, + 208078234303463777930443, + 495449125070884366403280, + 409220521346165172951210, + 134217175002192449913815, + 87718316256013518265593, + 261278801525790549618040, + ], +}; + +pub(crate) const FP126: FieldParameters = FieldParameters { + p: 74769074762901517850839147140769382401, // 126-bit prime + p2: 149538149525803035701678294281538764802, + mu: 18446744073709551615, + r2: 27801541991839173768379182336352451464, + g: 63245316532470582112420298384754157617, + num_roots: 118, + roots: [ + 41206067869332392060018018868690681852, + 33563006893569125790821128272078700549, + 9969209968386869007425498928188874206, + 26245577744033816872542400585149646017, + 53536320213034809573447803273264211942, + 27613962195776955012920796583378240442, + 32365734403831264958530930421153577004, + 13579354626561224539372784961933801433, + 57316758837288076943811104544124917759, + 70913423672054213072910590891105064074, + 71265034959502540558500186666669444000, + 34207722676470700263211551887273866594, + 37340170148681921863826823402458410577, + 35009585531414332540073382665488435215, + 70329412074928482115163094157328536788, + 39119429759852994810554872198104013087, + 47573549675073661838420354629772140200, + 77849817677037388106638164185970185092, + 37853717993704464400736177978677308170, + 83509620839139853788963077680031940984, + 64573608437864873942981348294630891347, + ], +}; + #[cfg(test)] mod tests { use super::*; @@ -346,12 +442,7 @@ mod tests { for t in test_fps.into_iter() { // Check that the field parameters have been constructed properly. - assert_eq!( - t.fp, - FieldParameters::new(t.expected_p, t.expected_g, t.expected_order).unwrap(), - "error for GF({})", - t.expected_p, - ); + t.fp.check(t.expected_p, t.expected_g, t.expected_order); // Check that the field element size is computed correctly. assert_eq!( @@ -365,13 +456,14 @@ mod tests { assert_eq!(t.fp.from_elem(t.fp.pow(t.fp.g, t.expected_order)), 1); // Test arithmetic using the field parameters. - test_arithmetic(t.fp); + test_arithmetic(&t.fp); } } - fn test_arithmetic(fp: FieldParameters) { + fn test_arithmetic(fp: &FieldParameters) { let mut rng = rand::thread_rng(); let big_p = &fp.p.to_bigint().unwrap(); + for _ in 0..100 { let x = fp.rand_elem(&mut rng); let y = fp.rand_elem(&mut rng); diff --git a/src/lib.rs b/src/lib.rs index 7481c809d..965c8ae50 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -9,8 +9,10 @@ //! //! For now we only support 0 / 1 vectors. +pub mod benchmarked; pub mod client; pub mod encrypt; +mod fft; pub mod finite_field; mod fp; mod polynomial; diff --git a/src/polynomial.rs b/src/polynomial.rs index b546997e6..6e4b8e7e4 100644 --- a/src/polynomial.rs +++ b/src/polynomial.rs @@ -114,7 +114,7 @@ fn fft_get_roots(count: usize, invert: bool) -> Vec { } roots[0] = 1.into(); - let step_size: u32 = (Field::num_roots() as u32) / (count as u32); + let step_size: u32 = (Field::generator_order() as u32) / (count as u32); // generator for subgroup of order count gen = gen.pow(step_size.into());