diff --git a/mithril-stm/tests/integration.rs b/mithril-stm/tests/test_stm_protocol.rs similarity index 50% rename from mithril-stm/tests/integration.rs rename to mithril-stm/tests/test_stm_protocol.rs index c29d64f4534..b61a8a05a2f 100644 --- a/mithril-stm/tests/integration.rs +++ b/mithril-stm/tests/test_stm_protocol.rs @@ -1,6 +1,7 @@ use mithril_stm::key_reg::KeyReg; use mithril_stm::stm::{ - Stake, StmClerk, StmInitializer, StmParameters, StmSig, StmSigner, StmVerificationKey, + Stake, StmAggrSig, StmAggrVerificationKey, StmClerk, StmInitializer, StmParameters, StmSig, + StmSigner, StmVerificationKey, }; use mithril_stm::AggregationError; @@ -11,23 +12,11 @@ use rayon::prelude::*; type H = Blake2b; -#[test] -fn test_full_protocol() { - let nparties = 32; - let mut rng = ChaCha20Rng::from_seed([0u8; 32]); - let mut msg = [0u8; 16]; - rng.fill_bytes(&mut msg); - - ////////////////////////// - // initialization phase // - ////////////////////////// - - let params = StmParameters { - k: 357, - m: 2642, - phi_f: 0.2, - }; - +fn initialization_phase( + nparties: u64, + mut rng: ChaCha20Rng, + params: StmParameters, +) -> (Vec>, Vec<(StmVerificationKey, Stake)>) { let parties = (0..nparties) .into_iter() .map(|_| 1 + (rng.next_u64() % 9999)) @@ -35,32 +24,42 @@ fn test_full_protocol() { let mut key_reg = KeyReg::init(); - let mut ps: Vec = Vec::with_capacity(nparties as usize); + let mut initializers: Vec = Vec::with_capacity(nparties as usize); + let mut reg_parties: Vec<(StmVerificationKey, Stake)> = Vec::with_capacity(nparties as usize); + for stake in parties { let p = StmInitializer::setup(params, stake, &mut rng); key_reg.register(stake, p.verification_key()).unwrap(); reg_parties.push((p.verification_key().vk, stake)); - ps.push(p); + initializers.push(p); } let closed_reg = key_reg.close(); - let ps = ps + let signers = initializers .into_par_iter() .map(|p| p.new_signer(closed_reg.clone()).unwrap()) .collect::>>(); - ///////////////////// - // operation phase // - ///////////////////// + (signers, reg_parties) +} - let sigs = ps +fn operation_phase( + params: StmParameters, + signers: Vec>, + reg_parties: Vec<(StmVerificationKey, Stake)>, + msg: [u8; 32], +) -> ( + Result, AggregationError>, + StmAggrVerificationKey, +) { + let sigs = signers .par_iter() .filter_map(|p| p.sign(&msg)) .collect::>(); - let clerk = StmClerk::from_signer(&ps[0]); + let clerk = StmClerk::from_signer(&signers[0]); let avk = clerk.compute_avk(); // Check all parties can verify every sig @@ -71,13 +70,35 @@ fn test_full_protocol() { ); } - // Aggregate with random parties let msig = clerk.aggregate(&sigs, &msg); + (msig, avk) +} + +#[test] +fn test_full_protocol() { + let nparties = 32; + let mut rng = ChaCha20Rng::from_seed([0u8; 32]); + let mut msg = [0u8; 32]; + rng.fill_bytes(&mut msg); + + ////////////////////////// + // initialization phase // + ////////////////////////// + + let params = StmParameters { + k: 357, + m: 2642, + phi_f: 0.2, + }; + + let (signers, reg_parties) = initialization_phase(nparties, rng.clone(), params); + let (msig, avk) = operation_phase(params, signers, reg_parties, msg); + match msig { Ok(aggr) => { println!("Aggregate ok"); - assert!(aggr.verify(&msg, &clerk.compute_avk(), ¶ms).is_ok()); + assert!(aggr.verify(&msg, &avk, ¶ms).is_ok()); } Err(AggregationError::NotEnoughSignatures(n, k)) => { println!("Not enough signatures"); @@ -88,3 +109,34 @@ fn test_full_protocol() { } } } + +#[test] +fn test_full_protocol_batch_verify() { + let batch_size = 5; + let mut rng = ChaCha20Rng::from_seed([0u8; 32]); + + let mut aggr_avks = Vec::new(); + let mut aggr_stms = Vec::new(); + let mut batch_msgs = Vec::new(); + let mut batch_params = Vec::new(); + + let params = StmParameters { + k: 357, + m: 2642, + phi_f: 0.2, + }; + + for _ in 0..batch_size { + let mut msg = [0u8; 32]; + rng.fill_bytes(&mut msg); + let nparties = rng.next_u64() % 33; + let (signers, reg_parties) = initialization_phase(nparties, rng.clone(), params); + let operation = operation_phase(params, signers, reg_parties, msg); + + aggr_avks.push(operation.1); + aggr_stms.push(operation.0.unwrap()); + batch_msgs.push(msg.to_vec()); + batch_params.push(params); + } + assert!(StmAggrSig::batch_verify(&aggr_stms, &batch_msgs, &aggr_avks, &batch_params).is_ok()); +}