Skip to content

Commit

Permalink
pine: Use VDAF test driver from the prio crate
Browse files Browse the repository at this point in the history
  • Loading branch information
cjpatton committed Jul 3, 2024
1 parent 08c04f7 commit fbdde3b
Showing 1 changed file with 7 additions and 172 deletions.
179 changes: 7 additions & 172 deletions crates/daphne/src/pine/vdaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -577,188 +577,23 @@ impl<F: FftFriendlyFieldElement, const PROOFS: u8> Collector for Pine<F, PROOFS>
}
}

// TODO Use upstream version:
// https://github.com/cloudflare/daphne/issues/437
// https://github.com/divviup/libprio-rs/pull/905
#[cfg(test)]
mod test_util {
use prio::codec::{Encode, ParameterizedDecode};

use super::{Aggregatable, Aggregator, Client, Collector, PrepareTransition, VdafError};

/// Execute the VDAF end-to-end and return the aggregate result.
pub fn run_vdaf<V, M, const SEED_SIZE: usize>(
vdaf: &V,
agg_param: &V::AggregationParam,
measurements: M,
) -> Result<V::AggregateResult, VdafError>
where
V: Client<16> + Aggregator<SEED_SIZE, 16> + Collector,
M: IntoIterator<Item = V::Measurement>,
{
use rand::prelude::*;
let mut rng = thread_rng();
let mut verify_key = [0; SEED_SIZE];
rng.fill(&mut verify_key[..]);

let mut agg_shares: Vec<Option<V::AggregateShare>> = vec![None; vdaf.num_aggregators()];
let mut num_measurements: usize = 0;
for measurement in measurements {
num_measurements += 1;
let nonce = rng.gen();
let (public_share, input_shares) = vdaf.shard(&measurement, &nonce)?;
let out_shares = run_vdaf_prepare(
vdaf,
&verify_key,
agg_param,
&nonce,
public_share,
input_shares,
)?;
for (out_share, agg_share) in out_shares.into_iter().zip(agg_shares.iter_mut()) {
// Check serialization of output shares
let encoded_out_share = out_share.get_encoded().unwrap();
let round_trip_out_share =
V::OutputShare::get_decoded_with_param(&(vdaf, agg_param), &encoded_out_share)
.unwrap();
assert_eq!(
round_trip_out_share.get_encoded().unwrap(),
encoded_out_share
);

let this_agg_share = V::AggregateShare::from(out_share);
if let Some(ref mut inner) = agg_share {
inner.merge(&this_agg_share)?;
} else {
*agg_share = Some(this_agg_share);
}
}
}

for agg_share in &agg_shares {
// Check serialization of aggregate shares
let encoded_agg_share = agg_share.as_ref().unwrap().get_encoded().unwrap();
let round_trip_agg_share =
V::AggregateShare::get_decoded_with_param(&(vdaf, agg_param), &encoded_agg_share)
.unwrap();
assert_eq!(
round_trip_agg_share.get_encoded().unwrap(),
encoded_agg_share
);
}

let res = vdaf.unshard(
agg_param,
agg_shares.into_iter().map(|option| option.unwrap()),
num_measurements,
)?;
Ok(res)
}

/// Execute VDAF preparation for a single report and return the recovered output shares.
pub fn run_vdaf_prepare<V, M, const SEED_SIZE: usize>(
vdaf: &V,
verify_key: &[u8; SEED_SIZE],
agg_param: &V::AggregationParam,
nonce: &[u8; 16],
public_share: V::PublicShare,
input_shares: M,
) -> Result<Vec<V::OutputShare>, VdafError>
where
V: Client<16> + Aggregator<SEED_SIZE, 16> + Collector,
M: IntoIterator<Item = V::InputShare>,
{
let public_share =
V::PublicShare::get_decoded_with_param(vdaf, &public_share.get_encoded().unwrap())
.unwrap();
let input_shares = input_shares
.into_iter()
.map(|input_share| input_share.get_encoded().unwrap());

let mut states = Vec::new();
let mut outbound = Vec::new();
for (agg_id, input_share) in input_shares.enumerate() {
let (state, msg) = vdaf.prepare_init(
verify_key,
agg_id,
agg_param,
nonce,
&public_share,
&V::InputShare::get_decoded_with_param(&(vdaf, agg_id), &input_share)
.expect("failed to decode input share"),
)?;
states.push(state);
outbound.push(msg.get_encoded().unwrap());
}

let mut inbound = vdaf
.prepare_shares_to_prepare_message(
agg_param,
outbound.iter().map(|encoded| {
V::PrepareShare::get_decoded_with_param(&states[0], encoded)
.expect("failed to decode prep share")
}),
)?
.get_encoded()
.unwrap();

let mut out_shares = Vec::new();
loop {
let mut outbound = Vec::new();
for state in &mut states {
match vdaf.prepare_next(
state.clone(),
V::PrepareMessage::get_decoded_with_param(state, &inbound)
.expect("failed to decode prep message"),
)? {
PrepareTransition::Continue(new_state, msg) => {
outbound.push(msg.get_encoded().unwrap());
*state = new_state;
}
PrepareTransition::Finish(out_share) => {
out_shares.push(out_share);
}
}
}

if outbound.len() == vdaf.num_aggregators() {
// Another round is required before output shares are computed.
inbound = vdaf
.prepare_shares_to_prepare_message(
agg_param,
outbound.iter().map(|encoded| {
V::PrepareShare::get_decoded_with_param(&states[0], encoded)
.expect("failed to decode prep share")
}),
)?
.get_encoded()
.unwrap();
} else if outbound.is_empty() {
// Each Aggregator recovered an output share.
break;
} else {
panic!("Aggregators did not finish the prepare phase at the same time");
}
}

Ok(out_shares)
}
}

#[cfg(test)]
mod tests {

use prio::{
codec::Decode,
field::Field128,
vdaf::{Aggregator, Client, Collector},
vdaf::{
test_utils::{run_vdaf, run_vdaf_prepare},
xof::Seed,
Aggregator, Client, Collector,
},
};

use crate::pine::{msg, Pine128, Pine64};
use crate::pine::{msg, vdaf::PineVec, Pine128, Pine64};

use assert_matches::assert_matches;

use super::{test_util::*, *};

#[test]
fn run_128() {
let dimension = 100;
Expand Down

0 comments on commit fbdde3b

Please sign in to comment.