Skip to content

Commit

Permalink
Simplify MSM implementation, and small speedup (#157)
Browse files Browse the repository at this point in the history
  • Loading branch information
Pratyush authored Dec 30, 2020
1 parent 6a67c35 commit 64ec4fe
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 45 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ The main features of this release are:
- #144 (ark-poly) Add serialization for polynomials and evaluations
- #149 (ark-serialize) Add an impl of `CanonicalSerialize/Deserialize` for `String`.
- #153 (ark-serialize) Add an impl of `CanonicalSerialize/Deserialize` for `Rc<T>`.
- #157 (ark-ec) Speed up `variable_base_msm` by not relying on unnecessary normalization.
- #158 (ark-serialize) Add an impl of `CanonicalSerialize/Deserialize` for `()`.

### Bug fixes
Expand Down
85 changes: 40 additions & 45 deletions ec/src/msm/variable_base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,19 @@ use rayon::prelude::*;
pub struct VariableBaseMSM;

impl VariableBaseMSM {
fn msm_inner<G: AffineCurve>(
pub fn multi_scalar_mul<G: AffineCurve>(
bases: &[G],
scalars: &[<G::ScalarField as PrimeField>::BigInt],
) -> G::Projective
where
G::Projective: ProjectiveCurve<Affine = G>,
{
let c = if scalars.len() < 32 {
) -> G::Projective {
let size = ark_std::cmp::min(bases.len(), scalars.len());
let scalars = &scalars[..size];
let bases = &bases[..size];
let scalars_and_bases_iter = scalars.iter().zip(bases).filter(|(s, _)| !s.is_zero());

let c = if size < 32 {
3
} else {
super::ln_without_floats(scalars.len()) + 2
super::ln_without_floats(size) + 2
};

let num_bits = <G::ScalarField as PrimeField>::Params::MODULUS_BITS as usize;
Expand All @@ -36,42 +38,42 @@ impl VariableBaseMSM {
let mut res = zero;
// We don't need the "zero" bucket, so we only have 2^c - 1 buckets
let mut buckets = vec![zero; (1 << c) - 1];
scalars
.iter()
.zip(bases)
.filter(|(s, _)| !s.is_zero())
.for_each(|(&scalar, base)| {
if scalar == fr_one {
// We only process unit scalars once in the first window.
if w_start == 0 {
res.add_assign_mixed(base);
}
} else {
let mut scalar = scalar;
// This clone is cheap, because the iterator contains just a
// pointer and an index into the original vectors
scalars_and_bases_iter.clone().for_each(|(&scalar, base)| {
if scalar == fr_one {
// We only process unit scalars once in the first window.
if w_start == 0 {
res.add_assign_mixed(base);
}
} else {
let mut scalar = scalar;

// We right-shift by w_start, thus getting rid of the
// lower bits.
scalar.divn(w_start as u32);
// We right-shift by w_start, thus getting rid of the
// lower bits.
scalar.divn(w_start as u32);

// We mod the remaining bits by the window size.
let scalar = scalar.as_ref()[0] % (1 << c);
// We mod the remaining bits by the window size.
let scalar = scalar.as_ref()[0] % (1 << c);

// If the scalar is non-zero, we update the corresponding
// bucket.
// (Recall that `buckets` doesn't have a zero bucket.)
if scalar != 0 {
buckets[(scalar - 1) as usize].add_assign_mixed(base);
}
// If the scalar is non-zero, we update the corresponding
// bucket.
// (Recall that `buckets` doesn't have a zero bucket.)
if scalar != 0 {
buckets[(scalar - 1) as usize].add_assign_mixed(base);
}
});
let buckets = G::Projective::batch_normalization_into_affine(&buckets);
}
});

// Compute sum_{i in 0..num_buckets} (sum_{j in i..num_buckets} bucket[j])
// We could first normalize `buckets` and then use mixed-addition
// here, but that's slower for the kinds of groups we care about
// (Short Weierstrass curves and Twisted Edwards curves).
let mut running_sum = G::Projective::zero();
for b in buckets.into_iter().rev() {
running_sum.add_assign_mixed(&b);
res += running_sum;
}

buckets.into_iter().rev().for_each(|b| {
running_sum += &b;
res += &running_sum;
});
res
})
.collect();
Expand All @@ -81,7 +83,7 @@ impl VariableBaseMSM {

// We're traversing windows from high to low.
lowest
+ window_sums[1..]
+ &window_sums[1..]
.iter()
.rev()
.fold(zero, |mut total, sum_i| {
Expand All @@ -92,11 +94,4 @@ impl VariableBaseMSM {
total
})
}

pub fn multi_scalar_mul<G: AffineCurve>(
bases: &[G],
scalars: &[<G::ScalarField as PrimeField>::BigInt],
) -> G::Projective {
Self::msm_inner(bases, scalars)
}
}

0 comments on commit 64ec4fe

Please sign in to comment.