-
Notifications
You must be signed in to change notification settings - Fork 256
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Simplify MSM implementation, and small speedup #157
Changes from 3 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,13 +9,10 @@ use rayon::prelude::*; | |
pub struct VariableBaseMSM; | ||
|
||
impl VariableBaseMSM { | ||
fn msm_inner<G: AffineCurve>( | ||
pub fn multi_scalar_mul<G: AffineCurve>( | ||
bases: &[G], | ||
scalars: &[<G::ScalarField as PrimeField>::BigInt], | ||
) -> G::Projective | ||
where | ||
G::Projective: ProjectiveCurve<Affine = G>, | ||
{ | ||
) -> G::Projective { | ||
let c = if scalars.len() < 32 { | ||
3 | ||
} else { | ||
|
@@ -28,6 +25,11 @@ impl VariableBaseMSM { | |
let zero = G::Projective::zero(); | ||
let window_starts: Vec<_> = (0..num_bits).step_by(c).collect(); | ||
|
||
let size = ark_std::cmp::min(bases.len(), scalars.len()); | ||
let scalars = &scalars[..size]; | ||
let bases = &bases[..size]; | ||
let scalars_and_bases = scalars.iter().zip(bases).filter(|(s, _)| !s.is_zero()); | ||
|
||
// Each window is of size `c`. | ||
// We divide up the bits 0..num_bits into windows of size `c`, and | ||
// in parallel process each such window. | ||
|
@@ -36,42 +38,39 @@ impl VariableBaseMSM { | |
let mut res = zero; | ||
// We don't need the "zero" bucket, so we only have 2^c - 1 buckets | ||
let mut buckets = vec![zero; (1 << c) - 1]; | ||
scalars | ||
.iter() | ||
.zip(bases) | ||
.filter(|(s, _)| !s.is_zero()) | ||
.for_each(|(&scalar, base)| { | ||
if scalar == fr_one { | ||
// We only process unit scalars once in the first window. | ||
if w_start == 0 { | ||
res.add_assign_mixed(base); | ||
} | ||
} else { | ||
let mut scalar = scalar; | ||
scalars_and_bases.clone().for_each(|(&scalar, base)| { | ||
Pratyush marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if scalar == fr_one { | ||
// We only process unit scalars once in the first window. | ||
if w_start == 0 { | ||
res.add_assign_mixed(base); | ||
} | ||
} else { | ||
let mut scalar = scalar; | ||
|
||
// We right-shift by w_start, thus getting rid of the | ||
// lower bits. | ||
scalar.divn(w_start as u32); | ||
// We right-shift by w_start, thus getting rid of the | ||
// lower bits. | ||
scalar.divn(w_start as u32); | ||
|
||
// We mod the remaining bits by the window size. | ||
let scalar = scalar.as_ref()[0] % (1 << c); | ||
// We mod the remaining bits by the window size. | ||
let scalar = scalar.as_ref()[0] % (1 << c); | ||
|
||
// If the scalar is non-zero, we update the corresponding | ||
// bucket. | ||
// (Recall that `buckets` doesn't have a zero bucket.) | ||
if scalar != 0 { | ||
buckets[(scalar - 1) as usize].add_assign_mixed(base); | ||
} | ||
// If the scalar is non-zero, we update the corresponding | ||
// bucket. | ||
// (Recall that `buckets` doesn't have a zero bucket.) | ||
if scalar != 0 { | ||
buckets[(scalar - 1) as usize].add_assign_mixed(base); | ||
} | ||
}); | ||
let buckets = G::Projective::batch_normalization_into_affine(&buckets); | ||
Pratyush marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
}); | ||
|
||
// We could first normalize `buckets` and then use mixed-addition | ||
// here, but that's slower for the kinds of groups we care about | ||
// (Short Weierstrass curves and Twisted Edwards curves). | ||
Pratyush marked this conversation as resolved.
Show resolved
Hide resolved
|
||
let mut running_sum = G::Projective::zero(); | ||
for b in buckets.into_iter().rev() { | ||
running_sum.add_assign_mixed(&b); | ||
res += running_sum; | ||
} | ||
|
||
buckets.into_iter().rev().for_each(|b| { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As an aside, we should probably parallelize this loop. (Will have some complexity added, to handle the running sum's updates for subsequent chunks) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is that worthwhile? We're already parallelizing the outer loop (over the windows) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hrmm, probably should benchmark this. I thought its a linear number of buckets in the degree There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
It is, but the idea right now is that each thread gets its own window to operate over. Usually the number of windows is more than the number of threads, so the existing allocation is already fine for those cases. However if the number of threads is higher, then we could try leverage the extra threads for better parallelism; but we have to do so in a way that doesn't harm the normal mode of operation. One idea would be to conditionally parallelize only if we have spare threads; this can be done via rayon's There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah I see. Didn't realize its already parallelized, then I agree its probably unlikely to be a perf bottleneck |
||
running_sum += &b; | ||
res += &running_sum; | ||
}); | ||
res | ||
}) | ||
.collect(); | ||
|
@@ -81,7 +80,7 @@ impl VariableBaseMSM { | |
|
||
// We're traversing windows from high to low. | ||
lowest | ||
+ window_sums[1..] | ||
+ &window_sums[1..] | ||
.iter() | ||
.rev() | ||
.fold(zero, |mut total, sum_i| { | ||
|
@@ -92,11 +91,4 @@ impl VariableBaseMSM { | |
total | ||
}) | ||
} | ||
|
||
pub fn multi_scalar_mul<G: AffineCurve>( | ||
bases: &[G], | ||
scalars: &[<G::ScalarField as PrimeField>::BigInt], | ||
) -> G::Projective { | ||
Self::msm_inner(bases, scalars) | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change is because in some downstream projects, (poly-commit) algorithms that use MSM are slower if the vectors are mismatched. The
zip
should fix this, but in benchmarks it doesn't.(For that matter, neither does this new code, so I can remove it if that's better)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks to me like this is due to line 19, where
c
is set just on the scalars length? That should probably be moved down, and calculated based on size.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To clarify, this was happening when
bases.len() > scalars.len()
, so thec
calculation wasn't affected.(But you're right, we should move the
c
calculation down anyway)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also scalars should never be bigger then the number of bases, right? Perhaps an assert should be added for that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hm not sure if that's necessary; you could imagine a case where zero-padding makes
scalars
larger, but because the extras are zero, it doesn't affect the result(This might be happening in
ark-poly-commit
already)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hrmm, the function doesn't really make sense in the case where its non-zero. (And in the case where there are more 0's than bases, its being inefficient in allocating them)
Regardless of if this function accepts extra zeroes, imo we should (eventually) change poly-commit to not allocate extra zeros