Skip to content

Commit

Permalink
Merge pull request #13 from abetterinternet/timg/accumulate-vector
Browse files Browse the repository at this point in the history
add accumulate_vector utility function
  • Loading branch information
tgeoghegan authored Mar 16, 2021
2 parents 48e2c9e + cd72ec7 commit e0247d0
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 36 deletions.
7 changes: 5 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "prio"
version = "0.2.0"
version = "0.3.0"
authors = ["Josh Aas <jaas@kflag.net>", "Karl Tarbe <tarbe@apple.com>"]
edition = "2018"
description = "Implementation of the Prio aggregation system core: https://crypto.stanford.edu/prio/"
Expand All @@ -17,5 +17,8 @@ rand = "0.7"
ring = "0.16.15"
thiserror = "1.0"

[dev-dependencies]
assert_matches = "1.5.0"

[[example]]
name = "sum"
name = "sum"
2 changes: 1 addition & 1 deletion examples/sum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,6 @@ fn main() {
let _ = server1.aggregate(&share2_1, &v2_1, &v2_2).unwrap();
let _ = server2.aggregate(&share2_2, &v2_1, &v2_2).unwrap();

server1.merge_total_shares(server2.total_shares());
server1.merge_total_shares(server2.total_shares()).unwrap();
println!("Final Publication: {:?}", server1.total_shares());
}
26 changes: 13 additions & 13 deletions src/encrypt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,19 @@ const KEY_LENGTH: usize = 16;
pub enum EncryptError {
/// Base64 decoding error
#[error("base64 decoding error")]
DecodeBase64Error(#[from] base64::DecodeError),
DecodeBase64(#[from] base64::DecodeError),
/// Error in ECDH
#[error("error in ECDH")]
KeyAgreementError,
KeyAgreement,
/// Buffer for ciphertext was not large enough
#[error("buffer for ciphertext was not large enough")]
EncryptionError,
Encryption,
/// Authentication tags did not match.
#[error("authentication tags did not match")]
DecryptionError,
Decryption,
/// Input ciphertext was too small
#[error("input ciphertext was too small")]
DecryptionLengthError,
DecryptionLength,
}

/// NIST P-256, public key in X9.62 uncompressed format
Expand Down Expand Up @@ -76,16 +76,16 @@ impl PrivateKey {
pub fn encrypt_share(share: &[u8], key: &PublicKey) -> Result<Vec<u8>, EncryptError> {
let rng = ring::rand::SystemRandom::new();
let ephemeral_priv = agreement::EphemeralPrivateKey::generate(&agreement::ECDH_P256, &rng)
.map_err(|_| EncryptError::KeyAgreementError)?;
.map_err(|_| EncryptError::KeyAgreement)?;
let peer_public = agreement::UnparsedPublicKey::new(&agreement::ECDH_P256, &key.0);
let ephemeral_pub = ephemeral_priv
.compute_public_key()
.map_err(|_| EncryptError::KeyAgreementError)?;
.map_err(|_| EncryptError::KeyAgreement)?;

let symmetric_key_bytes = agreement::agree_ephemeral(
ephemeral_priv,
&peer_public,
EncryptError::KeyAgreementError,
EncryptError::KeyAgreement,
|material| Ok(x963_kdf(material, ephemeral_pub.as_ref())),
)?;

Expand All @@ -109,7 +109,7 @@ pub fn encrypt_share(share: &[u8], key: &PublicKey) -> Result<Vec<u8>, EncryptEr
/// symmetic encryption and MAC.
pub fn decrypt_share(share: &[u8], key: &PrivateKey) -> Result<Vec<u8>, EncryptError> {
if share.len() < PUBLICKEY_LENGTH + TAG_LENGTH {
return Err(EncryptError::DecryptionLengthError);
return Err(EncryptError::DecryptionLength);
}
let empheral_pub_bytes: &[u8] = &share[0..PUBLICKEY_LENGTH];

Expand All @@ -122,12 +122,12 @@ pub fn decrypt_share(share: &[u8], key: &PrivateKey) -> Result<Vec<u8>, EncryptE
};

let private_key = agreement::EphemeralPrivateKey::generate(&agreement::ECDH_P256, &fake_rng)
.map_err(|_| EncryptError::KeyAgreementError)?;
.map_err(|_| EncryptError::KeyAgreement)?;

let symmetric_key_bytes = agreement::agree_ephemeral(
private_key,
&ephemeral_pub,
EncryptError::KeyAgreementError,
EncryptError::KeyAgreement,
|material| Ok(x963_kdf(material, empheral_pub_bytes)),
)?;

Expand Down Expand Up @@ -155,15 +155,15 @@ fn decrypt_aes_gcm(key: &[u8], nonce: &[u8], mut data: Vec<u8>) -> Result<Vec<u8
let cipher = Aes128::new(GenericArray::from_slice(key));
cipher
.decrypt_in_place(GenericArray::from_slice(nonce), &[], &mut data)
.map_err(|_| EncryptError::DecryptionError)?;
.map_err(|_| EncryptError::Decryption)?;
Ok(data)
}

fn encrypt_aes_gcm(key: &[u8], nonce: &[u8], mut data: Vec<u8>) -> Result<Vec<u8>, EncryptError> {
let cipher = Aes128::new(GenericArray::from_slice(key));
cipher
.encrypt_in_place(GenericArray::from_slice(nonce), &[], &mut data)
.map_err(|_| EncryptError::EncryptionError)?;
.map_err(|_| EncryptError::Encryption)?;
Ok(data)
}

Expand Down
51 changes: 51 additions & 0 deletions src/finite_field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,14 @@

//! Finite field arithmetic over a prime field using a 32bit prime.

/// Possible errors from finite field operations.
#[derive(Debug, thiserror::Error)]
pub enum FiniteFieldError {
/// Input sizes do not match
#[error("input sizes do not match")]
InputSizeMismatch,
}

/// Newtype wrapper over u32
///
/// Implements the arithmetic over the finite prime field
Expand Down Expand Up @@ -215,3 +223,46 @@ fn test_arithmetic() {
assert_eq!(Field(432).pow(0.into()), 1);
assert_eq!(Field(0).pow(123.into()), 0);
}

/// Merge two vectors of fields by summing other_vector into accumulator.
///
/// # Errors
///
/// Fails if the two vectors do not have the same length.
pub fn merge_vector(
accumulator: &mut [Field],
other_vector: &[Field],
) -> Result<(), FiniteFieldError> {
if accumulator.len() != other_vector.len() {
return Err(FiniteFieldError::InputSizeMismatch);
}
for (a, o) in accumulator.iter_mut().zip(other_vector.iter()) {
*a += *o;
}

Ok(())
}

#[cfg(test)]
mod tests {
use super::*;
use crate::util::vector_with_length;
use assert_matches::assert_matches;

#[test]
fn test_accumulate() {
let mut lhs = vector_with_length(10);
lhs.iter_mut().for_each(|f| *f = Field(1));
let mut rhs = vector_with_length(10);
rhs.iter_mut().for_each(|f| *f = Field(2));

merge_vector(&mut lhs, &rhs).unwrap();

lhs.iter().for_each(|f| assert_eq!(*f, Field(3)));
rhs.iter().for_each(|f| assert_eq!(*f, Field(2)));

let wrong_len = vector_with_length(9);
let result = merge_vector(&mut lhs, &wrong_len);
assert_matches!(result, Err(FiniteFieldError::InputSizeMismatch));
}
}
56 changes: 36 additions & 20 deletions src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,24 @@

//! Prio server

use crate::encrypt::*;
use crate::finite_field::*;
use crate::polynomial::*;
use crate::prng;
use crate::util;
use crate::util::*;
use crate::{
encrypt::{decrypt_share, EncryptError, PrivateKey},
finite_field::{merge_vector, Field, FiniteFieldError},
polynomial::{poly_interpret_eval, PolyAuxMemory},
prng::extract_share_from_seed,
util::{deserialize, proof_length, unpack_proof, vector_with_length},
};

/// Possible errors from server operations
#[derive(Debug, thiserror::Error)]
pub enum ServerError {
/// Encryption/decryption error
#[error("encryption/decryption error")]
Encrypt(#[from] EncryptError),
/// Finite field operation error
#[error("finite field operation error")]
FiniteField(#[from] FiniteFieldError),
}

/// Auxiliary memory for constructing a
/// [`VerificationMessage`](struct.VerificationMessage.html)
Expand Down Expand Up @@ -62,13 +74,13 @@ impl Server {
}

/// Decrypt and deserialize
fn deserialize_share(&self, encrypted_share: &[u8]) -> Result<Vec<Field>, EncryptError> {
fn deserialize_share(&self, encrypted_share: &[u8]) -> Result<Vec<Field>, ServerError> {
let share = decrypt_share(encrypted_share, &self.private_key)?;
Ok(if self.is_first_server {
util::deserialize(&share)
deserialize(&share)
} else {
let len = util::proof_length(self.dimension);
prng::extract_share_from_seed(len, &share)
let len = proof_length(self.dimension);
extract_share_from_seed(len, &share)
})
}

Expand Down Expand Up @@ -102,14 +114,14 @@ impl Server {
share: &[u8],
v1: &VerificationMessage,
v2: &VerificationMessage,
) -> Result<bool, EncryptError> {
) -> Result<bool, ServerError> {
let share_field = self.deserialize_share(share)?;
let is_valid = is_valid_share(v1, v2);
if is_valid {
// add to the accumulator
for (a, s) in self.accumulator.iter_mut().zip(share_field.iter()) {
*a += *s;
}
// Add to the accumulator. share_field also includes the proof
// encoding, so we slice off the first dimension fields, which are
// the actual data share.
merge_vector(&mut self.accumulator, &share_field[..self.dimension])?;
}

Ok(is_valid)
Expand All @@ -125,11 +137,14 @@ impl Server {

/// Merge shares from another server.
///
/// This modifies the current accumulator
pub fn merge_total_shares(&mut self, other_total_shares: &[Field]) {
for (a, o) in self.accumulator.iter_mut().zip(other_total_shares.iter()) {
*a += *o;
}
/// This modifies the current accumulator.
///
/// # Errors
///
/// Returns an error if `other_total_shares.len()` is not equal to this
//// server's `dimension`.
pub fn merge_total_shares(&mut self, other_total_shares: &[Field]) -> Result<(), ServerError> {
Ok(merge_vector(&mut self.accumulator, other_total_shares)?)
}

/// Choose a random point for polynomial evaluation
Expand Down Expand Up @@ -234,6 +249,7 @@ pub fn is_valid_share(v1: &VerificationMessage, v2: &VerificationMessage) -> boo
#[cfg(test)]
mod tests {
use super::*;
use crate::util;

#[test]
fn test_validation() {
Expand Down

0 comments on commit e0247d0

Please sign in to comment.