Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use any FieldElement and clean up field API #22

Merged
merged 1 commit into from
Apr 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions benches/fft.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,15 @@
use criterion::{criterion_group, criterion_main, Criterion};

use prio::benchmarked::{benchmarked_iterative_fft, benchmarked_recursive_fft};
use prio::finite_field::{Field, FieldElement};
use prio::field::{Field126, FieldElement};

pub fn fft(c: &mut Criterion) {
let test_sizes = [16, 256, 1024, 4096];
for size in test_sizes.iter() {
let mut rng = rand::thread_rng();
let mut inp = vec![Field::zero(); *size];
let mut outp = vec![Field::zero(); *size];
let mut inp = vec![Field126::zero(); *size];
let mut outp = vec![Field126::zero(); *size];
for i in 0..*size {
inp[i] = Field::rand(&mut rng);
inp[i] = Field126::rand();
}

c.bench_function(&format!("iterative/{}", *size), |b| {
Expand Down
12 changes: 6 additions & 6 deletions examples/sum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

use prio::client::*;
use prio::encrypt::*;
use prio::finite_field::*;
use prio::field::*;
use prio::server::*;

fn main() {
Expand Down Expand Up @@ -30,20 +30,20 @@ fn main() {

let data1 = data1_u32
.iter()
.map(|x| Field::from(*x))
.collect::<Vec<Field>>();
.map(|x| Field32::from(*x))
.collect::<Vec<Field32>>();

let data2_u32 = [0, 0, 1, 0, 0, 0, 0, 0];
println!("Client 2 Input: {:?}", data2_u32);

let data2 = data2_u32
.iter()
.map(|x| Field::from(*x))
.collect::<Vec<Field>>();
.map(|x| Field32::from(*x))
.collect::<Vec<Field32>>();

let (share1_1, share1_2) = client1.encode_simple(&data1).unwrap();
let (share2_1, share2_2) = client2.encode_simple(&data2).unwrap();
let eval_at = Field::from(12313);
let eval_at = Field32::from(12313);

let mut server1 = Server::new(dim, true, priv_key1.clone());
let mut server2 = Server::new(dim, false, priv_key2.clone());
Expand Down
4 changes: 2 additions & 2 deletions src/benchmarked.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
//! benchmark, but which we don't want to expose in the public API.

use crate::fft::discrete_fourier_transform;
use crate::finite_field::{Field, FieldElement};
use crate::field::FieldElement;
use crate::polynomial::{poly_fft, PolyAuxMemory};

/// Sets `outp` to the Discrete Fourier Transform (DFT) using an iterative FFT algorithm.
Expand All @@ -13,7 +13,7 @@ pub fn benchmarked_iterative_fft<F: FieldElement>(outp: &mut [F], inp: &[F]) {
}

/// Sets `outp` to the Discrete Fourier Transform (DFT) using a recursive FFT algorithm.
pub fn benchmarked_recursive_fft(outp: &mut [Field], inp: &[Field]) {
pub fn benchmarked_recursive_fft<F: FieldElement>(outp: &mut [F], inp: &[F]) {
let mut mem = PolyAuxMemory::new(inp.len() / 2);
poly_fft(
outp,
Expand Down
105 changes: 66 additions & 39 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,36 +4,61 @@
//! Prio client

use crate::encrypt::*;
use crate::finite_field::*;
use crate::field::FieldElement;
use crate::polynomial::*;
use crate::util::*;

use std::convert::TryFrom;

/// The main object that can be used to create Prio shares
///
/// Client is used to create Prio shares.
#[derive(Debug)]
pub struct Client {
pub struct Client<F: FieldElement> {
dimension: usize,
points_f: Vec<Field>,
points_g: Vec<Field>,
evals_f: Vec<Field>,
evals_g: Vec<Field>,
poly_mem: PolyAuxMemory,
points_f: Vec<F>,
points_g: Vec<F>,
evals_f: Vec<F>,
evals_g: Vec<F>,
poly_mem: PolyAuxMemory<F>,
public_key1: PublicKey,
public_key2: PublicKey,
}

impl Client {
/// Errors that might be emitted by the client.
#[derive(Debug, thiserror::Error)]
pub enum ClientError {
/// Thes error is output by `Client<F>::new()` if the length of the proof would exceed the
/// number of roots of unity that can be generated in the field.
#[error("input size exceeds field capacity")]
InputSizeExceedsFieldCapacity,
/// Thes error is output by `Client<F>::new()` if the length of the proof would exceed the
/// ssytem's addressible memory.
#[error("input size exceeds field capacity")]
InputSizeExceedsMemoryCapacity,
/// Encryption/decryption error
#[error("encryption/decryption error")]
Encrypt(#[from] EncryptError),
}

impl<F: FieldElement> Client<F> {
/// Construct a new Prio client
pub fn new(dimension: usize, public_key1: PublicKey, public_key2: PublicKey) -> Option<Self> {
pub fn new(
dimension: usize,
public_key1: PublicKey,
public_key2: PublicKey,
) -> Result<Self, ClientError> {
let n = (dimension + 1).next_power_of_two();

if 2 * n > Field::generator_order() as usize {
// too many elements for this field, not enough roots of unity
return None;
if let Ok(size) = F::Integer::try_from(2 * n) {
if size > F::generator_order() {
return Err(ClientError::InputSizeExceedsFieldCapacity);
}
} else {
return Err(ClientError::InputSizeExceedsMemoryCapacity);
}

Some(Client {
Ok(Client {
dimension,
points_f: vector_with_length(n),
points_g: vector_with_length(n),
Expand All @@ -46,20 +71,20 @@ impl Client {
}

/// Construct a pair of encrypted shares based on the input data.
pub fn encode_simple(&mut self, data: &[Field]) -> Result<(Vec<u8>, Vec<u8>), EncryptError> {
let copy_data = |share_data: &mut [Field]| {
pub fn encode_simple(&mut self, data: &[F]) -> Result<(Vec<u8>, Vec<u8>), ClientError> {
let copy_data = |share_data: &mut [F]| {
share_data[..].clone_from_slice(data);
};
self.encode_with(copy_data)
Ok(self.encode_with(copy_data)?)
}

/// Construct a pair of encrypted shares using a initilization function.
///
/// This might be slightly more efficient on large vectors, because one can
/// avoid copying the input data.
pub fn encode_with<F>(&mut self, init_function: F) -> Result<(Vec<u8>, Vec<u8>), EncryptError>
pub fn encode_with<G>(&mut self, init_function: G) -> Result<(Vec<u8>, Vec<u8>), EncryptError>
where
F: FnOnce(&mut [Field]),
G: FnOnce(&mut [F]),
{
let mut proof = vector_with_length(proof_length(self.dimension));
// unpack one long vector to different subparts
Expand Down Expand Up @@ -90,21 +115,21 @@ impl Client {

/// Convenience function if one does not want to reuse
/// [`Client`](struct.Client.html).
pub fn encode_simple(
data: &[Field],
pub fn encode_simple<F: FieldElement>(
data: &[F],
public_key1: PublicKey,
public_key2: PublicKey,
) -> Option<(Vec<u8>, Vec<u8>)> {
) -> Result<(Vec<u8>, Vec<u8>), ClientError> {
let dimension = data.len();
let mut client_memory = Client::new(dimension, public_key1, public_key2)?;
client_memory.encode_simple(data).ok()
client_memory.encode_simple(data)
}

fn interpolate_and_evaluate_at_2n(
fn interpolate_and_evaluate_at_2n<F: FieldElement>(
n: usize,
points_in: &[Field],
evals_out: &mut [Field],
mem: &mut PolyAuxMemory,
points_in: &[F],
evals_out: &mut [F],
mem: &mut PolyAuxMemory<F>,
) {
// interpolate through roots of unity
poly_fft(
Expand All @@ -130,20 +155,20 @@ fn interpolate_and_evaluate_at_2n(
///
/// Based on Theorem 2.3.3 from Henry Corrigan-Gibbs' dissertation
/// This constructs the output \pi by doing the necessesary calculations
fn construct_proof(
data: &[Field],
fn construct_proof<F: FieldElement>(
data: &[F],
dimension: usize,
f0: &mut Field,
g0: &mut Field,
h0: &mut Field,
points_h_packed: &mut [Field],
mem: &mut Client,
f0: &mut F,
g0: &mut F,
h0: &mut F,
points_h_packed: &mut [F],
mem: &mut Client<F>,
) {
let n = (dimension + 1).next_power_of_two();

// set zero terms to random
*f0 = Field::from(rand::random::<u32>());
*g0 = Field::from(rand::random::<u32>());
*f0 = F::rand();
*g0 = F::rand();
mem.points_f[0] = *f0;
mem.points_g[0] = *g0;

Expand All @@ -154,7 +179,7 @@ fn construct_proof(
// set g_i = f_i - 1
for i in 0..dimension {
mem.points_f[i + 1] = data[i];
mem.points_g[i + 1] = data[i] - 1.into();
mem.points_g[i + 1] = data[i] - F::one();
}

// interpolate and evaluate at roots of unity
Expand All @@ -174,6 +199,8 @@ fn construct_proof(

#[test]
fn test_encode() {
use crate::field::Field32;

let pub_key1 = PublicKey::from_base64(
"BIl6j+J6dYttxALdjISDv6ZI4/VWVEhUzaS05LgrsfswmbLOgNt9HUC2E0w+9RqZx3XMkdEHBHfNuCSMpOwofVQ=",
)
Expand All @@ -186,8 +213,8 @@ fn test_encode() {
let data_u32 = [0u32, 1, 0, 1, 1, 0, 0, 0, 1];
let data = data_u32
.iter()
.map(|x| Field::from(*x))
.collect::<Vec<Field>>();
.map(|x| Field32::from(*x))
.collect::<Vec<Field32>>();
let encoded_shares = encode_simple(&data, pub_key1, pub_key2);
assert_eq!(encoded_shares.is_some(), true);
assert_eq!(encoded_shares.is_ok(), true);
}
22 changes: 10 additions & 12 deletions src/fft.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
//! This module implements an iterative FFT algorithm for computing the (inverse) Discrete Fourier
//! Transform (DFT) over a slice of field elements.

use crate::finite_field::FieldElement;
use crate::field::FieldElement;
use crate::fp::{log2, MAX_ROOTS};

use std::convert::TryFrom;
Expand Down Expand Up @@ -48,7 +48,7 @@ pub fn discrete_fourier_transform<F: FieldElement>(

let mut w: F;
for l in 1..d + 1 {
w = F::root(0).unwrap(); // one
w = F::one();
let r = F::root(l).unwrap();
let y = 1 << (l - 1);
for i in 0..y {
Expand Down Expand Up @@ -100,19 +100,18 @@ fn bitrev(d: usize, x: usize) -> usize {
#[cfg(test)]
mod tests {
use super::*;
use crate::finite_field::{Field, Field126, Field64, Field80};
use crate::field::{Field126, Field32, Field64, Field80};
use crate::polynomial::{poly_fft, PolyAuxMemory};

fn discrete_fourier_transform_then_inv_test<F: FieldElement>() -> Result<(), FftError> {
let mut rng = rand::thread_rng();
let test_sizes = [1, 2, 4, 8, 16, 256, 1024, 2048];

for size in test_sizes.iter() {
let mut want = vec![F::zero(); *size];
let mut tmp = vec![F::zero(); *size];
let mut got = vec![F::zero(); *size];
for i in 0..*size {
want[i] = F::rand(&mut rng);
want[i] = F::rand();
}

discrete_fourier_transform(&mut tmp, &want)?;
Expand All @@ -125,7 +124,7 @@ mod tests {

#[test]
fn test_field32() {
discrete_fourier_transform_then_inv_test::<Field>().expect("unexpected error");
discrete_fourier_transform_then_inv_test::<Field32>().expect("unexpected error");
}

#[test]
Expand All @@ -146,17 +145,16 @@ mod tests {
#[test]
fn test_recursive_fft() {
let size = 128;
let mut rng = rand::thread_rng();
let mut mem = PolyAuxMemory::new(size / 2);

let mut inp = vec![Field::zero(); size];
let mut want = vec![Field::zero(); size];
let mut got = vec![Field::zero(); size];
let mut inp = vec![Field32::zero(); size];
let mut want = vec![Field32::zero(); size];
let mut got = vec![Field32::zero(); size];
for i in 0..size {
inp[i] = Field::rand(&mut rng);
inp[i] = Field32::rand();
}

discrete_fourier_transform::<Field>(&mut want, &inp).expect("unexpected error");
discrete_fourier_transform::<Field32>(&mut want, &inp).expect("unexpected error");

poly_fft(
&mut got,
Expand Down
Loading