diff --git a/Cargo.toml b/Cargo.toml index f91fec4f2..73b486dbb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,6 +19,8 @@ thiserror = "1.0" [dev-dependencies] assert_matches = "1.5.0" +modinverse = "0.1.0" +num-bigint = "0.4.0" [[example]] name = "sum" diff --git a/src/finite_field.rs b/src/finite_field.rs index 59f74209c..3dc36fffa 100644 --- a/src/finite_field.rs +++ b/src/finite_field.rs @@ -3,6 +3,8 @@ //! Finite field arithmetic over a prime field using a 32bit prime. +use crate::fp::FieldParameters; + /// Possible errors from finite field operations. #[derive(Debug, thiserror::Error)] pub enum FiniteFieldError { @@ -11,14 +13,22 @@ pub enum FiniteFieldError { InputSizeMismatch, } -/// Newtype wrapper over u32 +/// Newtype wrapper over u128 /// /// Implements the arithmetic over the finite prime field -#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Default)] -pub struct Field(u32); +#[derive(Clone, Copy, Debug, PartialOrd, Ord, Hash, Default)] +pub struct Field(u128); + +/// Parameters for GF(2^32 - 2^20 + 1). +pub(crate) const SMALL_FP: FieldParameters = FieldParameters { + p: 4293918721, + p2: 8587837442, + mu: 17302828673139736575, + r2: 1676699750, +}; /// Modulus for the field, a FFT friendly prime: 2^32 - 2^20 + 1 -pub const MODULUS: u32 = 4293918721; +pub const MODULUS: u32 = SMALL_FP.p as u32; /// Generator for the multiplicative subgroup pub(crate) const GENERATOR: u32 = 3925978153; /// Number of primitive roots @@ -28,7 +38,7 @@ impl std::ops::Add for Field { type Output = Field; fn add(self, rhs: Self) -> Self { - self - Field(MODULUS - rhs.0) + Self(SMALL_FP.add(self.0, rhs.0)) } } @@ -42,14 +52,7 @@ impl std::ops::Sub for Field { type Output = Field; fn sub(self, rhs: Self) -> Self { - let l = self.0; - let r = rhs.0; - - if l >= r { - Field(l - r) - } else { - Field(MODULUS - r + l) - } + Self(SMALL_FP.sub(self.0, rhs.0)) } } @@ -62,12 +65,8 @@ impl std::ops::SubAssign for Field { impl std::ops::Mul for Field { type Output = Field; - #[allow(clippy::suspicious_arithmetic_impl)] fn mul(self, rhs: Self) -> Self { - let l = self.0 as u64; - let r = rhs.0 as u64; - let mul = l * r; - Field((mul % (MODULUS as u64)) as u32) + Self(SMALL_FP.mul(self.0, rhs.0)) } } @@ -80,7 +79,6 @@ impl std::ops::MulAssign for Field { impl std::ops::Div for Field { type Output = Field; - #[allow(clippy::suspicious_arithmetic_impl)] fn div(self, rhs: Self) -> Self { self * rhs.inv() } @@ -92,95 +90,76 @@ impl std::ops::DivAssign for Field { } } +impl PartialEq for Field { + fn eq(&self, rhs: &Self) -> bool { + SMALL_FP.from_elem(self.0) == SMALL_FP.from_elem(rhs.0) + } +} + +impl Eq for Field {} + impl Field { /// Modular exponentation pub fn pow(self, exp: Self) -> Self { - // repeated squaring - let mut base = self; - let mut exp = exp.0; - let mut result: Field = Field(1); - while exp > 0 { - while (exp & 1) == 0 { - exp /= 2; - base *= base; - } - exp -= 1; - result *= base; - } - result + Self(SMALL_FP.pow(self.0, SMALL_FP.from_elem(exp.0))) } /// Modular inverse /// /// Note: inverse of 0 is defined as 0. pub fn inv(self) -> Self { - // extended Euclidean - let mut x1: i32 = 1; - let mut a1: u32 = self.0; - let mut x0: i32 = 0; - let mut a2: u32 = MODULUS; - let mut q: u32 = 0; - - while a2 != 0 { - let x2 = x0 - (q as i32) * x1; - x0 = x1; - let a0 = a1; - x1 = x2; - a1 = a2; - q = a0 / a1; - a2 = a0 - q * a1; - } - if x1 < 0 { - let (r, _) = MODULUS.overflowing_add(x1 as u32); - Field(r) - } else { - Field(x1 as u32) - } + Self(SMALL_FP.inv(self.0)) } } impl From for Field { fn from(x: u32) -> Self { - Field(x % MODULUS) + Field(SMALL_FP.elem(x as u128)) } } impl From for u32 { fn from(x: Field) -> Self { - x.0 + SMALL_FP.from_elem(x.0) as u32 } } impl PartialEq for Field { fn eq(&self, rhs: &u32) -> bool { - self.0 == *rhs + SMALL_FP.from_elem(self.0) == *rhs as u128 } } impl std::fmt::Display for Field { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - write!(f, "{}", self.0) + write!(f, "{}", SMALL_FP.from_elem(self.0)) } } +#[test] +fn test_small_fp() { + assert_eq!(SMALL_FP.check(), Ok(())); +} + #[test] fn test_arithmetic() { use rand::prelude::*; + // add - assert_eq!(Field(MODULUS - 1) + Field(1), 0); - assert_eq!(Field(MODULUS - 2) + Field(2), 0); - assert_eq!(Field(MODULUS - 2) + Field(3), 1); - assert_eq!(Field(1) + Field(1), 2); - assert_eq!(Field(2) + Field(MODULUS), 2); - assert_eq!(Field(3) + Field(MODULUS - 1), 2); + assert_eq!(Field::from(MODULUS - 1) + Field::from(1), 0); + assert_eq!(Field::from(MODULUS - 2) + Field::from(2), 0); + assert_eq!(Field::from(MODULUS - 2) + Field::from(3), 1); + assert_eq!(Field::from(1) + Field::from(1), 2); + assert_eq!(Field::from(2) + Field::from(MODULUS), 2); + assert_eq!(Field::from(3) + Field::from(MODULUS - 1), 2); // sub - assert_eq!(Field(0) - Field(1), MODULUS - 1); - assert_eq!(Field(1) - Field(2), MODULUS - 1); - assert_eq!(Field(15) - Field(3), 12); - assert_eq!(Field(1) - Field(1), 0); - assert_eq!(Field(2) - Field(MODULUS), 2); - assert_eq!(Field(3) - Field(MODULUS - 1), 4); + assert_eq!(Field::from(0) - Field::from(1), MODULUS - 1); + assert_eq!(Field::from(1) - Field::from(2), MODULUS - 1); + assert_eq!(Field::from(15) - Field::from(3), 12); + assert_eq!(Field::from(1) - Field::from(1), 0); + assert_eq!(Field::from(2) - Field::from(MODULUS), 2); + assert_eq!(Field::from(3) - Field::from(MODULUS - 1), 4); // add + sub for _ in 0..100 { @@ -192,35 +171,35 @@ fn test_arithmetic() { } // mul - assert_eq!(Field(35) * Field(123), 4305); - assert_eq!(Field(1) * Field(MODULUS), 0); - assert_eq!(Field(0) * Field(123), 0); - assert_eq!(Field(123) * Field(0), 0); - assert_eq!(Field(123123123) * Field(123123123), 1237630077); + assert_eq!(Field::from(35) * Field::from(123), 4305); + assert_eq!(Field::from(1) * Field::from(MODULUS), 0); + assert_eq!(Field::from(0) * Field::from(123), 0); + assert_eq!(Field::from(123) * Field::from(0), 0); + assert_eq!(Field::from(123123123) * Field::from(123123123), 1237630077); // div - assert_eq!(Field(35) / Field(5), 7); - assert_eq!(Field(35) / Field(0), 0); - assert_eq!(Field(0) / Field(5), 0); - assert_eq!(Field(1237630077) / Field(123123123), 123123123); + assert_eq!(Field::from(35) / Field::from(5), 7); + assert_eq!(Field::from(35) / Field::from(0), 0); + assert_eq!(Field::from(0) / Field::from(5), 0); + assert_eq!(Field::from(1237630077) / Field::from(123123123), 123123123); - assert_eq!(Field(0).inv(), 0); + assert_eq!(Field::from(0).inv(), 0); // mul and div let uniform = rand::distributions::Uniform::from(1..MODULUS); let mut rng = thread_rng(); for _ in 0..100 { // non-zero element - let f = Field(uniform.sample(&mut rng)); + let f = Field::from(uniform.sample(&mut rng)); assert_eq!(f * f.inv(), 1); assert_eq!(f.inv() * f, 1); } // pow - assert_eq!(Field(2).pow(3.into()), 8); - assert_eq!(Field(3).pow(9.into()), 19683); - assert_eq!(Field(51).pow(27.into()), 3760729523); - assert_eq!(Field(432).pow(0.into()), 1); + assert_eq!(Field::from(2).pow(3.into()), 8); + assert_eq!(Field::from(3).pow(9.into()), 19683); + assert_eq!(Field::from(51).pow(27.into()), 3760729523); + assert_eq!(Field::from(432).pow(0.into()), 1); assert_eq!(Field(0).pow(123.into()), 0); } diff --git a/src/fp.rs b/src/fp.rs new file mode 100644 index 000000000..55850ce98 --- /dev/null +++ b/src/fp.rs @@ -0,0 +1,357 @@ +// Copyright (c) 2021 The Authors +// SPDX-License-Identifier: MPL-2.0 + +//! Finite field arithmetic for any field GF(p) for which p < 2^126. + +use rand::prelude::*; +use rand::Rng; + +/// This structure represents the parameters of a finite field GF(p) for which p < 2^126. +#[derive(Debug)] +pub(crate) struct FieldParameters { + /// The prime modulus `p`. + pub p: u128, + /// `p * 2`. + pub p2: u128, + /// `mu = -p^(-1) mod 2^64`. + pub mu: u64, + /// `r2 = (2^128)^2 mod p`. + pub r2: u128, +} + +impl FieldParameters { + /// Addition. + pub fn add(&self, x: u128, y: u128) -> u128 { + let (z, carry) = x.wrapping_add(y).overflowing_sub(self.p2); + let m = 0u128.wrapping_sub(carry as u128); + z.wrapping_add(m & self.p2) + } + + /// Subtraction. + pub fn sub(&self, x: u128, y: u128) -> u128 { + let (z, carry) = x.overflowing_sub(y); + let m = 0u128.wrapping_sub(carry as u128); + z.wrapping_add(m & self.p2) + } + + /// Multiplication of field elements in the Montgomery domain. This uses the REDC algorithm + /// described + /// [here](https://www.ams.org/journals/mcom/1985-44-170/S0025-5718-1985-0777282-X/S0025-5718-1985-0777282-X.pdfA). + /// + /// Example usage: + /// assert_eq!(fp.from_elem(fp.mul(fp.elem(23), fp.elem(2))), 46); + pub fn mul(&self, x: u128, y: u128) -> u128 { + let x = [lo64(x), hi64(x)]; + let y = [lo64(y), hi64(y)]; + let p = [lo64(self.p), hi64(self.p)]; + let mut zz = [0; 4]; + let mut result: u128; + let mut carry: u128; + let mut hi: u128; + let mut lo: u128; + let mut cc: u128; + + // Integer multiplication + result = x[0] * y[0]; + carry = hi64(result); + zz[0] = lo64(result); + result = x[0] * y[1]; + hi = hi64(result); + lo = lo64(result); + result = lo + carry; + zz[1] = lo64(result); + cc = hi64(result); + result = hi + cc; + zz[2] = lo64(result); + + result = x[1] * y[0]; + hi = hi64(result); + lo = lo64(result); + result = zz[1] + lo; + zz[1] = lo64(result); + cc = hi64(result); + result = hi + cc; + carry = lo64(result); + + result = x[1] * y[1]; + hi = hi64(result); + lo = lo64(result); + result = lo + carry; + lo = lo64(result); + cc = hi64(result); + result = hi + cc; + hi = lo64(result); + result = zz[2] + lo; + zz[2] = lo64(result); + cc = hi64(result); + result = hi + cc; + zz[3] = lo64(result); + + // Reduction + let w = self.mu.wrapping_mul(zz[0] as u64); + result = p[0] * (w as u128); + hi = hi64(result); + lo = lo64(result); + result = zz[0] + lo; + zz[0] = lo64(result); + cc = hi64(result); + result = hi + cc; + carry = lo64(result); + + result = p[1] * (w as u128); + hi = hi64(result); + lo = lo64(result); + result = lo + carry; + lo = lo64(result); + cc = hi64(result); + result = hi + cc; + hi = lo64(result); + result = zz[1] + lo; + zz[1] = lo64(result); + cc = hi64(result); + result = zz[2] + hi + cc; + zz[2] = lo64(result); + cc = hi64(result); + result = zz[3] + cc; + zz[3] = lo64(result); + + let w = self.mu.wrapping_mul(zz[1] as u64); + result = p[0] * (w as u128); + hi = hi64(result); + lo = lo64(result); + result = zz[1] + lo; + zz[1] = lo64(result); + cc = hi64(result); + result = hi + cc; + carry = lo64(result); + + result = p[1] * (w as u128); + hi = hi64(result); + lo = lo64(result); + result = lo + carry; + lo = lo64(result); + cc = hi64(result); + result = hi + cc; + hi = lo64(result); + result = zz[2] + lo; + zz[2] = lo64(result); + cc = hi64(result); + result = zz[3] + hi + cc; + zz[3] = lo64(result); + + zz[2] | (zz[3] << 64) + } + + /// Modular exponentiation, i.e., `x^exp (mod p)` where `p` is the modulus. Note that the + /// runtime of this algorithm is linear in the bit length of `exp`. + pub fn pow(&self, x: u128, exp: u128) -> u128 { + let mut t = self.elem(1); + for i in (0..128).rev() { + t = self.mul(t, t); + if (exp >> i) & 1 != 0 { + t = self.mul(t, x); + } + } + t + } + + /// Modular inversion, i.e., x^-1 (mod p) where `p` is the modulu. Note that the runtime of + /// this algorithm is linear in the bit length of `p`. + pub fn inv(&self, x: u128) -> u128 { + self.pow(x, self.p - 2) + } + + /// Negation, i.e., `-x (mod p)` where `p` is the modulus. + pub fn neg(&self, x: u128) -> u128 { + self.sub(0, x) + } + + /// Maps an integer to its internal representation. Field elements are mapped to the Montgomery + /// domain in order to carry out field arithmetic. + /// + /// Example usage: + /// let integer = 1; // Standard integer representation + /// let elem = fp.elem(integer); // Internal representation in the Montgomery domain + /// assert_eq!(elem, 2564090464); + pub fn elem(&self, x: u128) -> u128 { + modp(self.mul(x, self.r2), self.p) + } + + /// Returns a random field element mapped. + pub fn rand_elem(&self, rng: &mut R) -> u128 { + let uniform = rand::distributions::Uniform::from(0..self.p); + self.elem(uniform.sample(rng)) + } + + /// Maps a field element to its representation as an integer. + /// + /// Example usage: + /// let elem = 2564090464; // Internal representation in the Montgomery domain + /// let integer = fp.from_elem(elem); // Standard integer representation + /// assert_eq!(integer, 1); + pub fn from_elem(&self, x: u128) -> u128 { + modp(self.mul(x, 1), self.p) + } + + /// Returns the number of bytes required to encode field elements. + pub fn size(&self) -> usize { + (16 - (self.p.leading_zeros() / 8)) as usize + } + + #[cfg(test)] + pub fn check(&self) -> Result<(), &'static str> { + use modinverse::modinverse; + use num_bigint::{BigInt, ToBigInt}; + + let err_modulus_too_large = "p > 2^126"; + if let Some(x) = self.p.checked_next_power_of_two() { + if x > 1 << 126 { + return Err(err_modulus_too_large); + } + } else { + return Err(err_modulus_too_large); + } + + if self.p2 != self.p << 1 { + return Err("value of p2 is incorrect"); + } + + let mu = match modinverse((-(self.p as i128)).rem_euclid(1 << 64), 1 << 64) { + Some(mu) => mu as u64, + None => return Err("inverse of -p (mod 2^64) is undefined"), + }; + if self.mu != mu { + return Err("value of mu is incorrect"); + } + + let big_p = &self.p.to_bigint().unwrap(); + let big_r: &BigInt = &(&(BigInt::from(1) << 128) % big_p); + let big_r2: &BigInt = &(&(big_r * big_r) % big_p); + let mut it = big_r2.iter_u64_digits(); + let mut r2 = 0; + r2 |= it.next().unwrap() as u128; + if let Some(x) = it.next() { + r2 |= (x as u128) << 64; + } + if self.r2 != r2 { + return Err("value of r2 is not correct"); + } + + Ok(()) + } +} + +fn lo64(x: u128) -> u128 { + x & ((1 << 64) - 1) +} + +fn hi64(x: u128) -> u128 { + x >> 64 +} + +fn modp(x: u128, p: u128) -> u128 { + let (z, carry) = x.overflowing_sub(p); + let m = 0u128.wrapping_sub(carry as u128); + z.wrapping_add(m & p) +} + +#[cfg(test)] +mod tests { + use super::*; + use modinverse::modinverse; + use num_bigint::ToBigInt; + + struct TestFieldParametersData { + fp: FieldParameters, + expected_size: usize, + } + + #[test] + fn test_fp() { + let mut rng = rand::thread_rng(); + let test_fps = vec![ + TestFieldParametersData { + fp: FieldParameters { + p: 4293918721, // 32-bit prime + p2: 8587837442, + mu: 17302828673139736575, + r2: 1676699750, + }, + expected_size: 4, + }, + TestFieldParametersData { + fp: FieldParameters { + p: 15564440312192434177, // 64-bit prime + p2: 31128880624384868354, + mu: 15564440312192434175, + r2: 13031533328350459868, + }, + expected_size: 8, + }, + TestFieldParametersData { + fp: FieldParameters { + p: 779190469673491460259841, // 80-bit prime + p2: 1558380939346982920519682, + mu: 18446744073709551615, + r2: 699883506621195336351723, + }, + expected_size: 10, + }, + TestFieldParametersData { + fp: FieldParameters { + p: 74769074762901517850839147140769382401, // 126-bit prime + p2: 149538149525803035701678294281538764802, + mu: 18446744073709551615, + r2: 27801541991839173768379182336352451464, + }, + expected_size: 16, + }, + ]; + + for t in test_fps.into_iter() { + let fp = t.fp; + assert_eq!(fp.size(), t.expected_size); + assert_eq!(fp.check(), Ok(())); + + // Test arithmetic. + let big_p = &fp.p.to_bigint().unwrap(); + for _ in 0..100 { + let x = fp.rand_elem(&mut rng); + let y = fp.rand_elem(&mut rng); + let big_x = &fp.from_elem(x).to_bigint().unwrap(); + let big_y = &fp.from_elem(y).to_bigint().unwrap(); + + // Test addition. + let got = fp.add(x, y); + let want = (big_x + big_y) % big_p; + assert_eq!(fp.from_elem(got).to_bigint().unwrap(), want); + + // Test subtraction. + let got = fp.sub(x, y); + let want = if big_x >= big_y { + big_x - big_y + } else { + big_p - big_y + big_x + }; + assert_eq!(fp.from_elem(got).to_bigint().unwrap(), want); + + // Test multiplication. + let got = fp.mul(x, y); + let want = (big_x * big_y) % big_p; + assert_eq!(fp.from_elem(got).to_bigint().unwrap(), want); + + // Test inversion. + let got = fp.inv(x); + let want = modinverse(fp.from_elem(x) as i128, fp.p as i128).unwrap(); + assert_eq!(fp.from_elem(got) as i128, want); + assert_eq!(fp.from_elem(fp.mul(got, x)), 1); + + // Test negation. + let got = fp.neg(x); + let want = (-(fp.from_elem(x) as i128)).rem_euclid(fp.p as i128); + assert_eq!(fp.from_elem(got) as i128, want); + assert_eq!(fp.from_elem(fp.add(got, x)), 0); + } + } + } +} diff --git a/src/lib.rs b/src/lib.rs index 2b89da5fa..7481c809d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -12,6 +12,7 @@ pub mod client; pub mod encrypt; pub mod finite_field; +mod fp; mod polynomial; mod prng; pub mod server; diff --git a/src/prng.rs b/src/prng.rs index f701cfa13..b98df0352 100644 --- a/src/prng.rs +++ b/src/prng.rs @@ -1,7 +1,7 @@ // Copyright (c) 2020 Apple Inc. // SPDX-License-Identifier: MPL-2.0 -use super::finite_field::{Field, MODULUS}; +use super::finite_field::{Field, MODULUS, SMALL_FP}; use aes_ctr::stream_cipher::generic_array::GenericArray; use aes_ctr::stream_cipher::NewStreamCipher; use aes_ctr::stream_cipher::SyncStreamCipher; @@ -61,7 +61,7 @@ fn random_field_from_seed(seed: &[u8], length: usize) -> Vec { cipher.apply_keystream(&mut buffer); // rejection sampling - for chunk in buffer.chunks_exact(std::mem::size_of::()) { + for chunk in buffer.chunks_exact(SMALL_FP.size()) { let integer = u32::from_le_bytes(chunk.try_into().unwrap()); if integer < MODULUS { output[output_written] = Field::from(integer); diff --git a/src/util.rs b/src/util.rs index afae2155e..de717dd39 100644 --- a/src/util.rs +++ b/src/util.rs @@ -3,7 +3,7 @@ //! Utility functions for handling Prio stuff. -use crate::finite_field::Field; +use crate::finite_field::{Field, SMALL_FP}; /// Convenience function for initializing fixed sized vectors of Field elements. pub fn vector_with_length(len: usize) -> Vec { @@ -110,7 +110,7 @@ pub fn serialize(data: &[Field]) -> Vec { /// Get a vector of field elements from a byte slice pub fn deserialize(data: &[u8]) -> Vec { - let field_size = std::mem::size_of::(); + let field_size = SMALL_FP.size(); let mut vec = Vec::with_capacity(data.len() / field_size); use std::convert::TryInto;