diff --git a/bindings/rust/src/lib.rs b/bindings/rust/src/lib.rs index 729e0d82..d62b4386 100644 --- a/bindings/rust/src/lib.rs +++ b/bindings/rust/src/lib.rs @@ -1580,6 +1580,50 @@ macro_rules! sig_variant_impl { } } + impl MultiPoint for [PublicKey] { + type Output = AggregatePublicKey; + + fn mult(&self, scalars: &[u8], nbits: usize) -> Self::Output { + Self::Output { + point: unsafe { + core::mem::transmute::<&[_], &[$pk_aff]>(self) + } + .mult(scalars, nbits), + } + } + + fn add(&self) -> Self::Output { + Self::Output { + point: unsafe { + core::mem::transmute::<&[_], &[$pk_aff]>(self) + } + .add(), + } + } + } + + impl MultiPoint for [Signature] { + type Output = AggregateSignature; + + fn mult(&self, scalars: &[u8], nbits: usize) -> Self::Output { + Self::Output { + point: unsafe { + core::mem::transmute::<&[_], &[$sig_aff]>(self) + } + .mult(scalars, nbits), + } + } + + fn add(&self) -> Self::Output { + Self::Output { + point: unsafe { + core::mem::transmute::<&[_], &[$sig_aff]>(self) + } + .add(), + } + } + } + #[cfg(test)] mod tests { use super::*; @@ -1942,6 +1986,78 @@ macro_rules! sig_variant_impl { assert_eq!(sk_des.sign(b"asdf", b"qwer", b"zxcv"), sig); } } + + #[test] + fn test_multi_point() { + let dst = b"BLS_SIG_BLS12381G2_XMD:SHA-256_SSWU_RO_POP_"; + let num_pks = 10; + + let seed = [0u8; 32]; + let mut rng = ChaCha20Rng::from_seed(seed); + + // Create public keys + let sks: Vec<_> = + (0..num_pks).map(|_| gen_random_key(&mut rng)).collect(); + + let pks = + sks.iter().map(|sk| sk.sk_to_pk()).collect::>(); + let pks_refs: Vec<&PublicKey> = + pks.iter().map(|pk| pk).collect(); + + // Create random message for pks to all sign + let msg_len = (rng.next_u64() & 0x3F) + 1; + let mut msg = vec![0u8; msg_len as usize]; + rng.fill_bytes(&mut msg); + + // Generate signature for each key pair + let sigs = sks + .iter() + .map(|sk| sk.sign(&msg, dst, &[])) + .collect::>(); + let sigs_refs: Vec<&Signature> = + sigs.iter().map(|s| s).collect(); + + // create random values + let mut rands: Vec = Vec::with_capacity(32 * num_pks); + for _ in 0..num_pks { + let mut r = rng.next_u64(); + while r == 0 { + // Reject zero as it is used for multiplication. + r = rng.next_u64(); + } + rands.extend_from_slice(&r.to_le_bytes()); + } + + // Sanity test each current single signature + let errs = sigs + .iter() + .zip(pks.iter()) + .map(|(s, pk)| (s.verify(true, &msg, dst, &[], pk, true))) + .collect::>(); + assert_eq!(errs, vec![BLST_ERROR::BLST_SUCCESS; num_pks]); + + // sanity test aggregated signature + let agg_pk = AggregatePublicKey::aggregate(&pks_refs, false) + .unwrap() + .to_public_key(); + let agg_sig = AggregateSignature::aggregate(&sigs_refs, false) + .unwrap() + .to_signature(); + let err = agg_sig.verify(true, &msg, dst, &[], &agg_pk, true); + assert_eq!(err, BLST_ERROR::BLST_SUCCESS); + + // test multi-point aggregation using add + let agg_pk = pks.add().to_public_key(); + let agg_sig = sigs.add().to_signature(); + let err = agg_sig.verify(true, &msg, dst, &[], &agg_pk, true); + assert_eq!(err, BLST_ERROR::BLST_SUCCESS); + + // test multi-point aggregation using mult + let agg_pk = pks.mult(&rands, 64).to_public_key(); + let agg_sig = sigs.mult(&rands, 64).to_signature(); + let err = agg_sig.verify(true, &msg, dst, &[], &agg_pk, true); + assert_eq!(err, BLST_ERROR::BLST_SUCCESS); + } } }; }