diff --git a/benches/fft.rs b/benches/fft.rs index e51276e0a..818b438ba 100644 --- a/benches/fft.rs +++ b/benches/fft.rs @@ -3,16 +3,15 @@ use criterion::{criterion_group, criterion_main, Criterion}; use prio::benchmarked::{benchmarked_iterative_fft, benchmarked_recursive_fft}; -use prio::finite_field::{Field, FieldElement}; +use prio::field::{Field126, FieldElement}; pub fn fft(c: &mut Criterion) { let test_sizes = [16, 256, 1024, 4096]; for size in test_sizes.iter() { - let mut rng = rand::thread_rng(); - let mut inp = vec![Field::zero(); *size]; - let mut outp = vec![Field::zero(); *size]; + let mut inp = vec![Field126::zero(); *size]; + let mut outp = vec![Field126::zero(); *size]; for i in 0..*size { - inp[i] = Field::rand(&mut rng); + inp[i] = Field126::rand(); } c.bench_function(&format!("iterative/{}", *size), |b| { diff --git a/examples/sum.rs b/examples/sum.rs index 063cd09d4..d48230068 100644 --- a/examples/sum.rs +++ b/examples/sum.rs @@ -2,7 +2,7 @@ use prio::client::*; use prio::encrypt::*; -use prio::finite_field::*; +use prio::field::*; use prio::server::*; fn main() { @@ -30,20 +30,20 @@ fn main() { let data1 = data1_u32 .iter() - .map(|x| Field::from(*x)) - .collect::>(); + .map(|x| Field32::from(*x)) + .collect::>(); let data2_u32 = [0, 0, 1, 0, 0, 0, 0, 0]; println!("Client 2 Input: {:?}", data2_u32); let data2 = data2_u32 .iter() - .map(|x| Field::from(*x)) - .collect::>(); + .map(|x| Field32::from(*x)) + .collect::>(); let (share1_1, share1_2) = client1.encode_simple(&data1).unwrap(); let (share2_1, share2_2) = client2.encode_simple(&data2).unwrap(); - let eval_at = Field::from(12313); + let eval_at = Field32::from(12313); let mut server1 = Server::new(dim, true, priv_key1.clone()); let mut server2 = Server::new(dim, false, priv_key2.clone()); diff --git a/src/benchmarked.rs b/src/benchmarked.rs index ffd75df4a..b831e3579 100644 --- a/src/benchmarked.rs +++ b/src/benchmarked.rs @@ -4,7 +4,7 @@ //! benchmark, but which we don't want to expose in the public API. use crate::fft::discrete_fourier_transform; -use crate::finite_field::{Field, FieldElement}; +use crate::field::FieldElement; use crate::polynomial::{poly_fft, PolyAuxMemory}; /// Sets `outp` to the Discrete Fourier Transform (DFT) using an iterative FFT algorithm. @@ -13,7 +13,7 @@ pub fn benchmarked_iterative_fft(outp: &mut [F], inp: &[F]) { } /// Sets `outp` to the Discrete Fourier Transform (DFT) using a recursive FFT algorithm. -pub fn benchmarked_recursive_fft(outp: &mut [Field], inp: &[Field]) { +pub fn benchmarked_recursive_fft(outp: &mut [F], inp: &[F]) { let mut mem = PolyAuxMemory::new(inp.len() / 2); poly_fft( outp, diff --git a/src/client.rs b/src/client.rs index d7209b2ef..b69b2e6e7 100644 --- a/src/client.rs +++ b/src/client.rs @@ -4,36 +4,61 @@ //! Prio client use crate::encrypt::*; -use crate::finite_field::*; +use crate::field::FieldElement; use crate::polynomial::*; use crate::util::*; +use std::convert::TryFrom; + /// The main object that can be used to create Prio shares /// /// Client is used to create Prio shares. #[derive(Debug)] -pub struct Client { +pub struct Client { dimension: usize, - points_f: Vec, - points_g: Vec, - evals_f: Vec, - evals_g: Vec, - poly_mem: PolyAuxMemory, + points_f: Vec, + points_g: Vec, + evals_f: Vec, + evals_g: Vec, + poly_mem: PolyAuxMemory, public_key1: PublicKey, public_key2: PublicKey, } -impl Client { +/// Errors that might be emitted by the client. +#[derive(Debug, thiserror::Error)] +pub enum ClientError { + /// Thes error is output by `Client::new()` if the length of the proof would exceed the + /// number of roots of unity that can be generated in the field. + #[error("input size exceeds field capacity")] + InputSizeExceedsFieldCapacity, + /// Thes error is output by `Client::new()` if the length of the proof would exceed the + /// ssytem's addressible memory. + #[error("input size exceeds field capacity")] + InputSizeExceedsMemoryCapacity, + /// Encryption/decryption error + #[error("encryption/decryption error")] + Encrypt(#[from] EncryptError), +} + +impl Client { /// Construct a new Prio client - pub fn new(dimension: usize, public_key1: PublicKey, public_key2: PublicKey) -> Option { + pub fn new( + dimension: usize, + public_key1: PublicKey, + public_key2: PublicKey, + ) -> Result { let n = (dimension + 1).next_power_of_two(); - if 2 * n > Field::generator_order() as usize { - // too many elements for this field, not enough roots of unity - return None; + if let Ok(size) = F::Integer::try_from(2 * n) { + if size > F::generator_order() { + return Err(ClientError::InputSizeExceedsFieldCapacity); + } + } else { + return Err(ClientError::InputSizeExceedsMemoryCapacity); } - Some(Client { + Ok(Client { dimension, points_f: vector_with_length(n), points_g: vector_with_length(n), @@ -46,20 +71,20 @@ impl Client { } /// Construct a pair of encrypted shares based on the input data. - pub fn encode_simple(&mut self, data: &[Field]) -> Result<(Vec, Vec), EncryptError> { - let copy_data = |share_data: &mut [Field]| { + pub fn encode_simple(&mut self, data: &[F]) -> Result<(Vec, Vec), ClientError> { + let copy_data = |share_data: &mut [F]| { share_data[..].clone_from_slice(data); }; - self.encode_with(copy_data) + Ok(self.encode_with(copy_data)?) } /// Construct a pair of encrypted shares using a initilization function. /// /// This might be slightly more efficient on large vectors, because one can /// avoid copying the input data. - pub fn encode_with(&mut self, init_function: F) -> Result<(Vec, Vec), EncryptError> + pub fn encode_with(&mut self, init_function: G) -> Result<(Vec, Vec), EncryptError> where - F: FnOnce(&mut [Field]), + G: FnOnce(&mut [F]), { let mut proof = vector_with_length(proof_length(self.dimension)); // unpack one long vector to different subparts @@ -90,21 +115,21 @@ impl Client { /// Convenience function if one does not want to reuse /// [`Client`](struct.Client.html). -pub fn encode_simple( - data: &[Field], +pub fn encode_simple( + data: &[F], public_key1: PublicKey, public_key2: PublicKey, -) -> Option<(Vec, Vec)> { +) -> Result<(Vec, Vec), ClientError> { let dimension = data.len(); let mut client_memory = Client::new(dimension, public_key1, public_key2)?; - client_memory.encode_simple(data).ok() + client_memory.encode_simple(data) } -fn interpolate_and_evaluate_at_2n( +fn interpolate_and_evaluate_at_2n( n: usize, - points_in: &[Field], - evals_out: &mut [Field], - mem: &mut PolyAuxMemory, + points_in: &[F], + evals_out: &mut [F], + mem: &mut PolyAuxMemory, ) { // interpolate through roots of unity poly_fft( @@ -130,20 +155,20 @@ fn interpolate_and_evaluate_at_2n( /// /// Based on Theorem 2.3.3 from Henry Corrigan-Gibbs' dissertation /// This constructs the output \pi by doing the necessesary calculations -fn construct_proof( - data: &[Field], +fn construct_proof( + data: &[F], dimension: usize, - f0: &mut Field, - g0: &mut Field, - h0: &mut Field, - points_h_packed: &mut [Field], - mem: &mut Client, + f0: &mut F, + g0: &mut F, + h0: &mut F, + points_h_packed: &mut [F], + mem: &mut Client, ) { let n = (dimension + 1).next_power_of_two(); // set zero terms to random - *f0 = Field::from(rand::random::()); - *g0 = Field::from(rand::random::()); + *f0 = F::rand(); + *g0 = F::rand(); mem.points_f[0] = *f0; mem.points_g[0] = *g0; @@ -154,7 +179,7 @@ fn construct_proof( // set g_i = f_i - 1 for i in 0..dimension { mem.points_f[i + 1] = data[i]; - mem.points_g[i + 1] = data[i] - 1.into(); + mem.points_g[i + 1] = data[i] - F::one(); } // interpolate and evaluate at roots of unity @@ -174,6 +199,8 @@ fn construct_proof( #[test] fn test_encode() { + use crate::field::Field32; + let pub_key1 = PublicKey::from_base64( "BIl6j+J6dYttxALdjISDv6ZI4/VWVEhUzaS05LgrsfswmbLOgNt9HUC2E0w+9RqZx3XMkdEHBHfNuCSMpOwofVQ=", ) @@ -186,8 +213,8 @@ fn test_encode() { let data_u32 = [0u32, 1, 0, 1, 1, 0, 0, 0, 1]; let data = data_u32 .iter() - .map(|x| Field::from(*x)) - .collect::>(); + .map(|x| Field32::from(*x)) + .collect::>(); let encoded_shares = encode_simple(&data, pub_key1, pub_key2); - assert_eq!(encoded_shares.is_some(), true); + assert_eq!(encoded_shares.is_ok(), true); } diff --git a/src/fft.rs b/src/fft.rs index 370ce2fff..763b68fe7 100644 --- a/src/fft.rs +++ b/src/fft.rs @@ -3,7 +3,7 @@ //! This module implements an iterative FFT algorithm for computing the (inverse) Discrete Fourier //! Transform (DFT) over a slice of field elements. -use crate::finite_field::FieldElement; +use crate::field::FieldElement; use crate::fp::{log2, MAX_ROOTS}; use std::convert::TryFrom; @@ -48,7 +48,7 @@ pub fn discrete_fourier_transform( let mut w: F; for l in 1..d + 1 { - w = F::root(0).unwrap(); // one + w = F::one(); let r = F::root(l).unwrap(); let y = 1 << (l - 1); for i in 0..y { @@ -100,11 +100,10 @@ fn bitrev(d: usize, x: usize) -> usize { #[cfg(test)] mod tests { use super::*; - use crate::finite_field::{Field, Field126, Field64, Field80}; + use crate::field::{Field126, Field32, Field64, Field80}; use crate::polynomial::{poly_fft, PolyAuxMemory}; fn discrete_fourier_transform_then_inv_test() -> Result<(), FftError> { - let mut rng = rand::thread_rng(); let test_sizes = [1, 2, 4, 8, 16, 256, 1024, 2048]; for size in test_sizes.iter() { @@ -112,7 +111,7 @@ mod tests { let mut tmp = vec![F::zero(); *size]; let mut got = vec![F::zero(); *size]; for i in 0..*size { - want[i] = F::rand(&mut rng); + want[i] = F::rand(); } discrete_fourier_transform(&mut tmp, &want)?; @@ -125,7 +124,7 @@ mod tests { #[test] fn test_field32() { - discrete_fourier_transform_then_inv_test::().expect("unexpected error"); + discrete_fourier_transform_then_inv_test::().expect("unexpected error"); } #[test] @@ -146,17 +145,16 @@ mod tests { #[test] fn test_recursive_fft() { let size = 128; - let mut rng = rand::thread_rng(); let mut mem = PolyAuxMemory::new(size / 2); - let mut inp = vec![Field::zero(); size]; - let mut want = vec![Field::zero(); size]; - let mut got = vec![Field::zero(); size]; + let mut inp = vec![Field32::zero(); size]; + let mut want = vec![Field32::zero(); size]; + let mut got = vec![Field32::zero(); size]; for i in 0..size { - inp[i] = Field::rand(&mut rng); + inp[i] = Field32::rand(); } - discrete_fourier_transform::(&mut want, &inp).expect("unexpected error"); + discrete_fourier_transform::(&mut want, &inp).expect("unexpected error"); poly_fft( &mut got, diff --git a/src/finite_field.rs b/src/field.rs similarity index 55% rename from src/finite_field.rs rename to src/field.rs index 676b8e1ea..18cb6c225 100644 --- a/src/finite_field.rs +++ b/src/field.rs @@ -8,17 +8,21 @@ use std::{ cmp::min, convert::TryFrom, fmt::{Debug, Display, Formatter}, - ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign}, + ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Shr, Sub, SubAssign}, }; -use rand::Rng; - /// Possible errors from finite field operations. #[derive(Debug, thiserror::Error)] -pub enum FiniteFieldError { +pub enum FieldError { /// Input sizes do not match #[error("input sizes do not match")] InputSizeMismatch, + /// Returned by `FieldElement::read_from()` if the input buffer is too short. + #[error("short read from byte slice")] + FromBytesShortRead, + /// Returned by `FieldElement::read_from()` if the input is larger than the modulus. + #[error("read from byte slice exceeds modulus")] + FromBytesModulusOverflow, } /// Objects with this trait represent an element of `GF(p)` for some prime `p`. @@ -40,17 +44,23 @@ pub trait FieldElement: + Display + From<::Integer> { + /// Size of each field element in bytes. + const BYTES: usize; + /// The error returned if converting `usize` to an `Int` fails. type IntegerTryFromError: std::fmt::Debug; /// The integer representation of the field element. type Integer: Copy + Debug + + PartialOrd + + Div::Integer> + + Shr::Integer> + Sub::Integer> + TryFrom; /// Modular exponentation, i.e., `self^exp (mod p)`. - fn pow(&self, exp: Self) -> Self; // TODO(cjpatton) exp should have type Self::Integer + fn pow(&self, exp: Self::Integer) -> Self; /// Modular inversion, i.e., `self^-1 (mod p)`. If `self` is 0, then the output is undefined. fn inv(&self) -> Self; @@ -58,6 +68,16 @@ pub trait FieldElement: /// Returns the prime modulus `p`. fn modulus() -> Self::Integer; + /// Writes the field element to the end of input buffer. + /// + /// TODO(acmiyaguchi) Replace this with an implementation of the corresponding serde trait + fn append_to(&self, bytes: &mut Vec); + + /// Reads the next field element from the buffer. + /// + /// TODO(acmiyaguchi) Replace this with an implementation of the corresponding serde trait + fn read_from(bytes: &[u8]) -> Result; + /// Returns the size of the multiplicative subgroup generated by `generator()`. fn generator_order() -> Self::Integer; @@ -69,16 +89,19 @@ pub trait FieldElement: fn root(l: usize) -> Option; /// Returns a random field element distributed uniformly over all field elements. - fn rand(rng: &mut R) -> Self; + fn rand() -> Self; /// Returns the additive identity. fn zero() -> Self; + + /// Returns the multiplicative identity. + fn one() -> Self; } macro_rules! make_field { ( $(#[$meta:meta])* - $elem:ident, $int:ident, $fp:ident + $elem:ident, $int:ident, $fp:ident, $bytes:literal ) => { $(#[$meta])* #[derive(Clone, Copy, Debug, PartialOrd, Ord, Hash, Default)] @@ -211,11 +234,12 @@ macro_rules! make_field { } impl FieldElement for $elem { + const BYTES: usize = $bytes; type Integer = $int; type IntegerTryFromError = >::Error; - fn pow(&self, exp: Self) -> Self { - Self($fp.pow(self.0, $fp.from_elem(exp.0))) + fn pow(&self, exp: Self::Integer) -> Self { + Self($fp.pow(self.0, u128::try_from(exp).unwrap())) } fn inv(&self) -> Self { @@ -226,6 +250,31 @@ macro_rules! make_field { $fp.p as $int } + fn append_to(&self, bytes: &mut Vec) { + let int = $fp.from_elem(self.0); + let mut slice = [0; Self::BYTES]; + for i in 0..Self::BYTES { + slice[i] = ((int >> (i << 3)) & 0xff) as u8; + } + bytes.extend_from_slice(&slice); + } + + fn read_from(bytes: &[u8]) -> Result { + if Self::BYTES > bytes.len() { + return Err(FieldError::FromBytesShortRead); + } + + let mut int = 0; + for i in 0..Self::BYTES { + int |= (bytes[i] as u128) << (i << 3); + } + + if int >= $fp.p { + return Err(FieldError::FromBytesModulusOverflow); + } + Ok(Self($fp.elem(int))) + } + fn generator() -> Self { Self($fp.g) } @@ -242,122 +291,65 @@ macro_rules! make_field { } } - fn rand(rng: &mut R) -> Self { - Self($fp.rand_elem(rng)) + fn rand() -> Self { + let mut rng = rand::thread_rng(); + Self($fp.rand_elem(&mut rng)) } fn zero() -> Self { Self(0) } + + fn one() -> Self { + Self($fp.roots[0]) + } } }; } -// TODO(cjpatton) Rename this to Field32. make_field!( /// `GF(4293918721)`, a 32-bit field. The generator has order `2^20`. - Field, + Field32, u32, - FP32 + FP32, + 4 ); make_field!( /// `GF(15564440312192434177)`, a 64-bit field. The generator has order `2^59`. Field64, u64, - FP64 + FP64, + 8 ); make_field!( /// `GF(779190469673491460259841)`, an 80-bit field. The generator has order `2^72`. Field80, u128, - FP80 + FP80, + 10 ); make_field!( /// `GF(74769074762901517850839147140769382401)`, a 126-bit field. The generator has order `2^118`. Field126, u128, - FP126 + FP126, + 16 ); -#[test] -fn test_arithmetic() { - // TODO(cjpatton) Add tests for the other fields. - use rand::prelude::*; - - let modulus = Field::modulus(); - - // add - assert_eq!(Field::from(modulus - 1) + Field::from(1), 0); - assert_eq!(Field::from(modulus - 2) + Field::from(2), 0); - assert_eq!(Field::from(modulus - 2) + Field::from(3), 1); - assert_eq!(Field::from(1) + Field::from(1), 2); - assert_eq!(Field::from(2) + Field::from(modulus), 2); - assert_eq!(Field::from(3) + Field::from(modulus - 1), 2); - - // sub - assert_eq!(Field::from(0) - Field::from(1), modulus - 1); - assert_eq!(Field::from(1) - Field::from(2), modulus - 1); - assert_eq!(Field::from(15) - Field::from(3), 12); - assert_eq!(Field::from(1) - Field::from(1), 0); - assert_eq!(Field::from(2) - Field::from(modulus), 2); - assert_eq!(Field::from(3) - Field::from(modulus - 1), 4); - - // add + sub - for _ in 0..100 { - let f = Field::from(random::()); - let g = Field::from(random::()); - assert_eq!(f + g - f - g, 0); - assert_eq!(f + g - g, f); - assert_eq!(f + g - f, g); - } - - // mul - assert_eq!(Field::from(35) * Field::from(123), 4305); - assert_eq!(Field::from(1) * Field::from(modulus), 0); - assert_eq!(Field::from(0) * Field::from(123), 0); - assert_eq!(Field::from(123) * Field::from(0), 0); - assert_eq!(Field::from(123123123) * Field::from(123123123), 1237630077); - - // div - assert_eq!(Field::from(35) / Field::from(5), 7); - assert_eq!(Field::from(35) / Field::from(0), 0); - assert_eq!(Field::from(0) / Field::from(5), 0); - assert_eq!(Field::from(1237630077) / Field::from(123123123), 123123123); - - assert_eq!(Field::from(0).inv(), 0); - - // mul and div - let uniform = rand::distributions::Uniform::from(1..modulus); - let mut rng = thread_rng(); - for _ in 0..100 { - // non-zero element - let f = Field::from(uniform.sample(&mut rng)); - assert_eq!(f * f.inv(), 1); - assert_eq!(f.inv() * f, 1); - } - - // pow - assert_eq!(Field::from(2).pow(3.into()), 8); - assert_eq!(Field::from(3).pow(9.into()), 19683); - assert_eq!(Field::from(51).pow(27.into()), 3760729523); - assert_eq!(Field::from(432).pow(0.into()), 1); - assert_eq!(Field(0).pow(123.into()), 0); -} - /// Merge two vectors of fields by summing other_vector into accumulator. /// /// # Errors /// /// Fails if the two vectors do not have the same length. -pub fn merge_vector( - accumulator: &mut [Field], - other_vector: &[Field], -) -> Result<(), FiniteFieldError> { +pub fn merge_vector( + accumulator: &mut [F], + other_vector: &[F], +) -> Result<(), FieldError> { if accumulator.len() != other_vector.len() { - return Err(FiniteFieldError::InputSizeMismatch); + return Err(FieldError::InputSizeMismatch); } for (a, o) in accumulator.iter_mut().zip(other_vector.iter()) { *a += *o; @@ -369,23 +361,126 @@ pub fn merge_vector( #[cfg(test)] mod tests { use super::*; + use crate::fp::MAX_ROOTS; use crate::util::vector_with_length; use assert_matches::assert_matches; #[test] fn test_accumulate() { let mut lhs = vector_with_length(10); - lhs.iter_mut().for_each(|f| *f = Field(1)); + lhs.iter_mut().for_each(|f| *f = Field32(1)); let mut rhs = vector_with_length(10); - rhs.iter_mut().for_each(|f| *f = Field(2)); + rhs.iter_mut().for_each(|f| *f = Field32(2)); merge_vector(&mut lhs, &rhs).unwrap(); - lhs.iter().for_each(|f| assert_eq!(*f, Field(3))); - rhs.iter().for_each(|f| assert_eq!(*f, Field(2))); + lhs.iter().for_each(|f| assert_eq!(*f, Field32(3))); + rhs.iter().for_each(|f| assert_eq!(*f, Field32(2))); let wrong_len = vector_with_length(9); let result = merge_vector(&mut lhs, &wrong_len); - assert_matches!(result, Err(FiniteFieldError::InputSizeMismatch)); + assert_matches!(result, Err(FieldError::InputSizeMismatch)); + } + + fn field_element_test() { + let int_modulus = F::modulus(); + let int_one = F::Integer::try_from(1).unwrap(); + let zero = F::zero(); + let one = F::one(); + let two = F::from(F::Integer::try_from(2).unwrap()); + let four = F::from(F::Integer::try_from(4).unwrap()); + + // add + assert_eq!(F::from(int_modulus - int_one) + one, zero); + assert_eq!(one + one, two); + assert_eq!(two + F::from(int_modulus), two); + + // sub + assert_eq!(zero - one, F::from(int_modulus - int_one)); + assert_eq!(one - one, zero); + assert_eq!(two - F::from(int_modulus), two); + assert_eq!(one - F::from(int_modulus - int_one), two); + + // add + sub + for _ in 0..100 { + let f = F::rand(); + let g = F::rand(); + assert_eq!(f + g - f - g, zero); + assert_eq!(f + g - g, f); + assert_eq!(f + g - f, g); + } + + // mul + assert_eq!(two * two, four); + assert_eq!(two * one, two); + assert_eq!(two * zero, zero); + assert_eq!(one * F::from(int_modulus), zero); + + // div + assert_eq!(four / two, two); + assert_eq!(two / two, one); + assert_eq!(zero / two, zero); + assert_eq!(two / zero, zero); // Undefined behavior + assert_eq!(zero.inv(), zero); // Undefined behavior + + // mul + div + for _ in 0..100 { + let f = F::rand(); + if f == zero { + println!("skipped zero"); + continue; + } + assert_eq!(f * f.inv(), one); + assert_eq!(f.inv() * f, one); + } + + // pow + assert_eq!(two.pow(F::Integer::try_from(0).unwrap()), one); + assert_eq!(two.pow(int_one), two); + assert_eq!(two.pow(F::Integer::try_from(2).unwrap()), four); + assert_eq!(two.pow(int_modulus - int_one), one); + assert_eq!(two.pow(int_modulus), two); + + // roots + let mut int_order = F::generator_order(); + for l in 0..MAX_ROOTS + 1 { + assert_eq!( + F::generator().pow(int_order), + F::root(l).unwrap(), + "failure for F::root({})", + l + ); + int_order = int_order >> int_one; + } + + // serialization + let test_inputs = vec![zero, one, F::rand(), F::from(int_modulus - int_one)]; + for want in test_inputs.iter() { + let mut bytes = vec![]; + want.append_to(&mut bytes); + let got = F::read_from(&bytes).unwrap(); + assert_eq!(got, *want); + assert_eq!(bytes.len(), F::BYTES); + } + } + + #[test] + fn test_field32() { + field_element_test::(); + } + + #[test] + fn test_field64() { + field_element_test::(); + } + + #[test] + fn test_field80() { + field_element_test::(); + } + + #[test] + fn test_field126() { + field_element_test::(); } } diff --git a/src/fp.rs b/src/fp.rs index 29d435fa4..272e05550 100644 --- a/src/fp.rs +++ b/src/fp.rs @@ -204,12 +204,6 @@ impl FieldParameters { modp(self.mul(x, 1), self.p) } - /// Returns the number of bytes required to encode field elements. - #[cfg(test)] // This code is only used by tests for now. - pub fn size(&self) -> usize { - (16 - (self.p.leading_zeros() / 8)) as usize - } - #[cfg(test)] pub fn check(&self, p: u128, g: u128, order: u128) { use modinverse::modinverse; @@ -404,7 +398,6 @@ mod tests { expected_p: u128, // Expected fp.p expected_g: u128, // Expected fp.from_elem(fp.g) expected_order: u128, // Expect fp.from_elem(fp.pow(fp.g, expected_order)) == 1 - expected_size: usize, // expected fp.size() } #[test] @@ -415,28 +408,24 @@ mod tests { expected_p: 4293918721, expected_g: 3925978153, expected_order: 1 << 20, - expected_size: 4, }, TestFieldParametersData { fp: FP64, expected_p: 15564440312192434177, expected_g: 7450580596923828125, expected_order: 1 << 59, - expected_size: 8, }, TestFieldParametersData { fp: FP80, expected_p: 779190469673491460259841, expected_g: 41782115852031095118226, expected_order: 1 << 72, - expected_size: 10, }, TestFieldParametersData { fp: FP126, expected_p: 74769074762901517850839147140769382401, expected_g: 43421413544015439978138831414974882540, expected_order: 1 << 118, - expected_size: 16, }, ]; @@ -444,23 +433,15 @@ mod tests { // Check that the field parameters have been constructed properly. t.fp.check(t.expected_p, t.expected_g, t.expected_order); - // Check that the field element size is computed correctly. - assert_eq!( - t.fp.size(), - t.expected_size, - "error for GF({})", - t.expected_p - ); - // Check that the generator has the correct order. assert_eq!(t.fp.from_elem(t.fp.pow(t.fp.g, t.expected_order)), 1); // Test arithmetic using the field parameters. - test_arithmetic(&t.fp); + arithmetic_test(&t.fp); } } - fn test_arithmetic(fp: &FieldParameters) { + fn arithmetic_test(fp: &FieldParameters) { let mut rng = rand::thread_rng(); let big_p = &fp.p.to_bigint().unwrap(); diff --git a/src/lib.rs b/src/lib.rs index 965c8ae50..018c1a99f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -13,7 +13,7 @@ pub mod benchmarked; pub mod client; pub mod encrypt; mod fft; -pub mod finite_field; +pub mod field; mod fp; mod polynomial; mod prng; diff --git a/src/polynomial.rs b/src/polynomial.rs index 6e4b8e7e4..00b0666f4 100644 --- a/src/polynomial.rs +++ b/src/polynomial.rs @@ -3,59 +3,60 @@ //! Functions for polynomial interpolation and evaluation -use crate::finite_field::*; -use crate::util::*; +use crate::field::FieldElement; + +use std::convert::TryFrom; /// Temporary memory used for FFT #[derive(Debug)] -pub struct PolyFFTTempMemory { - fft_tmp: Vec, - fft_y_sub: Vec, - fft_roots_sub: Vec, +pub struct PolyFFTTempMemory { + fft_tmp: Vec, + fft_y_sub: Vec, + fft_roots_sub: Vec, } -impl PolyFFTTempMemory { +impl PolyFFTTempMemory { fn new(length: usize) -> Self { PolyFFTTempMemory { - fft_tmp: vector_with_length(length), - fft_y_sub: vector_with_length(length), - fft_roots_sub: vector_with_length(length), + fft_tmp: vec![F::zero(); length], + fft_y_sub: vec![F::zero(); length], + fft_roots_sub: vec![F::zero(); length], } } } /// Auxiliary memory for polynomial interpolation and evaluation #[derive(Debug)] -pub struct PolyAuxMemory { - pub roots_2n: Vec, - pub roots_2n_inverted: Vec, - pub roots_n: Vec, - pub roots_n_inverted: Vec, - pub coeffs: Vec, - pub fft_memory: PolyFFTTempMemory, +pub struct PolyAuxMemory { + pub roots_2n: Vec, + pub roots_2n_inverted: Vec, + pub roots_n: Vec, + pub roots_n_inverted: Vec, + pub coeffs: Vec, + pub fft_memory: PolyFFTTempMemory, } -impl PolyAuxMemory { +impl PolyAuxMemory { pub fn new(n: usize) -> Self { PolyAuxMemory { roots_2n: fft_get_roots(2 * n, false), roots_2n_inverted: fft_get_roots(2 * n, true), roots_n: fft_get_roots(n, false), roots_n_inverted: fft_get_roots(n, true), - coeffs: vector_with_length(2 * n), + coeffs: vec![F::zero(); 2 * n], fft_memory: PolyFFTTempMemory::new(2 * n), } } } -fn fft_recurse( - out: &mut [Field], +fn fft_recurse( + out: &mut [F], n: usize, - roots: &[Field], - ys: &[Field], - tmp: &mut [Field], - y_sub: &mut [Field], - roots_sub: &mut [Field], + roots: &[F], + ys: &[F], + tmp: &mut [F], + y_sub: &mut [F], + roots_sub: &mut [F], ) { if n == 1 { out[0] = ys[0]; @@ -106,17 +107,17 @@ fn fft_recurse( } /// Calculate `count` number of roots of unity of order `count` -fn fft_get_roots(count: usize, invert: bool) -> Vec { - let mut roots = vec![Field::from(0); count]; - let mut gen = Field::generator(); +fn fft_get_roots(count: usize, invert: bool) -> Vec { + let mut roots = vec![F::zero(); count]; + let mut gen = F::generator(); if invert { gen = gen.inv(); } - roots[0] = 1.into(); - let step_size: u32 = (Field::generator_order() as u32) / (count as u32); + roots[0] = F::one(); + let step_size = F::generator_order() / F::Integer::try_from(count).unwrap(); // generator for subgroup of order count - gen = gen.pow(step_size.into()); + gen = gen.pow(step_size); roots[1] = gen; @@ -127,13 +128,13 @@ fn fft_get_roots(count: usize, invert: bool) -> Vec { roots } -fn fft_interpolate_raw( - out: &mut [Field], - ys: &[Field], +fn fft_interpolate_raw( + out: &mut [F], + ys: &[F], n_points: usize, - roots: &[Field], + roots: &[F], invert: bool, - mem: &mut PolyFFTTempMemory, + mem: &mut PolyFFTTempMemory, ) { fft_recurse( out, @@ -145,25 +146,25 @@ fn fft_interpolate_raw( &mut mem.fft_roots_sub, ); if invert { - let n_inverse = Field::from(n_points as u32).inv(); + let n_inverse = F::from(F::Integer::try_from(n_points).unwrap()).inv(); for i in 0..n_points { out[i] *= n_inverse; } } } -pub fn poly_fft( - points_out: &mut [Field], - points_in: &[Field], - scaled_roots: &[Field], +pub fn poly_fft( + points_out: &mut [F], + points_in: &[F], + scaled_roots: &[F], n_points: usize, invert: bool, - mem: &mut PolyFFTTempMemory, + mem: &mut PolyFFTTempMemory, ) { fft_interpolate_raw(points_out, points_in, n_points, scaled_roots, invert, mem) } -pub fn poly_horner_eval(poly: &[Field], eval_at: Field, len: usize) -> Field { +pub fn poly_horner_eval(poly: &[F], eval_at: F, len: usize) -> F { let mut result = poly[len - 1]; for i in (0..(len - 1)).rev() { @@ -174,33 +175,37 @@ pub fn poly_horner_eval(poly: &[Field], eval_at: Field, len: usize) -> Field { result } -pub fn poly_interpret_eval( - points: &[Field], - roots: &[Field], - eval_at: Field, - tmp_coeffs: &mut [Field], - fft_memory: &mut PolyFFTTempMemory, -) -> Field { +pub fn poly_interpret_eval( + points: &[F], + roots: &[F], + eval_at: F, + tmp_coeffs: &mut [F], + fft_memory: &mut PolyFFTTempMemory, +) -> F { poly_fft(tmp_coeffs, points, roots, points.len(), true, fft_memory); poly_horner_eval(&tmp_coeffs, eval_at, points.len()) } #[test] fn test_roots() { + use crate::field::Field32; + let count = 128; - let roots = fft_get_roots(count, false); - let roots_inv = fft_get_roots(count, true); + let roots = fft_get_roots::(count, false); + let roots_inv = fft_get_roots::(count, true); for i in 0..count { assert_eq!(roots[i] * roots_inv[i], 1); - assert_eq!(roots[i].pow(Field::from(count as u32)), 1); - assert_eq!(roots_inv[i].pow(Field::from(count as u32)), 1); + assert_eq!(roots[i].pow(u32::try_from(count).unwrap()), 1); + assert_eq!(roots_inv[i].pow(u32::try_from(count).unwrap()), 1); } } #[test] fn test_horner_eval() { - let mut poly = vec![Field::from(0); 4]; + use crate::field::Field32; + + let mut poly = vec![Field32::from(0); 4]; poly[0] = 2.into(); poly[1] = 1.into(); poly[2] = 5.into(); @@ -213,18 +218,21 @@ fn test_horner_eval() { #[test] fn test_fft() { - let count = 128; - let mut mem = PolyAuxMemory::new(count / 2); + use crate::field::Field32; use rand::prelude::*; + use std::convert::TryFrom; + + let count = 128; + let mut mem = PolyAuxMemory::new(count / 2); - let mut poly = vec![Field::from(0); count]; - let mut points2 = vec![Field::from(0); count]; + let mut poly = vec![Field32::from(0); count]; + let mut points2 = vec![Field32::from(0); count]; let points = (0..count) .into_iter() - .map(|_| Field::from(random::())) - .collect::>(); + .map(|_| Field32::from(random::())) + .collect::>(); // From points to coeffs and back poly_fft( @@ -256,9 +264,9 @@ fn test_fft() { &mut mem.fft_memory, ); for i in 0..count { - let mut should_be = Field::from(0); + let mut should_be = Field32::from(0); for j in 0..count { - should_be = mem.roots_2n[i].pow(Field::from(j as u32)) * points[j] + should_be; + should_be = mem.roots_2n[i].pow(u32::try_from(j).unwrap()) * points[j] + should_be; } assert_eq!(should_be, poly[i]); } diff --git a/src/prng.rs b/src/prng.rs index de78b0e26..b6bb898c1 100644 --- a/src/prng.rs +++ b/src/prng.rs @@ -1,19 +1,18 @@ // Copyright (c) 2020 Apple Inc. // SPDX-License-Identifier: MPL-2.0 -use super::finite_field::{Field, FieldElement}; +use super::field::{FieldElement, FieldError}; use aes_ctr::stream_cipher::generic_array::GenericArray; use aes_ctr::stream_cipher::NewStreamCipher; use aes_ctr::stream_cipher::SyncStreamCipher; use aes_ctr::Aes128Ctr; use rand::RngCore; -use std::convert::TryInto; const BLOCK_SIZE: usize = 16; const MAXIMUM_BUFFER_SIZE_IN_BLOCKS: usize = 4096; pub const SEED_LENGTH: usize = 2 * BLOCK_SIZE; -pub fn secret_share(share1: &mut [Field]) -> Vec { +pub fn secret_share(share1: &mut [F]) -> Vec { // get prng array let (data, seed) = random_field_and_seed(share1.len()); @@ -25,26 +24,26 @@ pub fn secret_share(share1: &mut [Field]) -> Vec { seed } -pub fn extract_share_from_seed(length: usize, seed: &[u8]) -> Vec { +pub fn extract_share_from_seed(length: usize, seed: &[u8]) -> Vec { random_field_from_seed(seed, length) } -fn random_field_and_seed(length: usize) -> (Vec, Vec) { +fn random_field_and_seed(length: usize) -> (Vec, Vec) { let mut seed = vec![0u8; SEED_LENGTH]; rand::thread_rng().fill_bytes(&mut seed); let data = random_field_from_seed(&seed, length); (data, seed) } -fn random_field_from_seed(seed: &[u8], length: usize) -> Vec { +fn random_field_from_seed(seed: &[u8], length: usize) -> Vec { let key = GenericArray::from_slice(&seed[..BLOCK_SIZE]); let nonce = GenericArray::from_slice(&seed[BLOCK_SIZE..]); let mut cipher = Aes128Ctr::new(&key, &nonce); - let mut output = super::util::vector_with_length(length); + let mut output = vec![F::zero(); length]; let mut output_written = 0; - let length_in_blocks = length * std::mem::size_of::() / BLOCK_SIZE; + let length_in_blocks = length * F::BYTES / BLOCK_SIZE; // add one more block to account for rejection and roundoff errors let buffer_len_in_blocks = std::cmp::min(MAXIMUM_BUFFER_SIZE_IN_BLOCKS, length_in_blocks + 1); // Note: buffer_len must be a multiple of BLOCK_SIZE, so that different buffer @@ -60,20 +59,17 @@ fn random_field_from_seed(seed: &[u8], length: usize) -> Vec { cipher.apply_keystream(&mut buffer); - // rejection sampling - // - // TODO(cjpatton): Once implemented, use FieldParameters::size() and - // FieldElement::form_bytes() to implement this loop. - let field_size = std::mem::size_of::<::Integer>(); - for chunk in buffer.chunks_exact(field_size) { - let integer = - ::Integer::from_le_bytes(chunk.try_into().unwrap()); - if integer < Field::modulus() { - output[output_written] = Field::from(integer); - output_written += 1; - if output_written == length { - break; + for chunk in buffer.chunks_exact(F::BYTES) { + match F::read_from(chunk) { + Ok(x) => { + output[output_written] = x; + output_written += 1; + if output_written == length { + break; + } } + Err(FieldError::FromBytesModulusOverflow) => (), // reject this sample + Err(err) => panic!("unexpected error: {}", err), } } } @@ -86,10 +82,11 @@ fn random_field_from_seed(seed: &[u8], length: usize) -> Vec { #[cfg(test)] mod tests { use super::*; + use crate::field::Field32; #[test] fn secret_sharing() { - let mut data = vec![Field::from(0); 123]; + let mut data = vec![Field32::from(0); 123]; data[3] = 23.into(); let data_clone = data.clone(); @@ -125,7 +122,7 @@ mod tests { 0xacb8b748, 0x6f5b9d49, 0x887d061b, 0x86db0c58, ]; - let share2 = extract_share_from_seed(reference.len(), &seed); + let share2 = extract_share_from_seed::(reference.len(), &seed); assert_eq!(share2, reference); } @@ -133,7 +130,7 @@ mod tests { /// takes a seed and hash as base64 encoded strings fn random_data_interop(seed_base64: &str, hash_base64: &str, len: usize) { let seed = base64::decode(seed_base64).unwrap(); - let random_data = extract_share_from_seed(len, &seed); + let random_data = extract_share_from_seed::(len, &seed); let random_bytes = crate::util::serialize(&random_data); diff --git a/src/server.rs b/src/server.rs index 18caca43e..adfbbabff 100644 --- a/src/server.rs +++ b/src/server.rs @@ -5,10 +5,10 @@ use crate::{ encrypt::{decrypt_share, EncryptError, PrivateKey}, - finite_field::{merge_vector, Field, FiniteFieldError}, + field::{merge_vector, FieldElement, FieldError}, polynomial::{poly_interpret_eval, PolyAuxMemory}, prng::extract_share_from_seed, - util::{deserialize, proof_length, unpack_proof, vector_with_length}, + util::{deserialize, proof_length, unpack_proof, vector_with_length, SerializeError}, }; /// Possible errors from server operations @@ -19,20 +19,23 @@ pub enum ServerError { Encrypt(#[from] EncryptError), /// Finite field operation error #[error("finite field operation error")] - FiniteField(#[from] FiniteFieldError), + Field(#[from] FieldError), + /// Serialization/deserialization error + #[error("serialization/deserialization error")] + Serialize(#[from] SerializeError), } /// Auxiliary memory for constructing a /// [`VerificationMessage`](struct.VerificationMessage.html) #[derive(Debug)] -pub struct ValidationMemory { - points_f: Vec, - points_g: Vec, - points_h: Vec, - poly_mem: PolyAuxMemory, +pub struct ValidationMemory { + points_f: Vec, + points_g: Vec, + points_h: Vec, + poly_mem: PolyAuxMemory, } -impl ValidationMemory { +impl ValidationMemory { /// Construct a new ValidationMemory object for validating proof shares of /// length `dimension`. pub fn new(dimension: usize) -> Self { @@ -48,22 +51,22 @@ impl ValidationMemory { /// Main workhorse of the server. #[derive(Debug)] -pub struct Server { +pub struct Server { dimension: usize, is_first_server: bool, - accumulator: Vec, - validation_mem: ValidationMemory, + accumulator: Vec, + validation_mem: ValidationMemory, private_key: PrivateKey, } -impl Server { +impl Server { /// Construct a new server instance /// /// Params: /// * `dimension`: the number of elements in the aggregation vector. /// * `is_first_server`: only one of the servers should have this true. /// * `private_key`: the private key for decrypting the share of the proof. - pub fn new(dimension: usize, is_first_server: bool, private_key: PrivateKey) -> Server { + pub fn new(dimension: usize, is_first_server: bool, private_key: PrivateKey) -> Server { Server { dimension, is_first_server, @@ -74,10 +77,10 @@ impl Server { } /// Decrypt and deserialize - fn deserialize_share(&self, encrypted_share: &[u8]) -> Result, ServerError> { + fn deserialize_share(&self, encrypted_share: &[u8]) -> Result, ServerError> { let share = decrypt_share(encrypted_share, &self.private_key)?; Ok(if self.is_first_server { - deserialize(&share) + deserialize(&share)? } else { let len = proof_length(self.dimension); extract_share_from_seed(len, &share) @@ -92,9 +95,9 @@ impl Server { /// [choose_eval_at](#method.choose_eval_at). pub fn generate_verification_message( &mut self, - eval_at: Field, + eval_at: F, share: &[u8], - ) -> Option { + ) -> Option> { let share_field = self.deserialize_share(share).ok()?; generate_verification_message( self.dimension, @@ -112,8 +115,8 @@ impl Server { pub fn aggregate( &mut self, share: &[u8], - v1: &VerificationMessage, - v2: &VerificationMessage, + v1: &VerificationMessage, + v2: &VerificationMessage, ) -> Result { let share_field = self.deserialize_share(share)?; let is_valid = is_valid_share(v1, v2); @@ -131,7 +134,7 @@ impl Server { /// /// These can be merged together using /// [`reconstruct_shares`](../util/fn.reconstruct_shares.html). - pub fn total_shares(&self) -> &[Field] { + pub fn total_shares(&self) -> &[F] { &self.accumulator } @@ -143,7 +146,7 @@ impl Server { /// /// Returns an error if `other_total_shares.len()` is not equal to this //// server's `dimension`. - pub fn merge_total_shares(&mut self, other_total_shares: &[Field]) -> Result<(), ServerError> { + pub fn merge_total_shares(&mut self, other_total_shares: &[F]) -> Result<(), ServerError> { Ok(merge_vector(&mut self.accumulator, other_total_shares)?) } @@ -151,9 +154,9 @@ impl Server { /// /// The point returned is not one of the roots used for polynomial /// evaluation. - pub fn choose_eval_at(&self) -> Field { + pub fn choose_eval_at(&self) -> F { loop { - let eval_at = Field::from(rand::random::()); + let eval_at = F::rand(); if !self.validation_mem.poly_mem.roots_2n.contains(&eval_at) { break eval_at; } @@ -162,24 +165,24 @@ impl Server { } /// Verification message for proof validation -pub struct VerificationMessage { +pub struct VerificationMessage { /// f evaluated at random point - pub f_r: Field, + pub f_r: F, /// g evaluated at random point - pub g_r: Field, + pub g_r: F, /// h evaluated at random point - pub h_r: Field, + pub h_r: F, } /// Given a proof and evaluation point, this constructs the verification /// message. -pub fn generate_verification_message( +pub fn generate_verification_message( dimension: usize, - eval_at: Field, - proof: &[Field], + eval_at: F, + proof: &[F], is_first_server: bool, - mem: &mut ValidationMemory, -) -> Option { + mem: &mut ValidationMemory, +) -> Option> { let unpacked = unpack_proof(proof, dimension)?; let proof_length = 2 * (dimension + 1).next_power_of_two(); @@ -194,7 +197,7 @@ pub fn generate_verification_message( if is_first_server { // only one server needs to subtract one for point_g - mem.points_g[i + 1] = *x - 1.into(); + mem.points_g[i + 1] = *x - F::one(); } else { mem.points_g[i + 1] = *x; } @@ -237,7 +240,10 @@ pub fn generate_verification_message( } /// Decides if the distributed proof is valid -pub fn is_valid_share(v1: &VerificationMessage, v2: &VerificationMessage) -> bool { +pub fn is_valid_share( + v1: &VerificationMessage, + v2: &VerificationMessage, +) -> bool { // reconstruct f_r, g_r, h_r let f_r = v1.f_r + v2.f_r; let g_r = v1.g_r + v2.g_r; @@ -249,6 +255,7 @@ pub fn is_valid_share(v1: &VerificationMessage, v2: &VerificationMessage) -> boo #[cfg(test)] mod tests { use super::*; + use crate::field::Field32; use crate::util; #[test] @@ -260,9 +267,9 @@ mod tests { 2567182742, 3542857140, 124017604, 4201373647, 431621210, 1618555683, 267689149, ]; - let mut proof: Vec = proof_u32.iter().map(|x| Field::from(*x)).collect(); + let mut proof: Vec = proof_u32.iter().map(|x| Field32::from(*x)).collect(); let share2 = util::tests::secret_share(&mut proof); - let eval_at = Field::from(12313); + let eval_at = Field32::from(12313); let mut validation_mem = ValidationMemory::new(dim); diff --git a/src/util.rs b/src/util.rs index 84382aa4e..9fb14f5e7 100644 --- a/src/util.rs +++ b/src/util.rs @@ -3,11 +3,18 @@ //! Utility functions for handling Prio stuff. -use crate::finite_field::{Field, FieldElement}; - -/// Convenience function for initializing fixed sized vectors of Field elements. -pub fn vector_with_length(len: usize) -> Vec { - vec![Field::from(0); len] +use crate::field::{FieldElement, FieldError}; + +/// Serialization errors +#[derive(Debug, thiserror::Error)] +pub enum SerializeError { + /// Emitted by `deserialize()` if the last chunk of input is not long enough to encode an + /// element of the field. + #[error("last chunk of bytes is incomplete")] + IncompleteChunk, + /// Finite field operation error. + #[error("finite field operation error")] + Field(#[from] FieldError), } /// Returns the number of field elements in the proof for given dimension of @@ -21,36 +28,41 @@ pub fn proof_length(dimension: usize) -> usize { dimension + 3 + (dimension + 1).next_power_of_two() } +/// Convenience function for initializing fixed sized vectors of Field elements. +pub fn vector_with_length(len: usize) -> Vec { + vec![F::zero(); len] +} + /// Unpacked proof with subcomponents -pub struct UnpackedProof<'a> { +pub struct UnpackedProof<'a, F: FieldElement> { /// Data - pub data: &'a [Field], + pub data: &'a [F], /// Zeroth coefficient of polynomial f - pub f0: &'a Field, + pub f0: &'a F, /// Zeroth coefficient of polynomial g - pub g0: &'a Field, + pub g0: &'a F, /// Zeroth coefficient of polynomial h - pub h0: &'a Field, + pub h0: &'a F, /// Non-zero points of polynomial h - pub points_h_packed: &'a [Field], + pub points_h_packed: &'a [F], } /// Unpacked proof with mutable subcomponents -pub struct UnpackedProofMut<'a> { +pub struct UnpackedProofMut<'a, F: FieldElement> { /// Data - pub data: &'a mut [Field], + pub data: &'a mut [F], /// Zeroth coefficient of polynomial f - pub f0: &'a mut Field, + pub f0: &'a mut F, /// Zeroth coefficient of polynomial g - pub g0: &'a mut Field, + pub g0: &'a mut F, /// Zeroth coefficient of polynomial h - pub h0: &'a mut Field, + pub h0: &'a mut F, /// Non-zero points of polynomial h - pub points_h_packed: &'a mut [Field], + pub points_h_packed: &'a mut [F], } /// Unpacks the proof vector into subcomponents -pub fn unpack_proof(proof: &[Field], dimension: usize) -> Option { +pub fn unpack_proof(proof: &[F], dimension: usize) -> Option> { // check the proof length if proof.len() != proof_length(dimension) { return None; @@ -73,7 +85,10 @@ pub fn unpack_proof(proof: &[Field], dimension: usize) -> Option } /// Unpacks a mutable proof vector into mutable subcomponents -pub fn unpack_proof_mut(proof: &mut [Field], dimension: usize) -> Option { +pub fn unpack_proof_mut( + proof: &mut [F], + dimension: usize, +) -> Option> { // check the share length if proof.len() != proof_length(dimension) { return None; @@ -96,46 +111,35 @@ pub fn unpack_proof_mut(proof: &mut [Field], dimension: usize) -> Option Vec { - let field_size = std::mem::size_of::<::Integer>(); - let mut vec = Vec::with_capacity(data.len() * field_size); - +pub fn serialize(data: &[F]) -> Vec { + let mut vec = Vec::::with_capacity(data.len() * F::BYTES); for elem in data.iter() { - // TODO(cjpatton) Implement FieldElement::bytes() that encodes each field element using - // FieldParameters::size() bytes. - let int = ::Integer::from(*elem); - vec.extend(int.to_le_bytes().iter()); + elem.append_to(&mut vec); } - vec } /// Get a vector of field elements from a byte slice -pub fn deserialize(data: &[u8]) -> Vec { - let field_size = std::mem::size_of::<::Integer>(); - - let mut vec = Vec::with_capacity(data.len() / field_size); - use std::convert::TryInto; - - for chunk in data.chunks_exact(field_size) { - // TODO(cjpatton) Implement FieldElement::from_bytes() that decodes field elements from a - // string with FieldParameters::size() bytes. - let integer = ::Integer::from_le_bytes(chunk.try_into().unwrap()); - vec.push(Field::from(integer)); +pub fn deserialize(data: &[u8]) -> Result, SerializeError> { + if data.len() % F::BYTES != 0 { + return Err(SerializeError::IncompleteChunk); } - - vec + let mut vec = Vec::::with_capacity(data.len() / F::BYTES); + for chunk in data.chunks_exact(F::BYTES) { + vec.push(F::read_from(chunk).or_else(|err| Err(SerializeError::Field(err)))?); + } + Ok(vec) } -/// Add two Field element arrays together elementwise. +/// Add two field element arrays together elementwise. /// /// Returns None, when array lengths are not equal. -pub fn reconstruct_shares(share1: &[Field], share2: &[Field]) -> Option> { +pub fn reconstruct_shares(share1: &[F], share2: &[F]) -> Option> { if share1.len() != share2.len() { return None; } - let mut reconstructed = vector_with_length(share1.len()); + let mut reconstructed: Vec = vector_with_length(share1.len()); for (r, (s1, s2)) in reconstructed .iter_mut() @@ -150,17 +154,18 @@ pub fn reconstruct_shares(share1: &[Field], share2: &[Field]) -> Option Vec { + pub fn secret_share(share: &mut [Field32]) -> Vec { use rand::Rng; let mut rng = rand::thread_rng(); let mut random = vec![0u32; share.len()]; - let mut share2 = vector_with_length(share.len()); + let mut share2 = vec![Field32::zero(); share.len()]; rng.fill(&mut random[..]); for (r, f) in random.iter().zip(share2.iter_mut()) { - *f = Field::from(*r); + *f = Field32::from(*r); } for (f1, f2) in share.iter_mut().zip(share2.iter()) { @@ -175,15 +180,15 @@ pub mod tests { let dim = 15; let len = proof_length(dim); - let mut share = vec![Field::from(0); len]; + let mut share = vec![Field32::from(0); len]; let unpacked = unpack_proof_mut(&mut share, dim).unwrap(); - *unpacked.f0 = Field::from(12); + *unpacked.f0 = Field32::from(12); assert_eq!(share[dim], 12); } #[test] fn secret_sharing() { - let mut share1 = vector_with_length(10); + let mut share1 = vec![Field32::zero(); 10]; share1[3] = 21.into(); share1[8] = 123.into(); @@ -197,9 +202,9 @@ pub mod tests { #[test] fn serialization() { - let field = [Field::from(1), Field::from(0x99997)]; + let field = [Field32::from(1), Field32::from(0x99997)]; let bytes = serialize(&field); - let field_deserialized = deserialize(&bytes); + let field_deserialized = deserialize::(&bytes).unwrap(); assert_eq!(field_deserialized, field); } } diff --git a/tests/accumulating.rs b/tests/accumulating.rs index a96562e16..b11296ab6 100644 --- a/tests/accumulating.rs +++ b/tests/accumulating.rs @@ -3,7 +3,7 @@ use prio::client::*; use prio::encrypt::*; -use prio::finite_field::Field; +use prio::field::Field32; use prio::server::*; #[test] @@ -26,8 +26,8 @@ fn accumulation() { let mut reference_count = vec![0u32; dim]; - let mut server1 = Server::new(dim, true, priv_key1); - let mut server2 = Server::new(dim, false, priv_key2); + let mut server1: Server = Server::new(dim, true, priv_key1); + let mut server2: Server = Server::new(dim, false, priv_key2); let mut client_mem = Client::new(dim, pub_key1, pub_key2).unwrap(); @@ -36,8 +36,8 @@ fn accumulation() { for _ in 0..number_of_clients { // some random data let data = (0..dim) - .map(|_| Field::from(rng.gen_range(0, 2))) - .collect::>(); + .map(|_| Field32::from(rng.gen_range(0, 2))) + .collect::>(); // update reference count for (r, d) in reference_count.iter_mut().zip(data.iter()) { diff --git a/tests/tweaks.rs b/tests/tweaks.rs index 9cf5ee16b..c94041e4a 100644 --- a/tests/tweaks.rs +++ b/tests/tweaks.rs @@ -3,7 +3,7 @@ use prio::client::*; use prio::encrypt::*; -use prio::finite_field::Field; +use prio::field::Field32; use prio::server::*; use prio::util::*; @@ -42,8 +42,8 @@ fn tweaks(tweak: Tweak) { let priv_key1_clone = priv_key1.clone(); let pub_key1_clone = pub_key1.clone(); - let mut server1 = Server::new(dim, true, priv_key1); - let mut server2 = Server::new(dim, false, priv_key2); + let mut server1: Server = Server::new(dim, true, priv_key1); + let mut server2: Server = Server::new(dim, false, priv_key2); let mut client_mem = Client::new(dim, pub_key1, pub_key2).unwrap(); @@ -51,16 +51,16 @@ fn tweaks(tweak: Tweak) { let mut data = vector_with_length(dim); if let Tweak::WrongInput = tweak { - data[0] = Field::from(2); + data[0] = Field32::from(2); } let (share1_original, share2) = client_mem.encode_simple(&data).unwrap(); let decrypted_share1 = decrypt_share(&share1_original, &priv_key1_clone).unwrap(); - let mut share1_field = deserialize(&decrypted_share1); + let mut share1_field: Vec = deserialize(&decrypted_share1).unwrap(); let unpacked_share1 = unpack_proof_mut(&mut share1_field, dim).unwrap(); - let one = Field::from(1); + let one = Field32::from(1); match tweak { Tweak::DataPartOfShare => unpacked_share1.data[0] += one,