Skip to content

Commit

Permalink
Add VDAF execution to "test-util" (#905)
Browse files Browse the repository at this point in the history
  • Loading branch information
cjpatton authored Jan 16, 2024
1 parent 105f1c5 commit 0310ca4
Show file tree
Hide file tree
Showing 4 changed files with 154 additions and 143 deletions.
290 changes: 150 additions & 140 deletions src/vdaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -460,161 +460,171 @@ impl<F: FieldElement> Encode for AggregateShare<F> {
}
}

#[cfg(test)]
pub(crate) fn run_vdaf<V, M, const SEED_SIZE: usize>(
vdaf: &V,
agg_param: &V::AggregationParam,
measurements: M,
) -> Result<V::AggregateResult, VdafError>
where
V: Client<16> + Aggregator<SEED_SIZE, 16> + Collector,
M: IntoIterator<Item = V::Measurement>,
{
use rand::prelude::*;
let mut rng = thread_rng();
let mut verify_key = [0; SEED_SIZE];
rng.fill(&mut verify_key[..]);

let mut agg_shares: Vec<Option<V::AggregateShare>> = vec![None; vdaf.num_aggregators()];
let mut num_measurements: usize = 0;
for measurement in measurements.into_iter() {
num_measurements += 1;
let nonce = rng.gen();
let (public_share, input_shares) = vdaf.shard(&measurement, &nonce)?;
let out_shares = run_vdaf_prepare(
vdaf,
&verify_key,
agg_param,
&nonce,
public_share,
input_shares,
)?;
for (out_share, agg_share) in out_shares.into_iter().zip(agg_shares.iter_mut()) {
// Check serialization of output shares
let encoded_out_share = out_share.get_encoded().unwrap();
let round_trip_out_share =
V::OutputShare::get_decoded_with_param(&(vdaf, agg_param), &encoded_out_share)
.unwrap();
assert_eq!(
round_trip_out_share.get_encoded().unwrap(),
encoded_out_share
);
/// Utilities for testing VDAFs.
#[cfg(feature = "test-util")]
#[cfg_attr(docsrs, doc(cfg(feature = "test-util")))]
pub mod test_utils {
use crate::codec::{Encode, ParameterizedDecode};

use super::{Aggregatable, Aggregator, Client, Collector, PrepareTransition, VdafError};

/// Execute the VDAF end-to-end and return the aggregate result.
pub fn run_vdaf<V, M, const SEED_SIZE: usize>(
vdaf: &V,
agg_param: &V::AggregationParam,
measurements: M,
) -> Result<V::AggregateResult, VdafError>
where
V: Client<16> + Aggregator<SEED_SIZE, 16> + Collector,
M: IntoIterator<Item = V::Measurement>,
{
use rand::prelude::*;
let mut rng = thread_rng();
let mut verify_key = [0; SEED_SIZE];
rng.fill(&mut verify_key[..]);

let mut agg_shares: Vec<Option<V::AggregateShare>> = vec![None; vdaf.num_aggregators()];
let mut num_measurements: usize = 0;
for measurement in measurements.into_iter() {
num_measurements += 1;
let nonce = rng.gen();
let (public_share, input_shares) = vdaf.shard(&measurement, &nonce)?;
let out_shares = run_vdaf_prepare(
vdaf,
&verify_key,
agg_param,
&nonce,
public_share,
input_shares,
)?;
for (out_share, agg_share) in out_shares.into_iter().zip(agg_shares.iter_mut()) {
// Check serialization of output shares
let encoded_out_share = out_share.get_encoded().unwrap();
let round_trip_out_share =
V::OutputShare::get_decoded_with_param(&(vdaf, agg_param), &encoded_out_share)
.unwrap();
assert_eq!(
round_trip_out_share.get_encoded().unwrap(),
encoded_out_share
);

let this_agg_share = V::AggregateShare::from(out_share);
if let Some(ref mut inner) = agg_share {
inner.merge(&this_agg_share)?;
} else {
*agg_share = Some(this_agg_share);
let this_agg_share = V::AggregateShare::from(out_share);
if let Some(ref mut inner) = agg_share {
inner.merge(&this_agg_share)?;
} else {
*agg_share = Some(this_agg_share);
}
}
}
}

for agg_share in agg_shares.iter() {
// Check serialization of aggregate shares
let encoded_agg_share = agg_share.as_ref().unwrap().get_encoded().unwrap();
let round_trip_agg_share =
V::AggregateShare::get_decoded_with_param(&(vdaf, agg_param), &encoded_agg_share)
.unwrap();
assert_eq!(
round_trip_agg_share.get_encoded().unwrap(),
encoded_agg_share
);
}

let res = vdaf.unshard(
agg_param,
agg_shares.into_iter().map(|option| option.unwrap()),
num_measurements,
)?;
Ok(res)
}
for agg_share in agg_shares.iter() {
// Check serialization of aggregate shares
let encoded_agg_share = agg_share.as_ref().unwrap().get_encoded().unwrap();
let round_trip_agg_share =
V::AggregateShare::get_decoded_with_param(&(vdaf, agg_param), &encoded_agg_share)
.unwrap();
assert_eq!(
round_trip_agg_share.get_encoded().unwrap(),
encoded_agg_share
);
}

#[cfg(test)]
pub(crate) fn run_vdaf_prepare<V, M, const SEED_SIZE: usize>(
vdaf: &V,
verify_key: &[u8; SEED_SIZE],
agg_param: &V::AggregationParam,
nonce: &[u8; 16],
public_share: V::PublicShare,
input_shares: M,
) -> Result<Vec<V::OutputShare>, VdafError>
where
V: Client<16> + Aggregator<SEED_SIZE, 16> + Collector,
M: IntoIterator<Item = V::InputShare>,
{
let public_share =
V::PublicShare::get_decoded_with_param(vdaf, &public_share.get_encoded().unwrap()).unwrap();
let input_shares = input_shares
.into_iter()
.map(|input_share| input_share.get_encoded().unwrap());

let mut states = Vec::new();
let mut outbound = Vec::new();
for (agg_id, input_share) in input_shares.enumerate() {
let (state, msg) = vdaf.prepare_init(
verify_key,
agg_id,
let res = vdaf.unshard(
agg_param,
nonce,
&public_share,
&V::InputShare::get_decoded_with_param(&(vdaf, agg_id), &input_share)
.expect("failed to decode input share"),
agg_shares.into_iter().map(|option| option.unwrap()),
num_measurements,
)?;
states.push(state);
outbound.push(msg.get_encoded().unwrap());
Ok(res)
}

let mut inbound = vdaf
.prepare_shares_to_prepare_message(
agg_param,
outbound.iter().map(|encoded| {
V::PrepareShare::get_decoded_with_param(&states[0], encoded)
.expect("failed to decode prep share")
}),
)?
.get_encoded()
.unwrap();

let mut out_shares = Vec::new();
loop {
/// Execute VDAF preparation for a single report and return the recovered output shares.
pub fn run_vdaf_prepare<V, M, const SEED_SIZE: usize>(
vdaf: &V,
verify_key: &[u8; SEED_SIZE],
agg_param: &V::AggregationParam,
nonce: &[u8; 16],
public_share: V::PublicShare,
input_shares: M,
) -> Result<Vec<V::OutputShare>, VdafError>
where
V: Client<16> + Aggregator<SEED_SIZE, 16> + Collector,
M: IntoIterator<Item = V::InputShare>,
{
let public_share =
V::PublicShare::get_decoded_with_param(vdaf, &public_share.get_encoded().unwrap())
.unwrap();
let input_shares = input_shares
.into_iter()
.map(|input_share| input_share.get_encoded().unwrap());

let mut states = Vec::new();
let mut outbound = Vec::new();
for state in states.iter_mut() {
match vdaf.prepare_next(
state.clone(),
V::PrepareMessage::get_decoded_with_param(state, &inbound)
.expect("failed to decode prep message"),
)? {
PrepareTransition::Continue(new_state, msg) => {
outbound.push(msg.get_encoded().unwrap());
*state = new_state
}
PrepareTransition::Finish(out_share) => {
out_shares.push(out_share);
for (agg_id, input_share) in input_shares.enumerate() {
let (state, msg) = vdaf.prepare_init(
verify_key,
agg_id,
agg_param,
nonce,
&public_share,
&V::InputShare::get_decoded_with_param(&(vdaf, agg_id), &input_share)
.expect("failed to decode input share"),
)?;
states.push(state);
outbound.push(msg.get_encoded().unwrap());
}

let mut inbound = vdaf
.prepare_shares_to_prepare_message(
agg_param,
outbound.iter().map(|encoded| {
V::PrepareShare::get_decoded_with_param(&states[0], encoded)
.expect("failed to decode prep share")
}),
)?
.get_encoded()
.unwrap();

let mut out_shares = Vec::new();
loop {
let mut outbound = Vec::new();
for state in states.iter_mut() {
match vdaf.prepare_next(
state.clone(),
V::PrepareMessage::get_decoded_with_param(state, &inbound)
.expect("failed to decode prep message"),
)? {
PrepareTransition::Continue(new_state, msg) => {
outbound.push(msg.get_encoded().unwrap());
*state = new_state
}
PrepareTransition::Finish(out_share) => {
out_shares.push(out_share);
}
}
}
}

if outbound.len() == vdaf.num_aggregators() {
// Another round is required before output shares are computed.
inbound = vdaf
.prepare_shares_to_prepare_message(
agg_param,
outbound.iter().map(|encoded| {
V::PrepareShare::get_decoded_with_param(&states[0], encoded)
.expect("failed to decode prep share")
}),
)?
.get_encoded()
.unwrap();
} else if outbound.is_empty() {
// Each Aggregator recovered an output share.
break;
} else {
panic!("Aggregators did not finish the prepare phase at the same time");
if outbound.len() == vdaf.num_aggregators() {
// Another round is required before output shares are computed.
inbound = vdaf
.prepare_shares_to_prepare_message(
agg_param,
outbound.iter().map(|encoded| {
V::PrepareShare::get_decoded_with_param(&states[0], encoded)
.expect("failed to decode prep share")
}),
)?
.get_encoded()
.unwrap();
} else if outbound.is_empty() {
// Each Aggregator recovered an output share.
break;
} else {
panic!("Aggregators did not finish the prepare phase at the same time");
}
}
}

Ok(out_shares)
Ok(out_shares)
}
}

#[cfg(test)]
Expand Down
2 changes: 1 addition & 1 deletion src/vdaf/poplar1.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1500,7 +1500,7 @@ where
#[cfg(test)]
mod tests {
use super::*;
use crate::vdaf::{equality_comparison_test, run_vdaf_prepare};
use crate::vdaf::{equality_comparison_test, test_utils::run_vdaf_prepare};
use assert_matches::assert_matches;
use rand::prelude::*;
use serde::Deserialize;
Expand Down
2 changes: 1 addition & 1 deletion src/vdaf/prio2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ mod tests {
use super::*;
use crate::vdaf::{
equality_comparison_test, fieldvec_roundtrip_test, prio2::test_vector::Priov2TestVector,
run_vdaf,
test_utils::run_vdaf,
};
use assert_matches::assert_matches;
use rand::prelude::*;
Expand Down
3 changes: 2 additions & 1 deletion src/vdaf/prio3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1578,7 +1578,8 @@ mod tests {
#[cfg(feature = "experimental")]
use crate::flp::gadgets::ParallelSumGadget;
use crate::vdaf::{
equality_comparison_test, fieldvec_roundtrip_test, run_vdaf, run_vdaf_prepare,
equality_comparison_test, fieldvec_roundtrip_test,
test_utils::{run_vdaf, run_vdaf_prepare},
};
use assert_matches::assert_matches;
#[cfg(feature = "experimental")]
Expand Down

0 comments on commit 0310ca4

Please sign in to comment.