From 4533590a2e0f413baaa388eddac8cf9c823ceffc Mon Sep 17 00:00:00 2001 From: Christopher Patton Date: Sun, 4 Apr 2021 08:12:52 -0700 Subject: [PATCH] Use any FieldElement and clean up API This change generalizes the code for generating and validating proofs to work with any implementation of the FieldElement trait. It also makes the following API-breaking changes: * Rename Field to Field32 in order to signify that it implements a 32-bit field * Rename FiniteFieldError to FieldError * Move finite_field to field * Simplify the FieldElement API a bit --- benches/fft.rs | 9 +- examples/sum.rs | 12 +- src/benchmarked.rs | 4 +- src/client.rs | 105 ++++++----- src/fft.rs | 22 ++- src/{finite_field.rs => field.rs} | 279 ++++++++++++++++++++---------- src/fp.rs | 23 +-- src/lib.rs | 2 +- src/polynomial.rs | 136 ++++++++------- src/prng.rs | 45 +++-- src/server.rs | 81 +++++---- src/util.rs | 109 ++++++------ tests/accumulating.rs | 10 +- tests/tweaks.rs | 12 +- 14 files changed, 483 insertions(+), 366 deletions(-) rename src/{finite_field.rs => field.rs} (55%) diff --git a/benches/fft.rs b/benches/fft.rs index e51276e0a..818b438ba 100644 --- a/benches/fft.rs +++ b/benches/fft.rs @@ -3,16 +3,15 @@ use criterion::{criterion_group, criterion_main, Criterion}; use prio::benchmarked::{benchmarked_iterative_fft, benchmarked_recursive_fft}; -use prio::finite_field::{Field, FieldElement}; +use prio::field::{Field126, FieldElement}; pub fn fft(c: &mut Criterion) { let test_sizes = [16, 256, 1024, 4096]; for size in test_sizes.iter() { - let mut rng = rand::thread_rng(); - let mut inp = vec![Field::zero(); *size]; - let mut outp = vec![Field::zero(); *size]; + let mut inp = vec![Field126::zero(); *size]; + let mut outp = vec![Field126::zero(); *size]; for i in 0..*size { - inp[i] = Field::rand(&mut rng); + inp[i] = Field126::rand(); } c.bench_function(&format!("iterative/{}", *size), |b| { diff --git a/examples/sum.rs b/examples/sum.rs index 063cd09d4..d48230068 100644 --- a/examples/sum.rs +++ b/examples/sum.rs @@ -2,7 +2,7 @@ use prio::client::*; use prio::encrypt::*; -use prio::finite_field::*; +use prio::field::*; use prio::server::*; fn main() { @@ -30,20 +30,20 @@ fn main() { let data1 = data1_u32 .iter() - .map(|x| Field::from(*x)) - .collect::>(); + .map(|x| Field32::from(*x)) + .collect::>(); let data2_u32 = [0, 0, 1, 0, 0, 0, 0, 0]; println!("Client 2 Input: {:?}", data2_u32); let data2 = data2_u32 .iter() - .map(|x| Field::from(*x)) - .collect::>(); + .map(|x| Field32::from(*x)) + .collect::>(); let (share1_1, share1_2) = client1.encode_simple(&data1).unwrap(); let (share2_1, share2_2) = client2.encode_simple(&data2).unwrap(); - let eval_at = Field::from(12313); + let eval_at = Field32::from(12313); let mut server1 = Server::new(dim, true, priv_key1.clone()); let mut server2 = Server::new(dim, false, priv_key2.clone()); diff --git a/src/benchmarked.rs b/src/benchmarked.rs index ffd75df4a..b831e3579 100644 --- a/src/benchmarked.rs +++ b/src/benchmarked.rs @@ -4,7 +4,7 @@ //! benchmark, but which we don't want to expose in the public API. use crate::fft::discrete_fourier_transform; -use crate::finite_field::{Field, FieldElement}; +use crate::field::FieldElement; use crate::polynomial::{poly_fft, PolyAuxMemory}; /// Sets `outp` to the Discrete Fourier Transform (DFT) using an iterative FFT algorithm. @@ -13,7 +13,7 @@ pub fn benchmarked_iterative_fft(outp: &mut [F], inp: &[F]) { } /// Sets `outp` to the Discrete Fourier Transform (DFT) using a recursive FFT algorithm. -pub fn benchmarked_recursive_fft(outp: &mut [Field], inp: &[Field]) { +pub fn benchmarked_recursive_fft(outp: &mut [F], inp: &[F]) { let mut mem = PolyAuxMemory::new(inp.len() / 2); poly_fft( outp, diff --git a/src/client.rs b/src/client.rs index d7209b2ef..b69b2e6e7 100644 --- a/src/client.rs +++ b/src/client.rs @@ -4,36 +4,61 @@ //! Prio client use crate::encrypt::*; -use crate::finite_field::*; +use crate::field::FieldElement; use crate::polynomial::*; use crate::util::*; +use std::convert::TryFrom; + /// The main object that can be used to create Prio shares /// /// Client is used to create Prio shares. #[derive(Debug)] -pub struct Client { +pub struct Client { dimension: usize, - points_f: Vec, - points_g: Vec, - evals_f: Vec, - evals_g: Vec, - poly_mem: PolyAuxMemory, + points_f: Vec, + points_g: Vec, + evals_f: Vec, + evals_g: Vec, + poly_mem: PolyAuxMemory, public_key1: PublicKey, public_key2: PublicKey, } -impl Client { +/// Errors that might be emitted by the client. +#[derive(Debug, thiserror::Error)] +pub enum ClientError { + /// Thes error is output by `Client::new()` if the length of the proof would exceed the + /// number of roots of unity that can be generated in the field. + #[error("input size exceeds field capacity")] + InputSizeExceedsFieldCapacity, + /// Thes error is output by `Client::new()` if the length of the proof would exceed the + /// ssytem's addressible memory. + #[error("input size exceeds field capacity")] + InputSizeExceedsMemoryCapacity, + /// Encryption/decryption error + #[error("encryption/decryption error")] + Encrypt(#[from] EncryptError), +} + +impl Client { /// Construct a new Prio client - pub fn new(dimension: usize, public_key1: PublicKey, public_key2: PublicKey) -> Option { + pub fn new( + dimension: usize, + public_key1: PublicKey, + public_key2: PublicKey, + ) -> Result { let n = (dimension + 1).next_power_of_two(); - if 2 * n > Field::generator_order() as usize { - // too many elements for this field, not enough roots of unity - return None; + if let Ok(size) = F::Integer::try_from(2 * n) { + if size > F::generator_order() { + return Err(ClientError::InputSizeExceedsFieldCapacity); + } + } else { + return Err(ClientError::InputSizeExceedsMemoryCapacity); } - Some(Client { + Ok(Client { dimension, points_f: vector_with_length(n), points_g: vector_with_length(n), @@ -46,20 +71,20 @@ impl Client { } /// Construct a pair of encrypted shares based on the input data. - pub fn encode_simple(&mut self, data: &[Field]) -> Result<(Vec, Vec), EncryptError> { - let copy_data = |share_data: &mut [Field]| { + pub fn encode_simple(&mut self, data: &[F]) -> Result<(Vec, Vec), ClientError> { + let copy_data = |share_data: &mut [F]| { share_data[..].clone_from_slice(data); }; - self.encode_with(copy_data) + Ok(self.encode_with(copy_data)?) } /// Construct a pair of encrypted shares using a initilization function. /// /// This might be slightly more efficient on large vectors, because one can /// avoid copying the input data. - pub fn encode_with(&mut self, init_function: F) -> Result<(Vec, Vec), EncryptError> + pub fn encode_with(&mut self, init_function: G) -> Result<(Vec, Vec), EncryptError> where - F: FnOnce(&mut [Field]), + G: FnOnce(&mut [F]), { let mut proof = vector_with_length(proof_length(self.dimension)); // unpack one long vector to different subparts @@ -90,21 +115,21 @@ impl Client { /// Convenience function if one does not want to reuse /// [`Client`](struct.Client.html). -pub fn encode_simple( - data: &[Field], +pub fn encode_simple( + data: &[F], public_key1: PublicKey, public_key2: PublicKey, -) -> Option<(Vec, Vec)> { +) -> Result<(Vec, Vec), ClientError> { let dimension = data.len(); let mut client_memory = Client::new(dimension, public_key1, public_key2)?; - client_memory.encode_simple(data).ok() + client_memory.encode_simple(data) } -fn interpolate_and_evaluate_at_2n( +fn interpolate_and_evaluate_at_2n( n: usize, - points_in: &[Field], - evals_out: &mut [Field], - mem: &mut PolyAuxMemory, + points_in: &[F], + evals_out: &mut [F], + mem: &mut PolyAuxMemory, ) { // interpolate through roots of unity poly_fft( @@ -130,20 +155,20 @@ fn interpolate_and_evaluate_at_2n( /// /// Based on Theorem 2.3.3 from Henry Corrigan-Gibbs' dissertation /// This constructs the output \pi by doing the necessesary calculations -fn construct_proof( - data: &[Field], +fn construct_proof( + data: &[F], dimension: usize, - f0: &mut Field, - g0: &mut Field, - h0: &mut Field, - points_h_packed: &mut [Field], - mem: &mut Client, + f0: &mut F, + g0: &mut F, + h0: &mut F, + points_h_packed: &mut [F], + mem: &mut Client, ) { let n = (dimension + 1).next_power_of_two(); // set zero terms to random - *f0 = Field::from(rand::random::()); - *g0 = Field::from(rand::random::()); + *f0 = F::rand(); + *g0 = F::rand(); mem.points_f[0] = *f0; mem.points_g[0] = *g0; @@ -154,7 +179,7 @@ fn construct_proof( // set g_i = f_i - 1 for i in 0..dimension { mem.points_f[i + 1] = data[i]; - mem.points_g[i + 1] = data[i] - 1.into(); + mem.points_g[i + 1] = data[i] - F::one(); } // interpolate and evaluate at roots of unity @@ -174,6 +199,8 @@ fn construct_proof( #[test] fn test_encode() { + use crate::field::Field32; + let pub_key1 = PublicKey::from_base64( "BIl6j+J6dYttxALdjISDv6ZI4/VWVEhUzaS05LgrsfswmbLOgNt9HUC2E0w+9RqZx3XMkdEHBHfNuCSMpOwofVQ=", ) @@ -186,8 +213,8 @@ fn test_encode() { let data_u32 = [0u32, 1, 0, 1, 1, 0, 0, 0, 1]; let data = data_u32 .iter() - .map(|x| Field::from(*x)) - .collect::>(); + .map(|x| Field32::from(*x)) + .collect::>(); let encoded_shares = encode_simple(&data, pub_key1, pub_key2); - assert_eq!(encoded_shares.is_some(), true); + assert_eq!(encoded_shares.is_ok(), true); } diff --git a/src/fft.rs b/src/fft.rs index 370ce2fff..763b68fe7 100644 --- a/src/fft.rs +++ b/src/fft.rs @@ -3,7 +3,7 @@ //! This module implements an iterative FFT algorithm for computing the (inverse) Discrete Fourier //! Transform (DFT) over a slice of field elements. -use crate::finite_field::FieldElement; +use crate::field::FieldElement; use crate::fp::{log2, MAX_ROOTS}; use std::convert::TryFrom; @@ -48,7 +48,7 @@ pub fn discrete_fourier_transform( let mut w: F; for l in 1..d + 1 { - w = F::root(0).unwrap(); // one + w = F::one(); let r = F::root(l).unwrap(); let y = 1 << (l - 1); for i in 0..y { @@ -100,11 +100,10 @@ fn bitrev(d: usize, x: usize) -> usize { #[cfg(test)] mod tests { use super::*; - use crate::finite_field::{Field, Field126, Field64, Field80}; + use crate::field::{Field126, Field32, Field64, Field80}; use crate::polynomial::{poly_fft, PolyAuxMemory}; fn discrete_fourier_transform_then_inv_test() -> Result<(), FftError> { - let mut rng = rand::thread_rng(); let test_sizes = [1, 2, 4, 8, 16, 256, 1024, 2048]; for size in test_sizes.iter() { @@ -112,7 +111,7 @@ mod tests { let mut tmp = vec![F::zero(); *size]; let mut got = vec![F::zero(); *size]; for i in 0..*size { - want[i] = F::rand(&mut rng); + want[i] = F::rand(); } discrete_fourier_transform(&mut tmp, &want)?; @@ -125,7 +124,7 @@ mod tests { #[test] fn test_field32() { - discrete_fourier_transform_then_inv_test::().expect("unexpected error"); + discrete_fourier_transform_then_inv_test::().expect("unexpected error"); } #[test] @@ -146,17 +145,16 @@ mod tests { #[test] fn test_recursive_fft() { let size = 128; - let mut rng = rand::thread_rng(); let mut mem = PolyAuxMemory::new(size / 2); - let mut inp = vec![Field::zero(); size]; - let mut want = vec![Field::zero(); size]; - let mut got = vec![Field::zero(); size]; + let mut inp = vec![Field32::zero(); size]; + let mut want = vec![Field32::zero(); size]; + let mut got = vec![Field32::zero(); size]; for i in 0..size { - inp[i] = Field::rand(&mut rng); + inp[i] = Field32::rand(); } - discrete_fourier_transform::(&mut want, &inp).expect("unexpected error"); + discrete_fourier_transform::(&mut want, &inp).expect("unexpected error"); poly_fft( &mut got, diff --git a/src/finite_field.rs b/src/field.rs similarity index 55% rename from src/finite_field.rs rename to src/field.rs index 676b8e1ea..18cb6c225 100644 --- a/src/finite_field.rs +++ b/src/field.rs @@ -8,17 +8,21 @@ use std::{ cmp::min, convert::TryFrom, fmt::{Debug, Display, Formatter}, - ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign}, + ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Shr, Sub, SubAssign}, }; -use rand::Rng; - /// Possible errors from finite field operations. #[derive(Debug, thiserror::Error)] -pub enum FiniteFieldError { +pub enum FieldError { /// Input sizes do not match #[error("input sizes do not match")] InputSizeMismatch, + /// Returned by `FieldElement::read_from()` if the input buffer is too short. + #[error("short read from byte slice")] + FromBytesShortRead, + /// Returned by `FieldElement::read_from()` if the input is larger than the modulus. + #[error("read from byte slice exceeds modulus")] + FromBytesModulusOverflow, } /// Objects with this trait represent an element of `GF(p)` for some prime `p`. @@ -40,17 +44,23 @@ pub trait FieldElement: + Display + From<::Integer> { + /// Size of each field element in bytes. + const BYTES: usize; + /// The error returned if converting `usize` to an `Int` fails. type IntegerTryFromError: std::fmt::Debug; /// The integer representation of the field element. type Integer: Copy + Debug + + PartialOrd + + Div::Integer> + + Shr::Integer> + Sub::Integer> + TryFrom; /// Modular exponentation, i.e., `self^exp (mod p)`. - fn pow(&self, exp: Self) -> Self; // TODO(cjpatton) exp should have type Self::Integer + fn pow(&self, exp: Self::Integer) -> Self; /// Modular inversion, i.e., `self^-1 (mod p)`. If `self` is 0, then the output is undefined. fn inv(&self) -> Self; @@ -58,6 +68,16 @@ pub trait FieldElement: /// Returns the prime modulus `p`. fn modulus() -> Self::Integer; + /// Writes the field element to the end of input buffer. + /// + /// TODO(acmiyaguchi) Replace this with an implementation of the corresponding serde trait + fn append_to(&self, bytes: &mut Vec); + + /// Reads the next field element from the buffer. + /// + /// TODO(acmiyaguchi) Replace this with an implementation of the corresponding serde trait + fn read_from(bytes: &[u8]) -> Result; + /// Returns the size of the multiplicative subgroup generated by `generator()`. fn generator_order() -> Self::Integer; @@ -69,16 +89,19 @@ pub trait FieldElement: fn root(l: usize) -> Option; /// Returns a random field element distributed uniformly over all field elements. - fn rand(rng: &mut R) -> Self; + fn rand() -> Self; /// Returns the additive identity. fn zero() -> Self; + + /// Returns the multiplicative identity. + fn one() -> Self; } macro_rules! make_field { ( $(#[$meta:meta])* - $elem:ident, $int:ident, $fp:ident + $elem:ident, $int:ident, $fp:ident, $bytes:literal ) => { $(#[$meta])* #[derive(Clone, Copy, Debug, PartialOrd, Ord, Hash, Default)] @@ -211,11 +234,12 @@ macro_rules! make_field { } impl FieldElement for $elem { + const BYTES: usize = $bytes; type Integer = $int; type IntegerTryFromError = >::Error; - fn pow(&self, exp: Self) -> Self { - Self($fp.pow(self.0, $fp.from_elem(exp.0))) + fn pow(&self, exp: Self::Integer) -> Self { + Self($fp.pow(self.0, u128::try_from(exp).unwrap())) } fn inv(&self) -> Self { @@ -226,6 +250,31 @@ macro_rules! make_field { $fp.p as $int } + fn append_to(&self, bytes: &mut Vec) { + let int = $fp.from_elem(self.0); + let mut slice = [0; Self::BYTES]; + for i in 0..Self::BYTES { + slice[i] = ((int >> (i << 3)) & 0xff) as u8; + } + bytes.extend_from_slice(&slice); + } + + fn read_from(bytes: &[u8]) -> Result { + if Self::BYTES > bytes.len() { + return Err(FieldError::FromBytesShortRead); + } + + let mut int = 0; + for i in 0..Self::BYTES { + int |= (bytes[i] as u128) << (i << 3); + } + + if int >= $fp.p { + return Err(FieldError::FromBytesModulusOverflow); + } + Ok(Self($fp.elem(int))) + } + fn generator() -> Self { Self($fp.g) } @@ -242,122 +291,65 @@ macro_rules! make_field { } } - fn rand(rng: &mut R) -> Self { - Self($fp.rand_elem(rng)) + fn rand() -> Self { + let mut rng = rand::thread_rng(); + Self($fp.rand_elem(&mut rng)) } fn zero() -> Self { Self(0) } + + fn one() -> Self { + Self($fp.roots[0]) + } } }; } -// TODO(cjpatton) Rename this to Field32. make_field!( /// `GF(4293918721)`, a 32-bit field. The generator has order `2^20`. - Field, + Field32, u32, - FP32 + FP32, + 4 ); make_field!( /// `GF(15564440312192434177)`, a 64-bit field. The generator has order `2^59`. Field64, u64, - FP64 + FP64, + 8 ); make_field!( /// `GF(779190469673491460259841)`, an 80-bit field. The generator has order `2^72`. Field80, u128, - FP80 + FP80, + 10 ); make_field!( /// `GF(74769074762901517850839147140769382401)`, a 126-bit field. The generator has order `2^118`. Field126, u128, - FP126 + FP126, + 16 ); -#[test] -fn test_arithmetic() { - // TODO(cjpatton) Add tests for the other fields. - use rand::prelude::*; - - let modulus = Field::modulus(); - - // add - assert_eq!(Field::from(modulus - 1) + Field::from(1), 0); - assert_eq!(Field::from(modulus - 2) + Field::from(2), 0); - assert_eq!(Field::from(modulus - 2) + Field::from(3), 1); - assert_eq!(Field::from(1) + Field::from(1), 2); - assert_eq!(Field::from(2) + Field::from(modulus), 2); - assert_eq!(Field::from(3) + Field::from(modulus - 1), 2); - - // sub - assert_eq!(Field::from(0) - Field::from(1), modulus - 1); - assert_eq!(Field::from(1) - Field::from(2), modulus - 1); - assert_eq!(Field::from(15) - Field::from(3), 12); - assert_eq!(Field::from(1) - Field::from(1), 0); - assert_eq!(Field::from(2) - Field::from(modulus), 2); - assert_eq!(Field::from(3) - Field::from(modulus - 1), 4); - - // add + sub - for _ in 0..100 { - let f = Field::from(random::()); - let g = Field::from(random::()); - assert_eq!(f + g - f - g, 0); - assert_eq!(f + g - g, f); - assert_eq!(f + g - f, g); - } - - // mul - assert_eq!(Field::from(35) * Field::from(123), 4305); - assert_eq!(Field::from(1) * Field::from(modulus), 0); - assert_eq!(Field::from(0) * Field::from(123), 0); - assert_eq!(Field::from(123) * Field::from(0), 0); - assert_eq!(Field::from(123123123) * Field::from(123123123), 1237630077); - - // div - assert_eq!(Field::from(35) / Field::from(5), 7); - assert_eq!(Field::from(35) / Field::from(0), 0); - assert_eq!(Field::from(0) / Field::from(5), 0); - assert_eq!(Field::from(1237630077) / Field::from(123123123), 123123123); - - assert_eq!(Field::from(0).inv(), 0); - - // mul and div - let uniform = rand::distributions::Uniform::from(1..modulus); - let mut rng = thread_rng(); - for _ in 0..100 { - // non-zero element - let f = Field::from(uniform.sample(&mut rng)); - assert_eq!(f * f.inv(), 1); - assert_eq!(f.inv() * f, 1); - } - - // pow - assert_eq!(Field::from(2).pow(3.into()), 8); - assert_eq!(Field::from(3).pow(9.into()), 19683); - assert_eq!(Field::from(51).pow(27.into()), 3760729523); - assert_eq!(Field::from(432).pow(0.into()), 1); - assert_eq!(Field(0).pow(123.into()), 0); -} - /// Merge two vectors of fields by summing other_vector into accumulator. /// /// # Errors /// /// Fails if the two vectors do not have the same length. -pub fn merge_vector( - accumulator: &mut [Field], - other_vector: &[Field], -) -> Result<(), FiniteFieldError> { +pub fn merge_vector( + accumulator: &mut [F], + other_vector: &[F], +) -> Result<(), FieldError> { if accumulator.len() != other_vector.len() { - return Err(FiniteFieldError::InputSizeMismatch); + return Err(FieldError::InputSizeMismatch); } for (a, o) in accumulator.iter_mut().zip(other_vector.iter()) { *a += *o; @@ -369,23 +361,126 @@ pub fn merge_vector( #[cfg(test)] mod tests { use super::*; + use crate::fp::MAX_ROOTS; use crate::util::vector_with_length; use assert_matches::assert_matches; #[test] fn test_accumulate() { let mut lhs = vector_with_length(10); - lhs.iter_mut().for_each(|f| *f = Field(1)); + lhs.iter_mut().for_each(|f| *f = Field32(1)); let mut rhs = vector_with_length(10); - rhs.iter_mut().for_each(|f| *f = Field(2)); + rhs.iter_mut().for_each(|f| *f = Field32(2)); merge_vector(&mut lhs, &rhs).unwrap(); - lhs.iter().for_each(|f| assert_eq!(*f, Field(3))); - rhs.iter().for_each(|f| assert_eq!(*f, Field(2))); + lhs.iter().for_each(|f| assert_eq!(*f, Field32(3))); + rhs.iter().for_each(|f| assert_eq!(*f, Field32(2))); let wrong_len = vector_with_length(9); let result = merge_vector(&mut lhs, &wrong_len); - assert_matches!(result, Err(FiniteFieldError::InputSizeMismatch)); + assert_matches!(result, Err(FieldError::InputSizeMismatch)); + } + + fn field_element_test() { + let int_modulus = F::modulus(); + let int_one = F::Integer::try_from(1).unwrap(); + let zero = F::zero(); + let one = F::one(); + let two = F::from(F::Integer::try_from(2).unwrap()); + let four = F::from(F::Integer::try_from(4).unwrap()); + + // add + assert_eq!(F::from(int_modulus - int_one) + one, zero); + assert_eq!(one + one, two); + assert_eq!(two + F::from(int_modulus), two); + + // sub + assert_eq!(zero - one, F::from(int_modulus - int_one)); + assert_eq!(one - one, zero); + assert_eq!(two - F::from(int_modulus), two); + assert_eq!(one - F::from(int_modulus - int_one), two); + + // add + sub + for _ in 0..100 { + let f = F::rand(); + let g = F::rand(); + assert_eq!(f + g - f - g, zero); + assert_eq!(f + g - g, f); + assert_eq!(f + g - f, g); + } + + // mul + assert_eq!(two * two, four); + assert_eq!(two * one, two); + assert_eq!(two * zero, zero); + assert_eq!(one * F::from(int_modulus), zero); + + // div + assert_eq!(four / two, two); + assert_eq!(two / two, one); + assert_eq!(zero / two, zero); + assert_eq!(two / zero, zero); // Undefined behavior + assert_eq!(zero.inv(), zero); // Undefined behavior + + // mul + div + for _ in 0..100 { + let f = F::rand(); + if f == zero { + println!("skipped zero"); + continue; + } + assert_eq!(f * f.inv(), one); + assert_eq!(f.inv() * f, one); + } + + // pow + assert_eq!(two.pow(F::Integer::try_from(0).unwrap()), one); + assert_eq!(two.pow(int_one), two); + assert_eq!(two.pow(F::Integer::try_from(2).unwrap()), four); + assert_eq!(two.pow(int_modulus - int_one), one); + assert_eq!(two.pow(int_modulus), two); + + // roots + let mut int_order = F::generator_order(); + for l in 0..MAX_ROOTS + 1 { + assert_eq!( + F::generator().pow(int_order), + F::root(l).unwrap(), + "failure for F::root({})", + l + ); + int_order = int_order >> int_one; + } + + // serialization + let test_inputs = vec![zero, one, F::rand(), F::from(int_modulus - int_one)]; + for want in test_inputs.iter() { + let mut bytes = vec![]; + want.append_to(&mut bytes); + let got = F::read_from(&bytes).unwrap(); + assert_eq!(got, *want); + assert_eq!(bytes.len(), F::BYTES); + } + } + + #[test] + fn test_field32() { + field_element_test::(); + } + + #[test] + fn test_field64() { + field_element_test::(); + } + + #[test] + fn test_field80() { + field_element_test::(); + } + + #[test] + fn test_field126() { + field_element_test::(); } } diff --git a/src/fp.rs b/src/fp.rs index 29d435fa4..272e05550 100644 --- a/src/fp.rs +++ b/src/fp.rs @@ -204,12 +204,6 @@ impl FieldParameters { modp(self.mul(x, 1), self.p) } - /// Returns the number of bytes required to encode field elements. - #[cfg(test)] // This code is only used by tests for now. - pub fn size(&self) -> usize { - (16 - (self.p.leading_zeros() / 8)) as usize - } - #[cfg(test)] pub fn check(&self, p: u128, g: u128, order: u128) { use modinverse::modinverse; @@ -404,7 +398,6 @@ mod tests { expected_p: u128, // Expected fp.p expected_g: u128, // Expected fp.from_elem(fp.g) expected_order: u128, // Expect fp.from_elem(fp.pow(fp.g, expected_order)) == 1 - expected_size: usize, // expected fp.size() } #[test] @@ -415,28 +408,24 @@ mod tests { expected_p: 4293918721, expected_g: 3925978153, expected_order: 1 << 20, - expected_size: 4, }, TestFieldParametersData { fp: FP64, expected_p: 15564440312192434177, expected_g: 7450580596923828125, expected_order: 1 << 59, - expected_size: 8, }, TestFieldParametersData { fp: FP80, expected_p: 779190469673491460259841, expected_g: 41782115852031095118226, expected_order: 1 << 72, - expected_size: 10, }, TestFieldParametersData { fp: FP126, expected_p: 74769074762901517850839147140769382401, expected_g: 43421413544015439978138831414974882540, expected_order: 1 << 118, - expected_size: 16, }, ]; @@ -444,23 +433,15 @@ mod tests { // Check that the field parameters have been constructed properly. t.fp.check(t.expected_p, t.expected_g, t.expected_order); - // Check that the field element size is computed correctly. - assert_eq!( - t.fp.size(), - t.expected_size, - "error for GF({})", - t.expected_p - ); - // Check that the generator has the correct order. assert_eq!(t.fp.from_elem(t.fp.pow(t.fp.g, t.expected_order)), 1); // Test arithmetic using the field parameters. - test_arithmetic(&t.fp); + arithmetic_test(&t.fp); } } - fn test_arithmetic(fp: &FieldParameters) { + fn arithmetic_test(fp: &FieldParameters) { let mut rng = rand::thread_rng(); let big_p = &fp.p.to_bigint().unwrap(); diff --git a/src/lib.rs b/src/lib.rs index 965c8ae50..018c1a99f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -13,7 +13,7 @@ pub mod benchmarked; pub mod client; pub mod encrypt; mod fft; -pub mod finite_field; +pub mod field; mod fp; mod polynomial; mod prng; diff --git a/src/polynomial.rs b/src/polynomial.rs index 6e4b8e7e4..00b0666f4 100644 --- a/src/polynomial.rs +++ b/src/polynomial.rs @@ -3,59 +3,60 @@ //! Functions for polynomial interpolation and evaluation -use crate::finite_field::*; -use crate::util::*; +use crate::field::FieldElement; + +use std::convert::TryFrom; /// Temporary memory used for FFT #[derive(Debug)] -pub struct PolyFFTTempMemory { - fft_tmp: Vec, - fft_y_sub: Vec, - fft_roots_sub: Vec, +pub struct PolyFFTTempMemory { + fft_tmp: Vec, + fft_y_sub: Vec, + fft_roots_sub: Vec, } -impl PolyFFTTempMemory { +impl PolyFFTTempMemory { fn new(length: usize) -> Self { PolyFFTTempMemory { - fft_tmp: vector_with_length(length), - fft_y_sub: vector_with_length(length), - fft_roots_sub: vector_with_length(length), + fft_tmp: vec![F::zero(); length], + fft_y_sub: vec![F::zero(); length], + fft_roots_sub: vec![F::zero(); length], } } } /// Auxiliary memory for polynomial interpolation and evaluation #[derive(Debug)] -pub struct PolyAuxMemory { - pub roots_2n: Vec, - pub roots_2n_inverted: Vec, - pub roots_n: Vec, - pub roots_n_inverted: Vec, - pub coeffs: Vec, - pub fft_memory: PolyFFTTempMemory, +pub struct PolyAuxMemory { + pub roots_2n: Vec, + pub roots_2n_inverted: Vec, + pub roots_n: Vec, + pub roots_n_inverted: Vec, + pub coeffs: Vec, + pub fft_memory: PolyFFTTempMemory, } -impl PolyAuxMemory { +impl PolyAuxMemory { pub fn new(n: usize) -> Self { PolyAuxMemory { roots_2n: fft_get_roots(2 * n, false), roots_2n_inverted: fft_get_roots(2 * n, true), roots_n: fft_get_roots(n, false), roots_n_inverted: fft_get_roots(n, true), - coeffs: vector_with_length(2 * n), + coeffs: vec![F::zero(); 2 * n], fft_memory: PolyFFTTempMemory::new(2 * n), } } } -fn fft_recurse( - out: &mut [Field], +fn fft_recurse( + out: &mut [F], n: usize, - roots: &[Field], - ys: &[Field], - tmp: &mut [Field], - y_sub: &mut [Field], - roots_sub: &mut [Field], + roots: &[F], + ys: &[F], + tmp: &mut [F], + y_sub: &mut [F], + roots_sub: &mut [F], ) { if n == 1 { out[0] = ys[0]; @@ -106,17 +107,17 @@ fn fft_recurse( } /// Calculate `count` number of roots of unity of order `count` -fn fft_get_roots(count: usize, invert: bool) -> Vec { - let mut roots = vec![Field::from(0); count]; - let mut gen = Field::generator(); +fn fft_get_roots(count: usize, invert: bool) -> Vec { + let mut roots = vec![F::zero(); count]; + let mut gen = F::generator(); if invert { gen = gen.inv(); } - roots[0] = 1.into(); - let step_size: u32 = (Field::generator_order() as u32) / (count as u32); + roots[0] = F::one(); + let step_size = F::generator_order() / F::Integer::try_from(count).unwrap(); // generator for subgroup of order count - gen = gen.pow(step_size.into()); + gen = gen.pow(step_size); roots[1] = gen; @@ -127,13 +128,13 @@ fn fft_get_roots(count: usize, invert: bool) -> Vec { roots } -fn fft_interpolate_raw( - out: &mut [Field], - ys: &[Field], +fn fft_interpolate_raw( + out: &mut [F], + ys: &[F], n_points: usize, - roots: &[Field], + roots: &[F], invert: bool, - mem: &mut PolyFFTTempMemory, + mem: &mut PolyFFTTempMemory, ) { fft_recurse( out, @@ -145,25 +146,25 @@ fn fft_interpolate_raw( &mut mem.fft_roots_sub, ); if invert { - let n_inverse = Field::from(n_points as u32).inv(); + let n_inverse = F::from(F::Integer::try_from(n_points).unwrap()).inv(); for i in 0..n_points { out[i] *= n_inverse; } } } -pub fn poly_fft( - points_out: &mut [Field], - points_in: &[Field], - scaled_roots: &[Field], +pub fn poly_fft( + points_out: &mut [F], + points_in: &[F], + scaled_roots: &[F], n_points: usize, invert: bool, - mem: &mut PolyFFTTempMemory, + mem: &mut PolyFFTTempMemory, ) { fft_interpolate_raw(points_out, points_in, n_points, scaled_roots, invert, mem) } -pub fn poly_horner_eval(poly: &[Field], eval_at: Field, len: usize) -> Field { +pub fn poly_horner_eval(poly: &[F], eval_at: F, len: usize) -> F { let mut result = poly[len - 1]; for i in (0..(len - 1)).rev() { @@ -174,33 +175,37 @@ pub fn poly_horner_eval(poly: &[Field], eval_at: Field, len: usize) -> Field { result } -pub fn poly_interpret_eval( - points: &[Field], - roots: &[Field], - eval_at: Field, - tmp_coeffs: &mut [Field], - fft_memory: &mut PolyFFTTempMemory, -) -> Field { +pub fn poly_interpret_eval( + points: &[F], + roots: &[F], + eval_at: F, + tmp_coeffs: &mut [F], + fft_memory: &mut PolyFFTTempMemory, +) -> F { poly_fft(tmp_coeffs, points, roots, points.len(), true, fft_memory); poly_horner_eval(&tmp_coeffs, eval_at, points.len()) } #[test] fn test_roots() { + use crate::field::Field32; + let count = 128; - let roots = fft_get_roots(count, false); - let roots_inv = fft_get_roots(count, true); + let roots = fft_get_roots::(count, false); + let roots_inv = fft_get_roots::(count, true); for i in 0..count { assert_eq!(roots[i] * roots_inv[i], 1); - assert_eq!(roots[i].pow(Field::from(count as u32)), 1); - assert_eq!(roots_inv[i].pow(Field::from(count as u32)), 1); + assert_eq!(roots[i].pow(u32::try_from(count).unwrap()), 1); + assert_eq!(roots_inv[i].pow(u32::try_from(count).unwrap()), 1); } } #[test] fn test_horner_eval() { - let mut poly = vec![Field::from(0); 4]; + use crate::field::Field32; + + let mut poly = vec![Field32::from(0); 4]; poly[0] = 2.into(); poly[1] = 1.into(); poly[2] = 5.into(); @@ -213,18 +218,21 @@ fn test_horner_eval() { #[test] fn test_fft() { - let count = 128; - let mut mem = PolyAuxMemory::new(count / 2); + use crate::field::Field32; use rand::prelude::*; + use std::convert::TryFrom; + + let count = 128; + let mut mem = PolyAuxMemory::new(count / 2); - let mut poly = vec![Field::from(0); count]; - let mut points2 = vec![Field::from(0); count]; + let mut poly = vec![Field32::from(0); count]; + let mut points2 = vec![Field32::from(0); count]; let points = (0..count) .into_iter() - .map(|_| Field::from(random::())) - .collect::>(); + .map(|_| Field32::from(random::())) + .collect::>(); // From points to coeffs and back poly_fft( @@ -256,9 +264,9 @@ fn test_fft() { &mut mem.fft_memory, ); for i in 0..count { - let mut should_be = Field::from(0); + let mut should_be = Field32::from(0); for j in 0..count { - should_be = mem.roots_2n[i].pow(Field::from(j as u32)) * points[j] + should_be; + should_be = mem.roots_2n[i].pow(u32::try_from(j).unwrap()) * points[j] + should_be; } assert_eq!(should_be, poly[i]); } diff --git a/src/prng.rs b/src/prng.rs index de78b0e26..b6bb898c1 100644 --- a/src/prng.rs +++ b/src/prng.rs @@ -1,19 +1,18 @@ // Copyright (c) 2020 Apple Inc. // SPDX-License-Identifier: MPL-2.0 -use super::finite_field::{Field, FieldElement}; +use super::field::{FieldElement, FieldError}; use aes_ctr::stream_cipher::generic_array::GenericArray; use aes_ctr::stream_cipher::NewStreamCipher; use aes_ctr::stream_cipher::SyncStreamCipher; use aes_ctr::Aes128Ctr; use rand::RngCore; -use std::convert::TryInto; const BLOCK_SIZE: usize = 16; const MAXIMUM_BUFFER_SIZE_IN_BLOCKS: usize = 4096; pub const SEED_LENGTH: usize = 2 * BLOCK_SIZE; -pub fn secret_share(share1: &mut [Field]) -> Vec { +pub fn secret_share(share1: &mut [F]) -> Vec { // get prng array let (data, seed) = random_field_and_seed(share1.len()); @@ -25,26 +24,26 @@ pub fn secret_share(share1: &mut [Field]) -> Vec { seed } -pub fn extract_share_from_seed(length: usize, seed: &[u8]) -> Vec { +pub fn extract_share_from_seed(length: usize, seed: &[u8]) -> Vec { random_field_from_seed(seed, length) } -fn random_field_and_seed(length: usize) -> (Vec, Vec) { +fn random_field_and_seed(length: usize) -> (Vec, Vec) { let mut seed = vec![0u8; SEED_LENGTH]; rand::thread_rng().fill_bytes(&mut seed); let data = random_field_from_seed(&seed, length); (data, seed) } -fn random_field_from_seed(seed: &[u8], length: usize) -> Vec { +fn random_field_from_seed(seed: &[u8], length: usize) -> Vec { let key = GenericArray::from_slice(&seed[..BLOCK_SIZE]); let nonce = GenericArray::from_slice(&seed[BLOCK_SIZE..]); let mut cipher = Aes128Ctr::new(&key, &nonce); - let mut output = super::util::vector_with_length(length); + let mut output = vec![F::zero(); length]; let mut output_written = 0; - let length_in_blocks = length * std::mem::size_of::() / BLOCK_SIZE; + let length_in_blocks = length * F::BYTES / BLOCK_SIZE; // add one more block to account for rejection and roundoff errors let buffer_len_in_blocks = std::cmp::min(MAXIMUM_BUFFER_SIZE_IN_BLOCKS, length_in_blocks + 1); // Note: buffer_len must be a multiple of BLOCK_SIZE, so that different buffer @@ -60,20 +59,17 @@ fn random_field_from_seed(seed: &[u8], length: usize) -> Vec { cipher.apply_keystream(&mut buffer); - // rejection sampling - // - // TODO(cjpatton): Once implemented, use FieldParameters::size() and - // FieldElement::form_bytes() to implement this loop. - let field_size = std::mem::size_of::<::Integer>(); - for chunk in buffer.chunks_exact(field_size) { - let integer = - ::Integer::from_le_bytes(chunk.try_into().unwrap()); - if integer < Field::modulus() { - output[output_written] = Field::from(integer); - output_written += 1; - if output_written == length { - break; + for chunk in buffer.chunks_exact(F::BYTES) { + match F::read_from(chunk) { + Ok(x) => { + output[output_written] = x; + output_written += 1; + if output_written == length { + break; + } } + Err(FieldError::FromBytesModulusOverflow) => (), // reject this sample + Err(err) => panic!("unexpected error: {}", err), } } } @@ -86,10 +82,11 @@ fn random_field_from_seed(seed: &[u8], length: usize) -> Vec { #[cfg(test)] mod tests { use super::*; + use crate::field::Field32; #[test] fn secret_sharing() { - let mut data = vec![Field::from(0); 123]; + let mut data = vec![Field32::from(0); 123]; data[3] = 23.into(); let data_clone = data.clone(); @@ -125,7 +122,7 @@ mod tests { 0xacb8b748, 0x6f5b9d49, 0x887d061b, 0x86db0c58, ]; - let share2 = extract_share_from_seed(reference.len(), &seed); + let share2 = extract_share_from_seed::(reference.len(), &seed); assert_eq!(share2, reference); } @@ -133,7 +130,7 @@ mod tests { /// takes a seed and hash as base64 encoded strings fn random_data_interop(seed_base64: &str, hash_base64: &str, len: usize) { let seed = base64::decode(seed_base64).unwrap(); - let random_data = extract_share_from_seed(len, &seed); + let random_data = extract_share_from_seed::(len, &seed); let random_bytes = crate::util::serialize(&random_data); diff --git a/src/server.rs b/src/server.rs index 18caca43e..adfbbabff 100644 --- a/src/server.rs +++ b/src/server.rs @@ -5,10 +5,10 @@ use crate::{ encrypt::{decrypt_share, EncryptError, PrivateKey}, - finite_field::{merge_vector, Field, FiniteFieldError}, + field::{merge_vector, FieldElement, FieldError}, polynomial::{poly_interpret_eval, PolyAuxMemory}, prng::extract_share_from_seed, - util::{deserialize, proof_length, unpack_proof, vector_with_length}, + util::{deserialize, proof_length, unpack_proof, vector_with_length, SerializeError}, }; /// Possible errors from server operations @@ -19,20 +19,23 @@ pub enum ServerError { Encrypt(#[from] EncryptError), /// Finite field operation error #[error("finite field operation error")] - FiniteField(#[from] FiniteFieldError), + Field(#[from] FieldError), + /// Serialization/deserialization error + #[error("serialization/deserialization error")] + Serialize(#[from] SerializeError), } /// Auxiliary memory for constructing a /// [`VerificationMessage`](struct.VerificationMessage.html) #[derive(Debug)] -pub struct ValidationMemory { - points_f: Vec, - points_g: Vec, - points_h: Vec, - poly_mem: PolyAuxMemory, +pub struct ValidationMemory { + points_f: Vec, + points_g: Vec, + points_h: Vec, + poly_mem: PolyAuxMemory, } -impl ValidationMemory { +impl ValidationMemory { /// Construct a new ValidationMemory object for validating proof shares of /// length `dimension`. pub fn new(dimension: usize) -> Self { @@ -48,22 +51,22 @@ impl ValidationMemory { /// Main workhorse of the server. #[derive(Debug)] -pub struct Server { +pub struct Server { dimension: usize, is_first_server: bool, - accumulator: Vec, - validation_mem: ValidationMemory, + accumulator: Vec, + validation_mem: ValidationMemory, private_key: PrivateKey, } -impl Server { +impl Server { /// Construct a new server instance /// /// Params: /// * `dimension`: the number of elements in the aggregation vector. /// * `is_first_server`: only one of the servers should have this true. /// * `private_key`: the private key for decrypting the share of the proof. - pub fn new(dimension: usize, is_first_server: bool, private_key: PrivateKey) -> Server { + pub fn new(dimension: usize, is_first_server: bool, private_key: PrivateKey) -> Server { Server { dimension, is_first_server, @@ -74,10 +77,10 @@ impl Server { } /// Decrypt and deserialize - fn deserialize_share(&self, encrypted_share: &[u8]) -> Result, ServerError> { + fn deserialize_share(&self, encrypted_share: &[u8]) -> Result, ServerError> { let share = decrypt_share(encrypted_share, &self.private_key)?; Ok(if self.is_first_server { - deserialize(&share) + deserialize(&share)? } else { let len = proof_length(self.dimension); extract_share_from_seed(len, &share) @@ -92,9 +95,9 @@ impl Server { /// [choose_eval_at](#method.choose_eval_at). pub fn generate_verification_message( &mut self, - eval_at: Field, + eval_at: F, share: &[u8], - ) -> Option { + ) -> Option> { let share_field = self.deserialize_share(share).ok()?; generate_verification_message( self.dimension, @@ -112,8 +115,8 @@ impl Server { pub fn aggregate( &mut self, share: &[u8], - v1: &VerificationMessage, - v2: &VerificationMessage, + v1: &VerificationMessage, + v2: &VerificationMessage, ) -> Result { let share_field = self.deserialize_share(share)?; let is_valid = is_valid_share(v1, v2); @@ -131,7 +134,7 @@ impl Server { /// /// These can be merged together using /// [`reconstruct_shares`](../util/fn.reconstruct_shares.html). - pub fn total_shares(&self) -> &[Field] { + pub fn total_shares(&self) -> &[F] { &self.accumulator } @@ -143,7 +146,7 @@ impl Server { /// /// Returns an error if `other_total_shares.len()` is not equal to this //// server's `dimension`. - pub fn merge_total_shares(&mut self, other_total_shares: &[Field]) -> Result<(), ServerError> { + pub fn merge_total_shares(&mut self, other_total_shares: &[F]) -> Result<(), ServerError> { Ok(merge_vector(&mut self.accumulator, other_total_shares)?) } @@ -151,9 +154,9 @@ impl Server { /// /// The point returned is not one of the roots used for polynomial /// evaluation. - pub fn choose_eval_at(&self) -> Field { + pub fn choose_eval_at(&self) -> F { loop { - let eval_at = Field::from(rand::random::()); + let eval_at = F::rand(); if !self.validation_mem.poly_mem.roots_2n.contains(&eval_at) { break eval_at; } @@ -162,24 +165,24 @@ impl Server { } /// Verification message for proof validation -pub struct VerificationMessage { +pub struct VerificationMessage { /// f evaluated at random point - pub f_r: Field, + pub f_r: F, /// g evaluated at random point - pub g_r: Field, + pub g_r: F, /// h evaluated at random point - pub h_r: Field, + pub h_r: F, } /// Given a proof and evaluation point, this constructs the verification /// message. -pub fn generate_verification_message( +pub fn generate_verification_message( dimension: usize, - eval_at: Field, - proof: &[Field], + eval_at: F, + proof: &[F], is_first_server: bool, - mem: &mut ValidationMemory, -) -> Option { + mem: &mut ValidationMemory, +) -> Option> { let unpacked = unpack_proof(proof, dimension)?; let proof_length = 2 * (dimension + 1).next_power_of_two(); @@ -194,7 +197,7 @@ pub fn generate_verification_message( if is_first_server { // only one server needs to subtract one for point_g - mem.points_g[i + 1] = *x - 1.into(); + mem.points_g[i + 1] = *x - F::one(); } else { mem.points_g[i + 1] = *x; } @@ -237,7 +240,10 @@ pub fn generate_verification_message( } /// Decides if the distributed proof is valid -pub fn is_valid_share(v1: &VerificationMessage, v2: &VerificationMessage) -> bool { +pub fn is_valid_share( + v1: &VerificationMessage, + v2: &VerificationMessage, +) -> bool { // reconstruct f_r, g_r, h_r let f_r = v1.f_r + v2.f_r; let g_r = v1.g_r + v2.g_r; @@ -249,6 +255,7 @@ pub fn is_valid_share(v1: &VerificationMessage, v2: &VerificationMessage) -> boo #[cfg(test)] mod tests { use super::*; + use crate::field::Field32; use crate::util; #[test] @@ -260,9 +267,9 @@ mod tests { 2567182742, 3542857140, 124017604, 4201373647, 431621210, 1618555683, 267689149, ]; - let mut proof: Vec = proof_u32.iter().map(|x| Field::from(*x)).collect(); + let mut proof: Vec = proof_u32.iter().map(|x| Field32::from(*x)).collect(); let share2 = util::tests::secret_share(&mut proof); - let eval_at = Field::from(12313); + let eval_at = Field32::from(12313); let mut validation_mem = ValidationMemory::new(dim); diff --git a/src/util.rs b/src/util.rs index 84382aa4e..9fb14f5e7 100644 --- a/src/util.rs +++ b/src/util.rs @@ -3,11 +3,18 @@ //! Utility functions for handling Prio stuff. -use crate::finite_field::{Field, FieldElement}; - -/// Convenience function for initializing fixed sized vectors of Field elements. -pub fn vector_with_length(len: usize) -> Vec { - vec![Field::from(0); len] +use crate::field::{FieldElement, FieldError}; + +/// Serialization errors +#[derive(Debug, thiserror::Error)] +pub enum SerializeError { + /// Emitted by `deserialize()` if the last chunk of input is not long enough to encode an + /// element of the field. + #[error("last chunk of bytes is incomplete")] + IncompleteChunk, + /// Finite field operation error. + #[error("finite field operation error")] + Field(#[from] FieldError), } /// Returns the number of field elements in the proof for given dimension of @@ -21,36 +28,41 @@ pub fn proof_length(dimension: usize) -> usize { dimension + 3 + (dimension + 1).next_power_of_two() } +/// Convenience function for initializing fixed sized vectors of Field elements. +pub fn vector_with_length(len: usize) -> Vec { + vec![F::zero(); len] +} + /// Unpacked proof with subcomponents -pub struct UnpackedProof<'a> { +pub struct UnpackedProof<'a, F: FieldElement> { /// Data - pub data: &'a [Field], + pub data: &'a [F], /// Zeroth coefficient of polynomial f - pub f0: &'a Field, + pub f0: &'a F, /// Zeroth coefficient of polynomial g - pub g0: &'a Field, + pub g0: &'a F, /// Zeroth coefficient of polynomial h - pub h0: &'a Field, + pub h0: &'a F, /// Non-zero points of polynomial h - pub points_h_packed: &'a [Field], + pub points_h_packed: &'a [F], } /// Unpacked proof with mutable subcomponents -pub struct UnpackedProofMut<'a> { +pub struct UnpackedProofMut<'a, F: FieldElement> { /// Data - pub data: &'a mut [Field], + pub data: &'a mut [F], /// Zeroth coefficient of polynomial f - pub f0: &'a mut Field, + pub f0: &'a mut F, /// Zeroth coefficient of polynomial g - pub g0: &'a mut Field, + pub g0: &'a mut F, /// Zeroth coefficient of polynomial h - pub h0: &'a mut Field, + pub h0: &'a mut F, /// Non-zero points of polynomial h - pub points_h_packed: &'a mut [Field], + pub points_h_packed: &'a mut [F], } /// Unpacks the proof vector into subcomponents -pub fn unpack_proof(proof: &[Field], dimension: usize) -> Option { +pub fn unpack_proof(proof: &[F], dimension: usize) -> Option> { // check the proof length if proof.len() != proof_length(dimension) { return None; @@ -73,7 +85,10 @@ pub fn unpack_proof(proof: &[Field], dimension: usize) -> Option } /// Unpacks a mutable proof vector into mutable subcomponents -pub fn unpack_proof_mut(proof: &mut [Field], dimension: usize) -> Option { +pub fn unpack_proof_mut( + proof: &mut [F], + dimension: usize, +) -> Option> { // check the share length if proof.len() != proof_length(dimension) { return None; @@ -96,46 +111,35 @@ pub fn unpack_proof_mut(proof: &mut [Field], dimension: usize) -> Option Vec { - let field_size = std::mem::size_of::<::Integer>(); - let mut vec = Vec::with_capacity(data.len() * field_size); - +pub fn serialize(data: &[F]) -> Vec { + let mut vec = Vec::::with_capacity(data.len() * F::BYTES); for elem in data.iter() { - // TODO(cjpatton) Implement FieldElement::bytes() that encodes each field element using - // FieldParameters::size() bytes. - let int = ::Integer::from(*elem); - vec.extend(int.to_le_bytes().iter()); + elem.append_to(&mut vec); } - vec } /// Get a vector of field elements from a byte slice -pub fn deserialize(data: &[u8]) -> Vec { - let field_size = std::mem::size_of::<::Integer>(); - - let mut vec = Vec::with_capacity(data.len() / field_size); - use std::convert::TryInto; - - for chunk in data.chunks_exact(field_size) { - // TODO(cjpatton) Implement FieldElement::from_bytes() that decodes field elements from a - // string with FieldParameters::size() bytes. - let integer = ::Integer::from_le_bytes(chunk.try_into().unwrap()); - vec.push(Field::from(integer)); +pub fn deserialize(data: &[u8]) -> Result, SerializeError> { + if data.len() % F::BYTES != 0 { + return Err(SerializeError::IncompleteChunk); } - - vec + let mut vec = Vec::::with_capacity(data.len() / F::BYTES); + for chunk in data.chunks_exact(F::BYTES) { + vec.push(F::read_from(chunk).or_else(|err| Err(SerializeError::Field(err)))?); + } + Ok(vec) } -/// Add two Field element arrays together elementwise. +/// Add two field element arrays together elementwise. /// /// Returns None, when array lengths are not equal. -pub fn reconstruct_shares(share1: &[Field], share2: &[Field]) -> Option> { +pub fn reconstruct_shares(share1: &[F], share2: &[F]) -> Option> { if share1.len() != share2.len() { return None; } - let mut reconstructed = vector_with_length(share1.len()); + let mut reconstructed: Vec = vector_with_length(share1.len()); for (r, (s1, s2)) in reconstructed .iter_mut() @@ -150,17 +154,18 @@ pub fn reconstruct_shares(share1: &[Field], share2: &[Field]) -> Option Vec { + pub fn secret_share(share: &mut [Field32]) -> Vec { use rand::Rng; let mut rng = rand::thread_rng(); let mut random = vec![0u32; share.len()]; - let mut share2 = vector_with_length(share.len()); + let mut share2 = vec![Field32::zero(); share.len()]; rng.fill(&mut random[..]); for (r, f) in random.iter().zip(share2.iter_mut()) { - *f = Field::from(*r); + *f = Field32::from(*r); } for (f1, f2) in share.iter_mut().zip(share2.iter()) { @@ -175,15 +180,15 @@ pub mod tests { let dim = 15; let len = proof_length(dim); - let mut share = vec![Field::from(0); len]; + let mut share = vec![Field32::from(0); len]; let unpacked = unpack_proof_mut(&mut share, dim).unwrap(); - *unpacked.f0 = Field::from(12); + *unpacked.f0 = Field32::from(12); assert_eq!(share[dim], 12); } #[test] fn secret_sharing() { - let mut share1 = vector_with_length(10); + let mut share1 = vec![Field32::zero(); 10]; share1[3] = 21.into(); share1[8] = 123.into(); @@ -197,9 +202,9 @@ pub mod tests { #[test] fn serialization() { - let field = [Field::from(1), Field::from(0x99997)]; + let field = [Field32::from(1), Field32::from(0x99997)]; let bytes = serialize(&field); - let field_deserialized = deserialize(&bytes); + let field_deserialized = deserialize::(&bytes).unwrap(); assert_eq!(field_deserialized, field); } } diff --git a/tests/accumulating.rs b/tests/accumulating.rs index a96562e16..b11296ab6 100644 --- a/tests/accumulating.rs +++ b/tests/accumulating.rs @@ -3,7 +3,7 @@ use prio::client::*; use prio::encrypt::*; -use prio::finite_field::Field; +use prio::field::Field32; use prio::server::*; #[test] @@ -26,8 +26,8 @@ fn accumulation() { let mut reference_count = vec![0u32; dim]; - let mut server1 = Server::new(dim, true, priv_key1); - let mut server2 = Server::new(dim, false, priv_key2); + let mut server1: Server = Server::new(dim, true, priv_key1); + let mut server2: Server = Server::new(dim, false, priv_key2); let mut client_mem = Client::new(dim, pub_key1, pub_key2).unwrap(); @@ -36,8 +36,8 @@ fn accumulation() { for _ in 0..number_of_clients { // some random data let data = (0..dim) - .map(|_| Field::from(rng.gen_range(0, 2))) - .collect::>(); + .map(|_| Field32::from(rng.gen_range(0, 2))) + .collect::>(); // update reference count for (r, d) in reference_count.iter_mut().zip(data.iter()) { diff --git a/tests/tweaks.rs b/tests/tweaks.rs index 9cf5ee16b..c94041e4a 100644 --- a/tests/tweaks.rs +++ b/tests/tweaks.rs @@ -3,7 +3,7 @@ use prio::client::*; use prio::encrypt::*; -use prio::finite_field::Field; +use prio::field::Field32; use prio::server::*; use prio::util::*; @@ -42,8 +42,8 @@ fn tweaks(tweak: Tweak) { let priv_key1_clone = priv_key1.clone(); let pub_key1_clone = pub_key1.clone(); - let mut server1 = Server::new(dim, true, priv_key1); - let mut server2 = Server::new(dim, false, priv_key2); + let mut server1: Server = Server::new(dim, true, priv_key1); + let mut server2: Server = Server::new(dim, false, priv_key2); let mut client_mem = Client::new(dim, pub_key1, pub_key2).unwrap(); @@ -51,16 +51,16 @@ fn tweaks(tweak: Tweak) { let mut data = vector_with_length(dim); if let Tweak::WrongInput = tweak { - data[0] = Field::from(2); + data[0] = Field32::from(2); } let (share1_original, share2) = client_mem.encode_simple(&data).unwrap(); let decrypted_share1 = decrypt_share(&share1_original, &priv_key1_clone).unwrap(); - let mut share1_field = deserialize(&decrypted_share1); + let mut share1_field: Vec = deserialize(&decrypted_share1).unwrap(); let unpacked_share1 = unpack_proof_mut(&mut share1_field, dim).unwrap(); - let one = Field::from(1); + let one = Field32::from(1); match tweak { Tweak::DataPartOfShare => unpacked_share1.data[0] += one,