diff --git a/src/vdaf/mastic.rs b/src/vdaf/mastic.rs index 3024cb4a0..b6ec28d36 100644 --- a/src/vdaf/mastic.rs +++ b/src/vdaf/mastic.rs @@ -287,7 +287,7 @@ where nonce: &[u8; 16], vidpf_keys: [VidpfKey; 2], szk_random: [Seed; 2], - opt_random: Option>, + joint_random_opt: Option>, ) -> Result<(::PublicShare, Vec<::InputShare>), VdafError> { // Compute the measurement shares for each aggregator by generating VIDPF // keys for the measurement and evaluating each of them. @@ -298,58 +298,27 @@ where nonce, )?; - let leader_measurement_share = self - .vidpf - .eval( - &vidpf_keys[0], - &public_share, - &VidpfInput::from_bools(&[false]), - nonce, - )? - .share - + self - .vidpf - .eval( - &vidpf_keys[0], - &public_share, - &VidpfInput::from_bools(&[true]), - nonce, - )? - .share; - let helper_measurement_share = self - .vidpf - .eval( - &vidpf_keys[1], - &public_share, - &VidpfInput::from_bools(&[false]), - nonce, - )? - .share - + self - .vidpf - .eval( - &vidpf_keys[1], - &public_share, - &VidpfInput::from_bools(&[true]), - nonce, - )? - .share; - - let szk_proof_shares = self.szk.prove( + let leader_measurement_share = + self.vidpf.eval_root(&vidpf_keys[0], &public_share, nonce)?; + let helper_measurement_share = + self.vidpf.eval_root(&vidpf_keys[1], &public_share, nonce)?; + + let [leader_szk_proof_share, helper_szk_proof_share] = self.szk.prove( leader_measurement_share.as_ref(), helper_measurement_share.as_ref(), measurement_weight.as_ref(), szk_random, - opt_random, + joint_random_opt, nonce, )?; + let [leader_vidpf_key, helper_vidpf_key] = vidpf_keys; let leader_share = MasticInputShare:: { - vidpf_key: vidpf_keys[0].clone(), - proofs_share: szk_proof_shares[0].clone(), + vidpf_key: leader_vidpf_key, + proofs_share: leader_szk_proof_share, }; let helper_share = MasticInputShare:: { - vidpf_key: vidpf_keys[1].clone(), - proofs_share: szk_proof_shares[1].clone(), + vidpf_key: helper_vidpf_key, + proofs_share: helper_szk_proof_share, }; Ok((public_share, vec![leader_share, helper_share])) } @@ -371,10 +340,10 @@ where { fn shard( &self, - measurement: &(VidpfInput, T::Measurement), + (attribute, weight): &(VidpfInput, T::Measurement), nonce: &[u8; 16], ) -> Result<(Self::PublicShare, Vec), VdafError> { - if measurement.0.len() != self.bits { + if attribute.len() != self.bits { return Err(VdafError::Vidpf(VidpfError::InvalidAttributeLength)); } @@ -382,28 +351,29 @@ where VidpfKey::gen(VidpfServerId::S0)?, VidpfKey::gen(VidpfServerId::S1)?, ]; - let opt_random = match self.szk.has_joint_rand() { - true => Some(Seed::::generate()?), - false => None, + let joint_random_opt = if self.szk.has_joint_rand() { + Some(Seed::::generate()?) + } else { + None }; let szk_random = [ Seed::::generate()?, Seed::::generate()?, ]; - let encoded_measurement = self.encode_measurement(&measurement.1)?; + let encoded_measurement = self.encode_measurement(weight)?; if encoded_measurement.as_ref().len() != self.vidpf.weight_parameter { return Err(VdafError::Uncategorized( - "encoded_measurement is wrong length".to_string(), + "encoded_measurement is the wrong length".to_string(), )); } self.shard_with_random( - &measurement.0, - &self.encode_measurement(&measurement.1)?, + attribute, + &encoded_measurement, nonce, vidpf_keys, szk_random, - opt_random, + joint_random_opt, ) } } diff --git a/src/vidpf.rs b/src/vidpf.rs index b217fe7e0..ca550252c 100644 --- a/src/vidpf.rs +++ b/src/vidpf.rs @@ -276,6 +276,20 @@ impl Vidpf { Ok((next_state, y)) } + pub(crate) fn eval_root( + &self, + key: &VidpfKey, + public_share: &VidpfPublicShare, + nonce: &[u8; NONCE_SIZE], + ) -> Result { + Ok(self + .eval(key, public_share, &VidpfInput::from_bools(&[false]), nonce)? + .share + + self + .eval(key, public_share, &VidpfInput::from_bools(&[true]), nonce)? + .share) + } + fn prg(seed: &VidpfSeed, nonce: &[u8]) -> VidpfPrgOutput { let mut rng = XofFixedKeyAes128::seed_stream(&Seed(*seed), VidpfDomainSepTag::PRG, nonce);