Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify MSM implementation, and small speedup #157

Merged
merged 4 commits into from
Dec 30, 2020
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ The main features of this release are:
- #144 (ark-poly) Add serialization for polynomials and evaluations
- #149 (ark-serialize) Add an impl of `CanonicalSerialize/Deserialize` for `String`.
- #153 (ark-serialize) Add an impl of `CanonicalSerialize/Deserialize` for `Rc<T>`.
- #157 (ark-ec) Speed up `variable_base_msm` by not relying on unnecessary normalization.
- #158 (ark-serialize) Add an impl of `CanonicalSerialize/Deserialize` for `()`.

### Bug fixes
Expand Down
78 changes: 35 additions & 43 deletions ec/src/msm/variable_base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,10 @@ use rayon::prelude::*;
pub struct VariableBaseMSM;

impl VariableBaseMSM {
fn msm_inner<G: AffineCurve>(
pub fn multi_scalar_mul<G: AffineCurve>(
bases: &[G],
scalars: &[<G::ScalarField as PrimeField>::BigInt],
) -> G::Projective
where
G::Projective: ProjectiveCurve<Affine = G>,
{
) -> G::Projective {
let c = if scalars.len() < 32 {
3
} else {
Expand All @@ -28,6 +25,11 @@ impl VariableBaseMSM {
let zero = G::Projective::zero();
let window_starts: Vec<_> = (0..num_bits).step_by(c).collect();

let size = ark_std::cmp::min(bases.len(), scalars.len());
let scalars = &scalars[..size];
let bases = &bases[..size];
let scalars_and_bases = scalars.iter().zip(bases).filter(|(s, _)| !s.is_zero());
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is because in some downstream projects, (poly-commit) algorithms that use MSM are slower if the vectors are mismatched. The zip should fix this, but in benchmarks it doesn't.

(For that matter, neither does this new code, so I can remove it if that's better)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks to me like this is due to line 19, where c is set just on the scalars length? That should probably be moved down, and calculated based on size.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To clarify, this was happening when bases.len() > scalars.len(), so the c calculation wasn't affected.

(But you're right, we should move the c calculation down anyway)

Copy link
Member

@ValarDragon ValarDragon Dec 29, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also scalars should never be bigger then the number of bases, right? Perhaps an assert should be added for that?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm not sure if that's necessary; you could imagine a case where zero-padding makes scalars larger, but because the extras are zero, it doesn't affect the result

(This might be happening in ark-poly-commit already)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hrmm, the function doesn't really make sense in the case where its non-zero. (And in the case where there are more 0's than bases, its being inefficient in allocating them)

Regardless of if this function accepts extra zeroes, imo we should (eventually) change poly-commit to not allocate extra zeros


// Each window is of size `c`.
// We divide up the bits 0..num_bits into windows of size `c`, and
// in parallel process each such window.
Expand All @@ -36,42 +38,39 @@ impl VariableBaseMSM {
let mut res = zero;
// We don't need the "zero" bucket, so we only have 2^c - 1 buckets
let mut buckets = vec![zero; (1 << c) - 1];
scalars
.iter()
.zip(bases)
.filter(|(s, _)| !s.is_zero())
.for_each(|(&scalar, base)| {
if scalar == fr_one {
// We only process unit scalars once in the first window.
if w_start == 0 {
res.add_assign_mixed(base);
}
} else {
let mut scalar = scalar;
scalars_and_bases.clone().for_each(|(&scalar, base)| {
Pratyush marked this conversation as resolved.
Show resolved Hide resolved
if scalar == fr_one {
// We only process unit scalars once in the first window.
if w_start == 0 {
res.add_assign_mixed(base);
}
} else {
let mut scalar = scalar;

// We right-shift by w_start, thus getting rid of the
// lower bits.
scalar.divn(w_start as u32);
// We right-shift by w_start, thus getting rid of the
// lower bits.
scalar.divn(w_start as u32);

// We mod the remaining bits by the window size.
let scalar = scalar.as_ref()[0] % (1 << c);
// We mod the remaining bits by the window size.
let scalar = scalar.as_ref()[0] % (1 << c);

// If the scalar is non-zero, we update the corresponding
// bucket.
// (Recall that `buckets` doesn't have a zero bucket.)
if scalar != 0 {
buckets[(scalar - 1) as usize].add_assign_mixed(base);
}
// If the scalar is non-zero, we update the corresponding
// bucket.
// (Recall that `buckets` doesn't have a zero bucket.)
if scalar != 0 {
buckets[(scalar - 1) as usize].add_assign_mixed(base);
}
});
let buckets = G::Projective::batch_normalization_into_affine(&buckets);
Pratyush marked this conversation as resolved.
Show resolved Hide resolved
}
});

// We could first normalize `buckets` and then use mixed-addition
// here, but that's slower for the kinds of groups we care about
// (Short Weierstrass curves and Twisted Edwards curves).
Pratyush marked this conversation as resolved.
Show resolved Hide resolved
let mut running_sum = G::Projective::zero();
for b in buckets.into_iter().rev() {
running_sum.add_assign_mixed(&b);
res += running_sum;
}

buckets.into_iter().rev().for_each(|b| {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As an aside, we should probably parallelize this loop. (Will have some complexity added, to handle the running sum's updates for subsequent chunks)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is that worthwhile? We're already parallelizing the outer loop (over the windows)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hrmm, probably should benchmark this. I thought its a linear number of buckets in the degree

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought its a linear number of buckets in the degree

It is, but the idea right now is that each thread gets its own window to operate over. Usually the number of windows is more than the number of threads, so the existing allocation is already fine for those cases. However if the number of threads is higher, then we could try leverage the extra threads for better parallelism; but we have to do so in a way that doesn't harm the normal mode of operation.

One idea would be to conditionally parallelize only if we have spare threads; this can be done via rayon's ThreadPool. The idea would be allocate some number of threads to the pool (say 2), and then execute operations inside the pool. The threadpool approach would ensure that all Rayon operations inside the pool use at most two threads

Copy link
Member

@ValarDragon ValarDragon Dec 29, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see. Didn't realize its already parallelized, then I agree its probably unlikely to be a perf bottleneck

running_sum += &b;
res += &running_sum;
});
res
})
.collect();
Expand All @@ -81,7 +80,7 @@ impl VariableBaseMSM {

// We're traversing windows from high to low.
lowest
+ window_sums[1..]
+ &window_sums[1..]
.iter()
.rev()
.fold(zero, |mut total, sum_i| {
Expand All @@ -92,11 +91,4 @@ impl VariableBaseMSM {
total
})
}

pub fn multi_scalar_mul<G: AffineCurve>(
bases: &[G],
scalars: &[<G::ScalarField as PrimeField>::BigInt],
) -> G::Projective {
Self::msm_inner(bases, scalars)
}
}