by Gregor Mitscha-Baude
This code won 2nd place in the Wasm/MSM ZPrize -- see official results. The version that was submitted is frozen on the branch zprize.
To get started with the code, see how to use this repo. To contribute, see contributing
The multi-scalar multiplication (MSM) problem is: Given elliptic curve points
where
Here's the 2-minute summary of my approach:
- Written in JavaScript and raw WebAssembly text format (WAT)
- The reference implementation was improved by a factor of 5-8x
- On a reasonable machine, we can do a modular multiplication in < 100ns
- Experiments with Barrett reduction & Karatsuba, but ended up sticking to Montgomery which I could make marginally faster
- A crucial insight was to use a non-standard limb size of 30 bits for representing field elements, to save carry operations, which are expensive in Wasm
- As probably everyone in this contest, I use Pippenger / the bucket method, with batch-affine additions. I also have NAF scalars, and do GLV decomposition
- An interesting realization was that we can use batch-affine, not just for the bucket accumulation as shown by Aztec, but also for the entire bucket reduction step. Thus, curve arithmetic in my MSM is 99.9% affine!
- Laying out points in memory in the right order before doing batched additions seems like the way to go
Here are some performance timings, measured in node 16 on the CoreWeave server that ZPrize participants were given. We ran each instance 10 times and took the median and standard deviation:
Size | Reference (sec) | Ours (sec) | Speed-up |
---|---|---|---|
2.84 |
0.37 |
|
|
9.59 |
1.38 |
|
|
32.9 |
4.98 |
|
On my local machine (Intel i7), overall performance is a bit better than these numbers, but relative gains somewhat smaller, between 5-6x.
Below, I give a more detailed account of my implementation and convey some of my personal learnings. Further at the bottom, you'll find a section on how to use this repo.
First, a couple of words on the project architecture. I started into this competition with a specific assumption: That, to create code that runs fast in the browser, the best way to go is to just write most of it in JavaScript, and only mix in hand-written Wasm for the low-level arithmetic and some hot loops. This is in contrast to more typical architectures where all the crypto-related code is written in a separate high-level language and compiled to Wasm for use in the JS ecosystem. Being a JS developer, who professionally works in a code base where large parts are compiled to JS from OCaml and Rust, I developed a dislike for the impendance mismatch, bad debugging experience, and general complexity such an architecture creates. It's nice if you can click on a function definition and end up in source code of that function, not in.. some auto-generated TS declaration file which hides an opaque blob of compiled Wasm. (Looking at layers of glue code and wasm-bindgen incantations, I feel a similar amount of pain for the Rust developer on the other side of the language gap.)
So, I started out by implementing everything from scratch, in JS – not Wasm yet, because it seemed too hard to find the right sequences of assembly when I was still figuring out the mathematics. After having the arithmetic working in JS, an interesting game began: "How much do we have to move to Wasm?"
Do you need any Wasm? There's a notion sometimes circling around that WebAssembly isn't really for performance (it's for bringing other languages to the web); and that perfectly-written JS which went through enough JIT cycles would be just as fast as Wasm. For cryptography at least, this is radically false. JS doesn't even have 64-bit integers. The most performant option for multiplication is bigints. They're nice because they make it simple:
let z = (x * y) % p;
However, one such modular bigint multiplication, for 381-bits inputs, takes 550ns on my machine. The Montgomery multiplication I created in Wasm takes 87ns!
We definitely want to have multiplication, addition, subtraction and low-level helpers like isEqual
in WebAssembly, using some custom bytes representation for field elements. The funny thing is that this is basically enough! There are diminishing returns for putting anything else in Wasm than this lowest layer. In fact, I was already close to Arkworks speed at the point where I had only the multiplication in Wasm, and was reading out field element limbs as bigints for routines like subtraction. However, it's slow to read out the field elements. What works well is if JS functions only operate with pointers to Wasm memory, never reading their content and just passing them from one Wasm function to the next. For the longest time during working on this competition, I had all slightly higher-level functions, like inversion, curve arithmetic etc, written in JS and operate in this way. This was good enough to be 3-4x faster than Arkworks, which is 100% Wasm!
Near the end, I put a lot of work into moving more critical path logic to Wasm, but this effort was wasteful. There's zero benefit in moving a routine like batchInverse
to Wasm -- I'll actually revert changes like that after the competition. The inverse
function is about the highest level that Wasm should operate on.
A major breakthrough in my work was when I changed the size of field elements limbs from 32 to 30 – this decreases the time for a multiplication from 140ns to 87ns. Multiplications are the clear bottleneck at 60%-80% of the total MSM runtime.
To understand why decreasing the limb size has such an impact, or come up with a change like that in the first place, we have to dive into the details of Montgomery multiplication - which I will do now.
Say our prime has bit length
The point of this representation is that we can efficiently compute the Montgomery product
The basic idea goes like this: You add a multiple of
The real algorithm only needs
We store
Now, the Montgomery radix
We can compute this sum iteratively:
- Initialize
$S = 0$ -
$S = (S + x_i y) 2^{-w}$ for$i = 0,\ldots,n-1.$
Note that the earlier
Since we are only interested in the end result modulo p, we are free to modify each step by adding a multiple of p. Similar to the non-iterative algorithm, we do
Now, here comes the beauty: Since this equation is mod
Similarly, you can replace
In full detail, this is the iterative algorithm for the Montgomery product:
- Initialize
$S_j = 0$ for$j = 0,\ldots,n-1$ - for
$i = 0,\ldots,n-1$ , do:$t = S_0 + x_i y_0$ $(_-, t') = t$ $(_-, q_i) = \mu_0 t'$ $(c, _-) = t + q_ip_0$ - for
$j = 1,\ldots,n-1$ , do:$(c, S_{j-1}) = S_j + x_i y_j + q_i p_j + c$
$S_{n-1} = c$
The
Note that, in the inner loop, we assign the
Also, let's see why the iterative algorithm is much better than the naive algorithm: There, computing
Another note: In the last step, we don't need another carry like
Let's talk about carries, which form a part of this algorithm that is very tweakable. First thing to note is that all of the operations above are implemented on 64-bit integers (i64
in Wasm). To make multiplications like
In WebAssembly, there is no such native "multiply-and-get-carry". Instead, the carry operation
- A right-shift (
i64.shr_u
) of$t$ by the constant$w$ , to get$c$ - A bitwise AND (
i64.and
) of$t$ with the constant$2^w - 1$ , to get$l$ .
Also, every carry is associated with an addition, because
Second, with 32-bit limbs, we need to add 1 carry after every product term, because products fill up the full 64 bits. (If we would add two 64-bit terms, we'd get something that can have 65 bits. This would overflow, i.e. the 65th bit gets discarded, giving us wrong results.) It turns out that 1 carry is almost as heavy as 1 mul + add, so doing the carrying on the terms
How many carries do we need for smaller
-
$k = 2^0 = 1$ term for$w = 32$ -
$k = 2^2 = 4$ terms for$w = 31$ -
$k = 2^4 = 16$ terms for$w = 30$ -
$k = 2^6 = 64$ terms for$w = 29$ -
$k = 2^8 = 256$ terms for$w = 28$
How many terms do we even have? In the worst case, during multiplication,
-
$w = 32$ ,$n = 12$ ,$N = 384$ $\Rightarrow$ max terms: 24, carry-free terms: 1 -
$w = 31$ ,$n = 13$ ,$N = 403$ $\Rightarrow$ max terms: 26, carry-free terms: 4 -
$w = 30$ ,$n = 13$ ,$N = 390$ $\Rightarrow$ max terms: 26, carry-free terms: 16 -
$w = 29$ ,$n = 14$ ,$N = 406$ $\Rightarrow$ max terms: 28, carry-free terms: 64 -
$w = 28$ ,$n = 14$ ,$N = 392$ $\Rightarrow$ max terms: 28, carry-free terms: 256
We see that starting at
The trade-off with using a smaller limb size is that we get a higher
I did experiments with limb sizes
Now concretely, how has our algorithm to be modified to use less carries? I'll show the version that's closest to the original algorithm. It has an additional parameter nSafe
nSafe
is the number of iterations without a carry. A carry is performed in step j of the inner loop, if j % nSafe === 0
. In particular at step 0 we always perform a carry since we don't store the result, so we couldn't do a carry on it later.
- Initialize
$S_j = 0$ for$j = 0,\ldots,n-1$ - for
$i = 0,\ldots,n-1$ , do:$t = S_0 + x_i y_0$ $(_-, t') = t$ $(_-, q_i) = \mu_0 t'$ -
$(c, _-) = t + q_ip_0$ (always carry for j=0) - for
$j = 1,\ldots,n-2$ , do:$t = S_j + x_i y_j + q_i p_j$ - add carry from last iteration:
if ((j-1) % nSafe === 0)
$t = t + c$ - maybe do a carry in this iteration:
if (j % nSafe === 0)
$(c, S_{j-1}) = t$
else
$S_{j-1} = t$
- case that the (n-2)th step does a carry:
if ((n-2) % nSafe === 0)
$(c, S_{n-2}) = S_{n-1} + x_i y_{n-1} + q_i p_{n-1}$ $S_{n-1} = c$
- if the (n-2)th step does no carry, then
$S_{n-1}$ gets never written to:
else
$S_{n-2} = x_i y_{n-1} + q_i p_{n-1}$
- Final round of carries to get back to
$w$ bits per limb:
Set$c = 0$ .
for$i = 0,\ldots,n-1$ , do:$(c, S_{i}) = S_{i} + c$
I encourage you to check for yourself that doing a carry every nSafe
steps of the inner loop is one way to ensure that no more than 2*nSafe
product terms are ever added toether.
In the actual implementation, the inner loop is unrolled, so the if conditions can be resolved at compile time and the places where carries happen are hard-coded in the Wasm code.
In our implementation, we use a sort of meta-programming for that: Our Wasm gets created by JavaScript which leverages a little ad-hoc library that mimics the WAT syntax in JS. In fact, the desire to test out implementations for different limb sizes, with complex compile-time conditions like above, was the initial motivation for starting to generate Wasm with JS; before that, I had written it by hand.
My conclusion on this section is that if you implement cryptography in a new environment like Wasm, you have to rederive your algorithms from first principles. If you just port over well-known algorithms, like the "CIOS method", you will adopt implicit assumptions about what's efficient that went into crafting these algorithms, which might not hold in your environment.
I did a lot more experiments trying to find the fastest multiplication algorithm, that I want to mention briefly. Some time during the competition, it came to my attention that there are some brand-new findings about Barrett reduction, which is a completely different way of reducing products modulo p. This paper, plus some closer analysis done by me within the framework it establishes, reveal that a multiplication + Barrett reduction can be done with an effort of
An interesting sub-result of my analysis is that for many primes (in particular, ours), we can prove that the maximium error is
An awesome property of Barrett reduction is that it is literally performing the reduction
Unfortunately, the fastest Barrett multiplication I was able to implement takes 99ns on my machine -- too much of a difference to Montgomery's 87ns to justify the switch. (At least, I could use the Barrett implementation for the GLV decomposition of scalars, where scalars have to be split up modulo the cube root
One reason for the "unreasonable effectiveness of Montgomery" I observed is that it can be structured so that all of the work happens in a single outer loop. And, for reasons unclear to me, implementing that outer loop with a Wasm loop
instruction is much faster than unrolling it; the inner loop, on the other hand, has to be unrolled. Ridiculously, my fully unrolled Montgomery product is 40% slower than the one with an outer loop. Much slower than the Barrett implementations, which are also unrolled. I wasn't able to get any of those voodoo gains by refactoring the Barrett implementation to use loops.
One hint I heard was that V8 (the JS engine) JIT-compiles vectorized instructions if operations are structured in just the right way. I haven't confirmed this myself, and I don't know if that's what's happening here. It would be great to find out, though.
I tried to use one layer of Karatsuba multiplication in the Barrett version. This is straight-forward as for Barrett, the multiplication and the reduction are two clearly separated steps. Karatsuba didn't help, though -- it was exactly as fast as the version with schoolbook multiplication. For Montgomery, I didn't try Karatsuba because I only understood very late that this was even possible. However, given the loop
paradoxon, and Karatsuba not resulting in a single nice loop, I don't imagine that it can yield any benefits.
Let's move to the higher-level algorithms for computing the MSM:
The bucket method and the technique of batching affine additions is well-documented elsewhere, so I'll skip over most details of those.
Broadly, our implementation uses the Pippenger algorithm / bucket method, where scalars are sliced
into windows of size
For each partition k, points
After sorting the points, computation proceeds in three main steps:
- Each bucket is accumulated into a single point, the bucket sum
$B_{l,k}$ , which is simply the sum of all points in the bucket. - The bucket sums of each partition k are reduced into a partition sum
$P_k = 1 B_{k, 1} + 2 B_{k, 2} + \ldots + L B_{k, L}$ . - the partition sums are reduced into the final result,
$S = P_0 + 2^c P_1 + \ldots + 2^{c(K-1)} P_{K-1}$ .
We use batch-affine additions for step 1 (bucket accumulation), as pioneered by Zac Williamson in Aztec's barretenberg library: AztecProtocol/barretenberg#19. Thus, in this step we loop over all buckets, collect the pairs of points to add, and then do a batch-addition on all of those. This is done in multiple passes, until the points of each bucket are summed to a single point, in an implicit binary tree. In each pass, empty buckets and buckets with 1 remaining point are skipped; also, buckets of uneven length have a dangling point at the end, which doesn't belong to a pair and is skipped and included in a later pass.
As a novelty, we also use batch-affine additions for all of step 2 (bucket reduction). More on that below.
We switch from an affine to a projective point representation between steps 2 and 3. Step 3 is so tiny (< 0.1% of the computation) that the performance of projective curve arithmetic becomes irrelevant.
The algorithm has a significant preparation phase, which happens before step 1, where we split scalars and sort points and such. Before splitting scalars into length-c slices, we do a GLV decomposition, where each 256-bit scalar is split into two 128-bit chunks as
where
Other than processing inputs, the preparation phase is concerned with organizing points. This should be done in a way which:
- enables to efficiently collect independent point pairs to add, in multiple successive passes over all buckets;
- makes memory access efficient when batch-adding pairs => ideally, the 2 points that form a pair, as well as consecutive pairs, are stored next to each other.
We address these two goals by copying all points to linear arrays; we do this K times, once for each partition.
Ordering in each of these arrays is achieved by performing a counting sort of all points with respect to their bucket
Between steps 1 and 2, there is a similar re-organization step. At the end of step 1, bucket sums are accumulated into the 0
locations of each original bucket, which are spread apart as far as the original buckets were long. Before step 2, we copy these bucket sums to a new linear array from 1 to L, for each partition. Doing this empirically reduces the runtime.
Here's a rough breakdown of the time spent in the 5 different phases of the algorithm. We split the preparation phase into two; the "summation steps" are the three steps also defined above.
% Runtime | Phase description |
---|---|
8% | Preparation phase 1 - input processing |
12% | Preparation phase 2 - copying points into bucket order |
65% | Summation step 1 - bucket accumulation |
15% | Summation step 2 - bucket reduction |
0% | Summation step 3 - final sum over partitions |
When you have a list of buckets to accumulate – how do you create a series of valid batches of independent additions, such that in the end, you have accumulated those points into one per bucket?
I want to describe this aspect because it's a bit confusing when you first encounter it, and the literature / code comments I found are also confusing, while the actual answer is super simple.
For simplicity, just look at one bucket, with points in an array:
x_0 | x_1 | x_2 | x_3 | x_4 | x_5
Here's what we do: When we encounter this bucket to collect addition pairs for the first batch, we just greedily take one pair after another, until we run out of pairs:
(x_0, x_1), (x_2, x_3), (x_4, x_5) --> addition pairs
For each collected pair, our batch addition routine add-assigns the second to the first point. So, after the first round, we can implicitly ignore every uneven-indexed point, because the entire sum is now contained at the even indices:
x_0 | ___ | x_2 | ___ | x_4 | ___
When we encounter this bucket for the next addition batch, we again greedily collect pairs starting from index 0. This time, we only have to skip an index every time when we collect a pair. The last point can't be added to a pair, so is skipped:
(x_0, x_2) --> addition pairs
After this round x_2
was added into x_0
. Now, we can ignore every index not divisible by 4:
x_0 | ___ | ___ | ___ | x_4 | ___
When collecting points the third round, we take pairs from 4 indices apart at a time, which just gives us the final pair:
(x_0, x_4) --> addition pairs
We end up with the final bucket layout, which has all points accumulated into the first one, in a series of independent additions:
x_0 | ___ | ___ | ___ | __ | ___
When we encounter that bucket in every subsequent round, we will skip it every time because the length is not
This trivial algorithm sums up each bucket, in an implicit binary tree, in the minimum possible number of rounds. In the implementation, you walk over all buckets and do what I described here. Simple!
Let's turn our attention to step 2. At the beginning of this step, we are given bucket sums
We actually need one such sum for every partition, but they are fully independent, so we are leaving out the
There's a well-known algorithm for computing this sum with just
- Set
$R = 0$ ,$P = 0$ . - for
$l = L, \ldots 1$ :$R = R + B_l$ $P = P + R$
In each step
Now the obvious question: Can we use batch-affine additions here? Clearly,
The bad news is that every partial sum depends on the last one:
Let's quickly understand the trade-off with a napkin calculation: With projective arithmetic, we could use mixed additions for all the
So, it's clearly not worth it to use batch additions here, even if we account for the savings possible in computing
However, what if we had a way to split the partition sum into independent sub-sums? Actually, we can do this:
This is just the same sum with the indices written differently: An index
Voilà, the first two sums are both of the form of the original parition sum, and they can be computed independently from each other. We have split our two partitions into a lower half
In summary, we can split a partition in two independent halves, at the cost of a logarithmic number of doublings, plus 2 additions to add the three sums back together. These extra doublings/additions don't even have to be affine, since they can be done at the end, when we're ready to leave affine land, so they are really negligible.
We don't have to stop there: We can split each of the sub-partitions again, and so on, recursively. We can do this until we have enough independent partitions that the cost of inversions is amortized. This let's us easily amortize the inversion, and we get the full benefit of batch-affine additions when doing the sums
This is implemented in src/msm.js
, reduceBucketsAffine
. Unfortunately, I didn't realize until writing this down that the extra doublings/additions don't have to be affine; I use batched-affine for those as well, which is probably just slightly suboptimal. Also, I should add that with a well-chosen window size, the bucket reduction step is 3-5x cheaper than the bucket accumulation step, so shaving off 25% of it ends up saving only <5% of overall runtime.
I hope to polish up this repo to become a go-to library for fast finite field and elliptic curve arithmetic, usable across the JS ecosystem. There is some work to do on this and every contribution is highly welcome 🙌