Skip to content

Commit

Permalink
Add Multi-scalar-multiplication using pippenger algorithm.
Browse files Browse the repository at this point in the history
  • Loading branch information
paulmillr committed Sep 2, 2024
1 parent ca3e550 commit 7c80242
Show file tree
Hide file tree
Showing 7 changed files with 193 additions and 4 deletions.
21 changes: 21 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ import { secp256k1 } from '@noble/curves/secp256k1'; // ESM and Common.js
- [ed448, X448, decaf448](#ed448-x448-decaf448)
- [bls12-381](#bls12-381)
- [bn254 aka alt_bn128](#bn254-aka-alt_bn128)
- [Multi-scalar-multiplication](#multi-scalar-multiplication)
- [All available imports](#all-available-imports)
- [Accessing a curve's variables](#accessing-a-curves-variables)
- [Abstract API](#abstract-api)
Expand Down Expand Up @@ -309,6 +310,19 @@ different implementations of bn254 do it differently - there is no standard. Poi

For example usage, check out [the implementation of bn254 EVM precompiles](https://github.com/paulmillr/noble-curves/blob/3ed792f8ad9932765b84d1064afea8663a255457/test/bn254.test.js#L697).

#### Multi-scalar-multiplication

```ts
import { secp256k1 } from '@noble/curves/secp256k1';
const p = secp256k1.ProjectivePoint;
const points = [p.BASE, p.BASE.multiply(2n), p.BASE.multiply(4n), p.BASE.multiply(8n)];
p.msm(points, [3n, 5n, 7n, 11n]).equals(p.BASE.multiply(129n)); // 129*G
```

Pippenger algorithm is used underneath.
Multi-scalar-multiplication (MSM) is basically `(Pa + Qb + Rc + ...)`.
It's 10-30x faster vs naive addition for large amount of points.

#### All available imports

```typescript
Expand Down Expand Up @@ -460,6 +474,7 @@ interface ProjConstructor<T> extends GroupConstructor<ProjPointType<T>> {
fromAffine(p: AffinePoint<T>): ProjPointType<T>;
fromHex(hex: Hex): ProjPointType<T>;
fromPrivateKey(privateKey: PrivKey): ProjPointType<T>;
msm(points: ProjPointType[], scalars: bigint[]): ProjPointType<T>;
}
```

Expand Down Expand Up @@ -612,6 +627,7 @@ interface ExtPointConstructor extends GroupConstructor<ExtPointType> {
fromAffine(p: AffinePoint<bigint>): ExtPointType;
fromHex(hex: Hex): ExtPointType;
fromPrivateKey(privateKey: Hex): ExtPointType;
msm(points: ExtPointType[], scalars: bigint[]): ExtPointType;
}
```

Expand Down Expand Up @@ -813,6 +829,11 @@ utils.equalBytes(Uint8Array.from([0xde]), Uint8Array.from([0xde]));

The library has been independently audited:

- at version 1.6.0, in Sep 2024, by [cure53](https://cure53.de)
- PDFs: [in-repo](./audit/2024-09-01-cure53-audit-nbl4.pdf)
- [Changes since audit](https://github.com/paulmillr/noble-curves/compare/1.6.0..main)
- Scope: ed25519, ed448, their add-ons, bls12-381, bn254,
hash-to-curve, low-level primitives bls, tower, edwards, montgomery etc.
- at version 1.2.0, in Sep 2023, by [Kudelski Security](https://kudelskisecurity.com)
- PDFs: [offline](./audit/2023-09-kudelski-audit-starknet.pdf)
- [Changes since audit](https://github.com/paulmillr/noble-curves/compare/1.2.0..main)
Expand Down
53 changes: 53 additions & 0 deletions benchmark/msm_timings.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import { run, mark, compare, utils } from 'micro-bmark';
import { bls12_381 } from '../bls12-381.js';

run(async () => {
const g1 = bls12_381.G1.ProjectivePoint;
const bits = bls12_381.G1.CURVE.nBitLength - 1;
const ones = BigInt(`0b${'1'.repeat(bits)}`);

const onezero = BigInt(`0b${'10'.repeat(bits / 2)}`);
const one8zero = BigInt(`0b${'10000000'.repeat(bits / 8)}`);
// Single scalar
await compare('single point', 5000, {
zero: () => g1.msm([g1.BASE], [0n]),
one: () => g1.msm([g1.BASE], [1n]),
one0: () => g1.msm([g1.ZERO], [1n]),
small: () => g1.msm([g1.BASE], [123n]),
big: () => g1.msm([g1.BASE], [bls12_381.G1.CURVE.n - 1n]),
});
// Multiple
const points = [3n, 5n, 7n, 11n, 13n].map((i) => g1.BASE.multiply(i));
await compare('single point', 500, {
zero: () => g1.msm([g1.BASE, g1.BASE, g1.BASE, g1.BASE, g1.BASE], [0n, 0n, 0n, 0n, 0n]),
zero2: () => g1.msm([g1.ZERO, g1.ZERO, g1.ZERO, g1.ZERO, g1.ZERO], [0n, 0n, 0n, 0n, 0n]),
big: () =>
g1.msm(points, [
bls12_381.G1.CURVE.n - 1n,
bls12_381.G1.CURVE.n - 100n,
bls12_381.G1.CURVE.n - 200n,
bls12_381.G1.CURVE.n - 300n,
bls12_381.G1.CURVE.n - 400n,
]),
same_scalar: () => g1.msm(points, [ones, ones, ones, ones, ones]),
same_scalar2: () => g1.msm(points, [onezero, onezero, onezero, onezero, onezero]),
same_scalar3: () => g1.msm(points, [1n, 1n, 1n, 1n, 1n]),
same_scalar4: () => g1.msm(points, [one8zero, one8zero, one8zero, one8zero, one8zero]),
});
// Ok, and what about multiply itself?
await compare('basic multiply', 5000, {
'1*G1': () => g1.BASE.multiply(1n),
'(n-1)*G1': () => g1.BASE.multiply(bls12_381.G1.CURVE.n - 1n),
'ones*G1': () => g1.BASE.multiply(ones),
'onezero*G1': () => g1.BASE.multiply(onezero),
'one8zero*G1': () => g1.BASE.multiply(one8zero),
// Infinity
'1*Inf': () => g1.ZERO.multiply(1n),
'(n-1)*Inf': () => g1.ZERO.multiply(bls12_381.G1.CURVE.n - 1n),
'ones*Inf': () => g1.ZERO.multiply(ones),
'onezero*Inf': () => g1.ZERO.multiply(onezero),
'one8zero*Inf': () => g1.ZERO.multiply(one8zero),
});

utils.logMem();
});
3 changes: 3 additions & 0 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

58 changes: 57 additions & 1 deletion src/abstract/curve.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*! noble-curves - MIT License (c) 2022 Paul Miller (paulmillr.com) */
// Abelian group utilities
import { IField, validateField, nLength } from './modular.js';
import { validateObject } from './utils.js';
import { validateObject, bitLen } from './utils.js';
const _0n = BigInt(0);
const _1n = BigInt(1);

Expand Down Expand Up @@ -181,6 +181,62 @@ export function wNAF<T extends Group<T>>(c: GroupConstructor<T>, bits: number) {
};
}

/**
* Pippenger algorithm for multi-scalar multiplication (MSM).
* MSM is basically (Pa + Qb + Rc + ...).
* 30x faster vs naive addition on L=4096, 10x faster with precomputes.
* For N=254bit, L=1, it does: 1024 ADD + 254 DBL. For L=5: 1536 ADD + 254 DBL.
* Algorithmically constant-time (for same L), even when 1 point + scalar, or when scalar = 0.
* @param c Curve Point constructor
* @param field field over CURVE.N - important that it's not over CURVE.P
* @param points array of L curve points
* @param scalars array of L scalars (aka private keys / bigints)
*/
export function pippenger<T extends Group<T>>(
c: GroupConstructor<T>,
field: IField<bigint>,
points: T[],
scalars: bigint[]
): T {
// If we split scalars by some window (let's say 8 bits), every chunk will only
// take 256 buckets even if there are 4096 scalars, also re-uses double.
// TODO:
// - https://eprint.iacr.org/2024/750.pdf
// - https://tches.iacr.org/index.php/TCHES/article/view/10287
// 0 is accepted in scalars
if (!Array.isArray(points) || !Array.isArray(scalars) || scalars.length !== points.length)
throw new Error('arrays of scalars and points must have equal length');
scalars.forEach((s, i) => {
if (!field.isValid(s)) throw new Error(`wrong scalar at index ${i}`);
});
points.forEach((p, i) => {
if (!(p instanceof (c as any))) throw new Error(`wrong point at index ${i}`);
});
const wbits = bitLen(BigInt(points.length));
const windowSize = wbits > 12 ? wbits - 3 : wbits > 4 ? wbits - 2 : wbits ? 2 : 1; // in bits
const MASK = (1 << windowSize) - 1;
const buckets = new Array(MASK + 1).fill(c.ZERO); // +1 for zero array
const lastBits = Math.floor((field.BITS - 1) / windowSize) * windowSize;
let sum = c.ZERO;
for (let i = lastBits; i >= 0; i -= windowSize) {
buckets.fill(c.ZERO);
for (let j = 0; j < scalars.length; j++) {
const scalar = scalars[j];
const wbits = Number((scalar >> BigInt(i)) & BigInt(MASK));
buckets[wbits] = buckets[wbits].add(points[j]);
}
let resI = c.ZERO; // not using this will do small speed-up, but will lose ct
// Skip first bucket, because it is zero
for (let j = buckets.length - 1, sumI = c.ZERO; j > 0; j--) {
sumI = sumI.add(buckets[j]);
resI = resI.add(sumI);
}
sum = sum.add(resI);
if (i !== 0) for (let j = 0; j < windowSize; j++) sum = sum.double();
}
return sum as T;
}

// Generic BasicCurve interface: works even for polynomial fields (BLS): P, n, h would be ok.
// Though generator can be different (Fp2 / Fp6 for BLS).
export type BasicCurve<T> = {
Expand Down
18 changes: 16 additions & 2 deletions src/abstract/edwards.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
/*! noble-curves - MIT License (c) 2022 Paul Miller (paulmillr.com) */
// Twisted Edwards curve. The formula is: ax² + y² = 1 + dx²y²
import { AffinePoint, BasicCurve, Group, GroupConstructor, validateBasic, wNAF } from './curve.js';
import { mod } from './modular.js';
import {
AffinePoint,
BasicCurve,
Group,
GroupConstructor,
validateBasic,
wNAF,
pippenger,
} from './curve.js';
import { mod, Field } from './modular.js';
import * as ut from './utils.js';
import { ensureBytes, FHash, Hex, memoized, abool } from './utils.js';

Expand Down Expand Up @@ -70,6 +78,7 @@ export interface ExtPointConstructor extends GroupConstructor<ExtPointType> {
fromAffine(p: AffinePoint<bigint>): ExtPointType;
fromHex(hex: Hex): ExtPointType;
fromPrivateKey(privateKey: Hex): ExtPointType;
msm(points: ExtPointType[], scalars: bigint[]): ExtPointType;
}

/**
Expand Down Expand Up @@ -119,6 +128,7 @@ export function twistedEdwards(curveDef: CurveType): CurveFn {
} = CURVE;
const MASK = _2n << (BigInt(nByteLength * 8) - _1n);
const modP = Fp.create; // Function overrides
const Fn = Field(CURVE.n, CURVE.nBitLength);

// sqrt(u/v)
const uvRatio =
Expand Down Expand Up @@ -218,6 +228,10 @@ export function twistedEdwards(curveDef: CurveType): CurveFn {
const toInv = Fp.invertBatch(points.map((p) => p.ez));
return points.map((p, i) => p.toAffine(toInv[i])).map(Point.fromAffine);
}
// Multiscalar Multiplication
static msm(points: Point[], scalars: bigint[]) {
return pippenger(Point, Fn, points, scalars);
}

// "Private method", don't use it directly
_setWindowSize(windowSize: number) {
Expand Down
17 changes: 16 additions & 1 deletion src/abstract/weierstrass.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
/*! noble-curves - MIT License (c) 2022 Paul Miller (paulmillr.com) */
// Short Weierstrass curve. The formula is: y² = x³ + ax + b
import { AffinePoint, BasicCurve, Group, GroupConstructor, validateBasic, wNAF } from './curve.js';
import {
AffinePoint,
BasicCurve,
Group,
GroupConstructor,
validateBasic,
wNAF,
pippenger,
} from './curve.js';
import * as mod from './modular.js';
import * as ut from './utils.js';
import { CHash, Hex, PrivKey, ensureBytes, memoized, abool } from './utils.js';
Expand Down Expand Up @@ -85,6 +93,7 @@ export interface ProjConstructor<T> extends GroupConstructor<ProjPointType<T>> {
fromHex(hex: Hex): ProjPointType<T>;
fromPrivateKey(privateKey: PrivKey): ProjPointType<T>;
normalizeZ(points: ProjPointType<T>[]): ProjPointType<T>[];
msm(points: ProjPointType<T>[], scalars: bigint[]): ProjPointType<T>;
}

export type CurvePointsType<T> = BasicWCurve<T> & {
Expand Down Expand Up @@ -239,6 +248,7 @@ const _0n = BigInt(0), _1n = BigInt(1), _2n = BigInt(2), _3n = BigInt(3), _4n =
export function weierstrassPoints<T>(opts: CurvePointsType<T>): CurvePointsRes<T> {
const CURVE = validatePointOpts(opts);
const { Fp } = CURVE; // All curves has same field / group length as for now, but they can differ
const Fn = mod.Field(CURVE.n, CURVE.nBitLength);

const toBytes =
CURVE.toBytes ||
Expand Down Expand Up @@ -412,6 +422,11 @@ export function weierstrassPoints<T>(opts: CurvePointsType<T>): CurvePointsRes<T
return Point.BASE.multiply(normPrivateKeyToScalar(privateKey));
}

// Multiscalar Multiplication
static msm(points: Point[], scalars: bigint[]) {
return pippenger(Point, Fn, points, scalars);
}

// "Private method", don't use it directly
_setWindowSize(windowSize: number) {
wnaf.setWindowSize(this, windowSize);
Expand Down
27 changes: 27 additions & 0 deletions test/basic.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,33 @@ for (const name in CURVES) {
{ numRuns: NUM_RUNS }
)
);
should('MSM (basic)', () => {
equal(p.msm([p.BASE], [0n]), p.ZERO, '0*G');
equal(p.msm([], []), p.ZERO, 'empty');
equal(p.msm([p.ZERO], [123n]), p.ZERO, '123 * Infinity');
equal(p.msm([p.BASE], [123n]), p.BASE.multiply(123n), '123 * G');
const points = [p.BASE, p.BASE.multiply(2n), p.BASE.multiply(4n), p.BASE.multiply(8n)];
// 1*3 + 5*2 + 4*7 + 11*8 = 129
equal(p.msm(points, [3n, 5n, 7n, 11n]), p.BASE.multiply(129n), '129 * G');
});
should('MSM (rand)', () =>
fc.assert(
fc.property(fc.array(fc.tuple(FC_BIGINT, FC_BIGINT)), FC_BIGINT, (pairs) => {
let total = 0n;
const scalars = [];
const points = [];
for (const [ps, s] of pairs) {
points.push(p.BASE.multiply(ps));
scalars.push(s);
total += ps * s;
}
total = mod.mod(total, CURVE_ORDER);
const exp = total ? p.BASE.multiply(total) : p.ZERO;
equal(p.msm(points, scalars), exp, 'total');
}),
{ numRuns: NUM_RUNS }
)
);
});

for (const op of ['add', 'subtract']) {
Expand Down

0 comments on commit 7c80242

Please sign in to comment.