From 08c703bd2c29bae4b1be7314c0188124c41c4a6b Mon Sep 17 00:00:00 2001 From: Pratyush Mishra Date: Fri, 25 Dec 2020 11:47:33 -0800 Subject: [PATCH 1/4] Improve API and possible speedup --- ec/src/msm/variable_base.rs | 75 ++++++++++++++++--------------------- 1 file changed, 32 insertions(+), 43 deletions(-) diff --git a/ec/src/msm/variable_base.rs b/ec/src/msm/variable_base.rs index 3fa5f7abd..fc24e2f22 100644 --- a/ec/src/msm/variable_base.rs +++ b/ec/src/msm/variable_base.rs @@ -9,13 +9,10 @@ use rayon::prelude::*; pub struct VariableBaseMSM; impl VariableBaseMSM { - fn msm_inner( + pub fn multi_scalar_mul( bases: &[G], scalars: &[::BigInt], - ) -> G::Projective - where - G::Projective: ProjectiveCurve, - { + ) -> G::Projective { let c = if scalars.len() < 32 { 3 } else { @@ -28,6 +25,11 @@ impl VariableBaseMSM { let zero = G::Projective::zero(); let window_starts: Vec<_> = (0..num_bits).step_by(c).collect(); + let size = ark_std::cmp::min(bases.len(), scalars.len()); + let scalars = &scalars[..size]; + let bases = &bases[..size]; + let scalars_and_bases = scalars.iter().zip(bases).filter(|(s, _)| !s.is_zero()); + // Each window is of size `c`. // We divide up the bits 0..num_bits into windows of size `c`, and // in parallel process each such window. @@ -36,42 +38,36 @@ impl VariableBaseMSM { let mut res = zero; // We don't need the "zero" bucket, so we only have 2^c - 1 buckets let mut buckets = vec![zero; (1 << c) - 1]; - scalars - .iter() - .zip(bases) - .filter(|(s, _)| !s.is_zero()) - .for_each(|(&scalar, base)| { - if scalar == fr_one { - // We only process unit scalars once in the first window. - if w_start == 0 { - res.add_assign_mixed(base); - } - } else { - let mut scalar = scalar; + scalars_and_bases.clone().for_each(|(&scalar, base)| { + if scalar == fr_one { + // We only process unit scalars once in the first window. + if w_start == 0 { + res.add_assign_mixed(base); + } + } else { + let mut scalar = scalar; - // We right-shift by w_start, thus getting rid of the - // lower bits. - scalar.divn(w_start as u32); + // We right-shift by w_start, thus getting rid of the + // lower bits. + scalar.divn(w_start as u32); - // We mod the remaining bits by the window size. - let scalar = scalar.as_ref()[0] % (1 << c); + // We mod the remaining bits by the window size. + let scalar = scalar.as_ref()[0] % (1 << c); - // If the scalar is non-zero, we update the corresponding - // bucket. - // (Recall that `buckets` doesn't have a zero bucket.) - if scalar != 0 { - buckets[(scalar - 1) as usize].add_assign_mixed(base); - } + // If the scalar is non-zero, we update the corresponding + // bucket. + // (Recall that `buckets` doesn't have a zero bucket.) + if scalar != 0 { + buckets[(scalar - 1) as usize].add_assign_mixed(base); } - }); - let buckets = G::Projective::batch_normalization_into_affine(&buckets); + } + }); let mut running_sum = G::Projective::zero(); - for b in buckets.into_iter().rev() { - running_sum.add_assign_mixed(&b); - res += running_sum; - } - + buckets.into_iter().rev().for_each(|b| { + running_sum += &b; + res += &running_sum; + }); res }) .collect(); @@ -81,7 +77,7 @@ impl VariableBaseMSM { // We're traversing windows from high to low. lowest - + window_sums[1..] + + &window_sums[1..] .iter() .rev() .fold(zero, |mut total, sum_i| { @@ -92,11 +88,4 @@ impl VariableBaseMSM { total }) } - - pub fn multi_scalar_mul( - bases: &[G], - scalars: &[::BigInt], - ) -> G::Projective { - Self::msm_inner(bases, scalars) - } } From abbfd91337fd74e38f8214544d97d302d76904ae Mon Sep 17 00:00:00 2001 From: Pratyush Mishra Date: Tue, 29 Dec 2020 15:02:09 -0800 Subject: [PATCH 2/4] Update CHANGELOG --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index b670400a1..a4db2a794 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -58,6 +58,7 @@ The main features of this release are: - #144 (ark-poly) Add serialization for polynomials and evaluations - #149 (ark-serialize) Add an impl of `CanonicalSerialize/Deserialize` for `String`. - #153 (ark-serialize) Add an impl of `CanonicalSerialize/Deserialize` for `Rc`. +- #157 (ark-ec) Speed up `variable_base_msm` by not relying on unnecessary normalization. - #158 (ark-serialize) Add an impl of `CanonicalSerialize/Deserialize` for `()`. ### Bug fixes From 656bc6bd3f15c254df77db45d6d02e696a684c7e Mon Sep 17 00:00:00 2001 From: Pratyush Mishra Date: Tue, 29 Dec 2020 15:11:14 -0800 Subject: [PATCH 3/4] Add comment --- ec/src/msm/variable_base.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/ec/src/msm/variable_base.rs b/ec/src/msm/variable_base.rs index fc24e2f22..96bb7f2e7 100644 --- a/ec/src/msm/variable_base.rs +++ b/ec/src/msm/variable_base.rs @@ -63,6 +63,9 @@ impl VariableBaseMSM { } }); + // We could first normalize `buckets` and then use mixed-addition + // here, but that's slower for the kinds of groups we care about + // (Short Weierstrass curves and Twisted Edwards curves). let mut running_sum = G::Projective::zero(); buckets.into_iter().rev().for_each(|b| { running_sum += &b; From 5434fcaf608fd94b006a6cc6986acf0154eb83f5 Mon Sep 17 00:00:00 2001 From: Pratyush Mishra Date: Tue, 29 Dec 2020 15:36:24 -0800 Subject: [PATCH 4/4] Comment --- ec/src/msm/variable_base.rs | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/ec/src/msm/variable_base.rs b/ec/src/msm/variable_base.rs index 96bb7f2e7..340353644 100644 --- a/ec/src/msm/variable_base.rs +++ b/ec/src/msm/variable_base.rs @@ -13,10 +13,15 @@ impl VariableBaseMSM { bases: &[G], scalars: &[::BigInt], ) -> G::Projective { - let c = if scalars.len() < 32 { + let size = ark_std::cmp::min(bases.len(), scalars.len()); + let scalars = &scalars[..size]; + let bases = &bases[..size]; + let scalars_and_bases_iter = scalars.iter().zip(bases).filter(|(s, _)| !s.is_zero()); + + let c = if size < 32 { 3 } else { - super::ln_without_floats(scalars.len()) + 2 + super::ln_without_floats(size) + 2 }; let num_bits = ::Params::MODULUS_BITS as usize; @@ -25,11 +30,6 @@ impl VariableBaseMSM { let zero = G::Projective::zero(); let window_starts: Vec<_> = (0..num_bits).step_by(c).collect(); - let size = ark_std::cmp::min(bases.len(), scalars.len()); - let scalars = &scalars[..size]; - let bases = &bases[..size]; - let scalars_and_bases = scalars.iter().zip(bases).filter(|(s, _)| !s.is_zero()); - // Each window is of size `c`. // We divide up the bits 0..num_bits into windows of size `c`, and // in parallel process each such window. @@ -38,7 +38,9 @@ impl VariableBaseMSM { let mut res = zero; // We don't need the "zero" bucket, so we only have 2^c - 1 buckets let mut buckets = vec![zero; (1 << c) - 1]; - scalars_and_bases.clone().for_each(|(&scalar, base)| { + // This clone is cheap, because the iterator contains just a + // pointer and an index into the original vectors + scalars_and_bases_iter.clone().for_each(|(&scalar, base)| { if scalar == fr_one { // We only process unit scalars once in the first window. if w_start == 0 { @@ -63,6 +65,7 @@ impl VariableBaseMSM { } }); + // Compute sum_{i in 0..num_buckets} (sum_{j in i..num_buckets} bucket[j]) // We could first normalize `buckets` and then use mixed-addition // here, but that's slower for the kinds of groups we care about // (Short Weierstrass curves and Twisted Edwards curves).