Skip to content

Commit

Permalink
Return Result from Server methods
Browse files Browse the repository at this point in the history
Various methods on `Server` used to return `Option<T>`, making it
possible for them to fail but not to indicate why. This means
`prio-server` can't tell the difference between a packet being rejected
due to a decryption error or because the decrypted payload was invalid,
causing frustrating situations as described in [1]. We now return
`Result<T, ServerError>` from `Server`'s methods.

In support of this, this commit also includes some other cleanups:

 - various structs and functions are moved from `util.rs` (which will
   hopefully go away altogether in the near future) into the new `proof`
   module
 - the functions `unpack_proof` and `unpack_proof_mut` now also return
   errors, drawn from the new `proof::ProofError` enum.
 - a number of `use module::*` style statements are rewritten to be more
   explicit.

[1] divviup/prio-server#550
  • Loading branch information
tgeoghegan committed Apr 7, 2021
1 parent d7e93ae commit 00db7ce
Show file tree
Hide file tree
Showing 6 changed files with 169 additions and 114 deletions.
11 changes: 7 additions & 4 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@

//! Prio client
use crate::encrypt::*;
use crate::field::FieldElement;
use crate::polynomial::*;
use crate::util::*;
use crate::{
encrypt::{encrypt_share, EncryptError, PublicKey},
field::FieldElement,
polynomial::{poly_fft, PolyAuxMemory},
proof::{proof_length, unpack_proof_mut},
util::{serialize, vector_with_length},
};

use std::convert::TryFrom;

Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,6 @@ pub mod field;
mod fp;
mod polynomial;
mod prng;
pub mod proof;
pub mod server;
pub mod util;
144 changes: 144 additions & 0 deletions src/proof.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
// SPDX-License-Identifier: MPL-2.0

//! Structs and utilities for working with proofs and proof shares
use crate::field::FieldElement;

/// Possible errors from proofs
#[derive(Debug, thiserror::Error)]
pub enum ProofError {
/// Input sizes do not match
#[error("input sizes do not match")]
InputSizeMismatch,
}

/// Returns the number of field elements in the proof for given dimension of
/// data elements
///
/// Proof is a vector, where the first `dimension` elements are the data
/// elements, the next 3 elements are the zero terms for polynomials f, g and h
/// and the remaining elements are non-zero points of h(x).
pub(crate) fn proof_length(dimension: usize) -> usize {
// number of data items + number of zero terms + N
dimension + 3 + (dimension + 1).next_power_of_two()
}

/// Unpacked proof with subcomponents
#[derive(Debug)]
pub(crate) struct UnpackedProof<'a, F: FieldElement> {
/// Data
pub data: &'a [F],
/// Zeroth coefficient of polynomial f
pub f0: &'a F,
/// Zeroth coefficient of polynomial g
pub g0: &'a F,
/// Zeroth coefficient of polynomial h
pub h0: &'a F,
/// Non-zero points of polynomial h
pub points_h_packed: &'a [F],
}

/// Unpacked proof with mutable subcomponents
#[derive(Debug)]
pub struct UnpackedProofMut<'a, F: FieldElement> {
/// Data
pub data: &'a mut [F],
/// Zeroth coefficient of polynomial f
pub f0: &'a mut F,
/// Zeroth coefficient of polynomial g
pub g0: &'a mut F,
/// Zeroth coefficient of polynomial h
pub h0: &'a mut F,
/// Non-zero points of polynomial h
pub points_h_packed: &'a mut [F],
}

/// Unpacks the proof vector into subcomponents
pub(crate) fn unpack_proof<F: FieldElement>(
proof: &[F],
dimension: usize,
) -> Result<UnpackedProof<F>, ProofError> {
// check the proof length
if proof.len() != proof_length(dimension) {
return Err(ProofError::InputSizeMismatch);
}
// split share into components
let (data, rest) = proof.split_at(dimension);
if let ([f0, g0, h0], points_h_packed) = rest.split_at(3) {
Ok(UnpackedProof {
data,
f0,
g0,
h0,
points_h_packed,
})
} else {
Err(ProofError::InputSizeMismatch)
}
}

/// Unpacks a mutable proof vector into mutable subcomponents
// TODO(timg): This is public because it is used by tests/tweaks.rs. We should
// refactor that test so it doesn't require the crate to expose this function or
// UnpackedProofMut.
pub fn unpack_proof_mut<F: FieldElement>(
proof: &mut [F],
dimension: usize,
) -> Result<UnpackedProofMut<F>, ProofError> {
// check the share length
if proof.len() != proof_length(dimension) {
return Err(ProofError::InputSizeMismatch);
}
// split share into components
let (data, rest) = proof.split_at_mut(dimension);
if let ([f0, g0, h0], points_h_packed) = rest.split_at_mut(3) {
Ok(UnpackedProofMut {
data,
f0,
g0,
h0,
points_h_packed,
})
} else {
Err(ProofError::InputSizeMismatch)
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::field::{Field32, Field64};
use assert_matches::assert_matches;

#[test]
fn test_unpack_share_mut() {
let dim = 15;
let len = proof_length(dim);

let mut share = vec![Field32::from(0); len];
let unpacked = unpack_proof_mut(&mut share, dim).unwrap();
*unpacked.f0 = Field32::from(12);
assert_eq!(share[dim], 12);

let mut short_share = vec![Field32::from(0); len - 1];
assert_matches!(
unpack_proof_mut(&mut short_share, dim),
Err(ProofError::InputSizeMismatch)
);
}

#[test]
fn test_unpack_share() {
let dim = 15;
let len = proof_length(dim);

let share = vec![Field64::from(0); len];
unpack_proof(&share, dim).unwrap();

let short_share = vec![Field64::from(0); len - 1];
assert_matches!(
unpack_proof(&short_share, dim),
Err(ProofError::InputSizeMismatch)
);
}
}
15 changes: 9 additions & 6 deletions src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ use crate::{
field::{merge_vector, FieldElement, FieldError},
polynomial::{poly_interpret_eval, PolyAuxMemory},
prng::extract_share_from_seed,
util::{deserialize, proof_length, unpack_proof, vector_with_length, SerializeError},
proof::{proof_length, unpack_proof, ProofError},
util::{deserialize, vector_with_length, SerializeError},
};

/// Possible errors from server operations
Expand All @@ -23,6 +24,9 @@ pub enum ServerError {
/// Serialization/deserialization error
#[error("serialization/deserialization error")]
Serialize(#[from] SerializeError),
/// Proof encoding/decoding error
#[error("proof operation error")]
Proof(#[from] ProofError),
}

/// Auxiliary memory for constructing a
Expand Down Expand Up @@ -97,8 +101,8 @@ impl<F: FieldElement> Server<F> {
&mut self,
eval_at: F,
share: &[u8],
) -> Option<VerificationMessage<F>> {
let share_field = self.deserialize_share(share).ok()?;
) -> Result<VerificationMessage<F>, ServerError> {
let share_field = self.deserialize_share(share)?;
generate_verification_message(
self.dimension,
eval_at,
Expand Down Expand Up @@ -182,7 +186,7 @@ pub fn generate_verification_message<F: FieldElement>(
proof: &[F],
is_first_server: bool,
mem: &mut ValidationMemory<F>,
) -> Option<VerificationMessage<F>> {
) -> Result<VerificationMessage<F>, ServerError> {
let unpacked = unpack_proof(proof, dimension)?;
let proof_length = 2 * (dimension + 1).next_power_of_two();

Expand Down Expand Up @@ -235,8 +239,7 @@ pub fn generate_verification_message<F: FieldElement>(
&mut mem.poly_mem.fft_memory,
);

let vm = VerificationMessage { f_r, g_r, h_r };
Some(vm)
Ok(VerificationMessage { f_r, g_r, h_r })
}

/// Decides if the distributed proof is valid
Expand Down
99 changes: 0 additions & 99 deletions src/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,99 +17,11 @@ pub enum SerializeError {
Field(#[from] FieldError),
}

/// Returns the number of field elements in the proof for given dimension of
/// data elements
///
/// Proof is a vector, where the first `dimension` elements are the data
/// elements, the next 3 elements are the zero terms for polynomials f, g and h
/// and the remaining elements are non-zero points of h(x).
pub fn proof_length(dimension: usize) -> usize {
// number of data items + number of zero terms + N
dimension + 3 + (dimension + 1).next_power_of_two()
}

/// Convenience function for initializing fixed sized vectors of Field elements.
pub fn vector_with_length<F: FieldElement>(len: usize) -> Vec<F> {
vec![F::zero(); len]
}

/// Unpacked proof with subcomponents
pub struct UnpackedProof<'a, F: FieldElement> {
/// Data
pub data: &'a [F],
/// Zeroth coefficient of polynomial f
pub f0: &'a F,
/// Zeroth coefficient of polynomial g
pub g0: &'a F,
/// Zeroth coefficient of polynomial h
pub h0: &'a F,
/// Non-zero points of polynomial h
pub points_h_packed: &'a [F],
}

/// Unpacked proof with mutable subcomponents
pub struct UnpackedProofMut<'a, F: FieldElement> {
/// Data
pub data: &'a mut [F],
/// Zeroth coefficient of polynomial f
pub f0: &'a mut F,
/// Zeroth coefficient of polynomial g
pub g0: &'a mut F,
/// Zeroth coefficient of polynomial h
pub h0: &'a mut F,
/// Non-zero points of polynomial h
pub points_h_packed: &'a mut [F],
}

/// Unpacks the proof vector into subcomponents
pub fn unpack_proof<F: FieldElement>(proof: &[F], dimension: usize) -> Option<UnpackedProof<F>> {
// check the proof length
if proof.len() != proof_length(dimension) {
return None;
}
// split share into components
let (data, rest) = proof.split_at(dimension);
let (zero_terms, points_h_packed) = rest.split_at(3);
if let [f0, g0, h0] = zero_terms {
let unpacked = UnpackedProof {
data,
f0,
g0,
h0,
points_h_packed,
};
Some(unpacked)
} else {
None
}
}

/// Unpacks a mutable proof vector into mutable subcomponents
pub fn unpack_proof_mut<F: FieldElement>(
proof: &mut [F],
dimension: usize,
) -> Option<UnpackedProofMut<F>> {
// check the share length
if proof.len() != proof_length(dimension) {
return None;
}
// split share into components
let (data, rest) = proof.split_at_mut(dimension);
let (zero_terms, points_h_packed) = rest.split_at_mut(3);
if let [f0, g0, h0] = zero_terms {
let unpacked = UnpackedProofMut {
data,
f0,
g0,
h0,
points_h_packed,
};
Some(unpacked)
} else {
None
}
}

/// Get a byte array from a slice of field elements
pub fn serialize<F: FieldElement>(data: &[F]) -> Vec<u8> {
let mut vec = Vec::<u8>::with_capacity(data.len() * F::BYTES);
Expand Down Expand Up @@ -175,17 +87,6 @@ pub mod tests {
share2
}

#[test]
fn test_unpack_share() {
let dim = 15;
let len = proof_length(dim);

let mut share = vec![Field32::from(0); len];
let unpacked = unpack_proof_mut(&mut share, dim).unwrap();
*unpacked.f0 = Field32::from(12);
assert_eq!(share[dim], 12);
}

#[test]
fn secret_sharing() {
let mut share1 = vec![Field32::zero(); 10];
Expand Down
13 changes: 8 additions & 5 deletions tests/tweaks.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
// Copyright (c) 2020 Apple Inc.
// SPDX-License-Identifier: MPL-2.0

use prio::client::*;
use prio::encrypt::*;
use prio::field::Field32;
use prio::server::*;
use prio::util::*;
use prio::{
client::Client,
encrypt::{decrypt_share, encrypt_share, PrivateKey, PublicKey},
field::Field32,
proof::unpack_proof_mut,
server::Server,
util::{deserialize, serialize, vector_with_length},
};

#[derive(Debug, Clone, Copy, PartialEq)]
enum Tweak {
Expand Down

0 comments on commit 00db7ce

Please sign in to comment.